diff --git a/src/aca_model/agent/labor_market.py b/src/aca_model/agent/labor_market.py index 1421e9e..5665d92 100644 --- a/src/aca_model/agent/labor_market.py +++ b/src/aca_model/agent/labor_market.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import numpy as np from lcm import categorical from lcm.typing import ( ContinuousState, @@ -39,12 +40,16 @@ class SpousalIncome: married_has_inc: ScalarInt -HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) +# Host array, not a module-level JAX array: a device array here would +# reserve the GPU memory pool at import time in every process that imports +# the model. It is converted to a device array at each indexing site, where +# the value folds into the surrounding compiled function. +HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0]) def working_hours_value(labor_supply: DiscreteAction) -> FloatND: """Map labor supply choice to annual hours worked.""" - return HOURS_VALUES[labor_supply] + return jnp.asarray(HOURS_VALUES)[labor_supply] def wage( @@ -74,7 +79,7 @@ def income( income = wage * hours^(1 + exp) * int^(-exp) """ - hours = HOURS_VALUES[labor_supply] + hours = jnp.asarray(HOURS_VALUES)[labor_supply] return jnp.where( hours > 0.0, wage diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 381dec6..eeba5bd 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -7,7 +7,6 @@ from lcm import categorical from lcm.typing import ( Age, - BoolND, ContinuousAction, ContinuousState, DiscreteState, @@ -19,6 +18,11 @@ from aca_model.agent.labor_market import LaggedLaborSupply +# Width of the smooth leisure floor, as a fraction of the time endowment. Small enough +# that leisure equals `time_endowment - cost` wherever work costs sit well below the +# endowment; it only bends the map near and beyond the endowment. +_LEISURE_SMOOTHING_FRACTION = 0.01 + @categorical(ordered=False) class PrefType: @@ -44,11 +48,6 @@ class BenchmarkPrefType: type_1: ScalarInt -def positive_leisure(leisure: FloatND) -> BoolND: - """Return True where leisure is strictly positive.""" - return leisure > 0 - - def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND: """Return the equivalence scale for household size adjustment. @@ -69,6 +68,22 @@ def fixed_cost_of_work( ) +def _smooth_leisure_floor( + leisure_available: FloatND, time_endowment: ScalarFloat +) -> FloatND: + """Bend leisure to a strictly positive floor as work costs approach the endowment. + + `softplus(x) = log(1 + e^x)` via `jnp.logaddexp(0, x)`, scaled by a small fraction + of the endowment. Where `leisure_available` is large relative to the smoothing width + the map reduces to `leisure_available` (bulk unchanged); as it falls to zero leisure + bends to `0⁺` — never negative, never a kinked clamp — so the CRRA aggregator never + receives a non-positive base. The smoothing width scales with the endowment, so the + map is scale-invariant. + """ + smoothing = _LEISURE_SMOOTHING_FRACTION * time_endowment + return smoothing * jnp.logaddexp(0.0, leisure_available / smoothing) + + def leisure_canwork_retiree_or_nongroup( working_hours_value: FloatND, good_health: IntND, @@ -94,7 +109,8 @@ def leisure_canwork_retiree_or_nongroup( 0.0, ) - return time_endowment - health_loss - work_loss + leisure_available = time_endowment - health_loss - work_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def leisure_canwork_tied( @@ -112,7 +128,8 @@ def leisure_canwork_tied( work_loss = jnp.where( working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0 ) - return time_endowment - health_loss - work_loss + leisure_available = time_endowment - health_loss - work_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def leisure_forcedout( @@ -122,7 +139,8 @@ def leisure_forcedout( ) -> FloatND: """Compute leisure for forcedout regimes (no work).""" health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health) - return time_endowment - health_loss + leisure_available = time_endowment - health_loss + return _smooth_leisure_floor(leisure_available, time_endowment) def consumption_equiv( diff --git a/src/aca_model/baseline/regimes/_nongroup.py b/src/aca_model/baseline/regimes/_nongroup.py index 151929c..410ef6f 100644 --- a/src/aca_model/baseline/regimes/_nongroup.py +++ b/src/aca_model/baseline/regimes/_nongroup.py @@ -10,7 +10,6 @@ from lcm.solvers import DCEGM from lcm.typing import Age, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance from aca_model.baseline.regimes._common import ( @@ -112,10 +111,6 @@ def build_regime( transition_func = _make_transition_forcedout(gets_mc, own) states = build_states(spec, grids) - # `borrowing_constraint` is broadcast from the model level. - constraints: dict = {} - if spec["canwork"] == "canwork": - constraints["positive_leisure"] = preferences.positive_leisure solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver} return Regime( @@ -129,6 +124,5 @@ def build_regime( ), actions=build_actions(spec, grids), functions=_build_functions(spec), - constraints=constraints, **solver_kwargs, ) diff --git a/src/aca_model/baseline/regimes/_retiree.py b/src/aca_model/baseline/regimes/_retiree.py index 1dd054b..2c86376 100644 --- a/src/aca_model/baseline/regimes/_retiree.py +++ b/src/aca_model/baseline/regimes/_retiree.py @@ -11,7 +11,6 @@ from lcm.solvers import DCEGM from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance from aca_model.baseline.regimes._common import ( @@ -122,10 +121,6 @@ def build_regime( transition_func = _make_transition_forcedout(gets_mc, own, ng) states = build_states(spec, grids) - # `borrowing_constraint` is broadcast from the model level. - constraints: dict = {} - if spec["canwork"] == "canwork": - constraints["positive_leisure"] = preferences.positive_leisure solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver} return Regime( @@ -139,6 +134,5 @@ def build_regime( ), actions=build_actions(spec, grids), functions=_build_functions(spec), - constraints=constraints, **solver_kwargs, ) diff --git a/src/aca_model/baseline/regimes/_tied.py b/src/aca_model/baseline/regimes/_tied.py index d7b66ad..b1f6f41 100644 --- a/src/aca_model/baseline/regimes/_tied.py +++ b/src/aca_model/baseline/regimes/_tied.py @@ -12,7 +12,6 @@ from lcm.solvers import DCEGM from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period -from aca_model.agent import preferences from aca_model.agent.labor_market import LaborSupply from aca_model.baseline import health_insurance from aca_model.baseline.regimes._common import ( @@ -105,7 +104,5 @@ def build_regime( ), actions=build_actions(spec, grids), functions=_build_functions(spec), - # `borrowing_constraint` is broadcast from the model level. - constraints={"positive_leisure": preferences.positive_leisure}, **solver_kwargs, ) diff --git a/tests/test_labor_market.py b/tests/test_labor_market.py index 18dcaa2..f26992a 100644 --- a/tests/test_labor_market.py +++ b/tests/test_labor_market.py @@ -2,11 +2,35 @@ import jax.numpy as jnp import numpy as np +import pytest from aca_model.agent import labor_market from aca_model.agent.labor_market import LaborSupply +def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None: + """`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array. + + A module-level JAX array materializes on the default device the moment the + module is imported, reserving the GPU memory pool in every process that + imports the model — including the estimation orchestrator, which only + launches GPU worker ranks and must leave the devices free for them. + """ + assert isinstance(labor_market.HOURS_VALUES, np.ndarray) + + +@pytest.mark.parametrize( + ("choice", "expected_hours"), + [(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)], +) +def test_working_hours_value_maps_choice_to_annual_hours( + choice: int, expected_hours: float +) -> None: + """Each labor-supply choice maps to its annual hours worked.""" + result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32)) + np.testing.assert_allclose(float(result), expected_hours) + + def test_wage_combines_age_health_profile_with_residual() -> None: """`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`.""" log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health] diff --git a/tests/test_model_components.py b/tests/test_model_components.py index 733fecd..5caa981 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -67,6 +67,49 @@ def test_leisure_bad_health() -> None: assert jnp.isclose(result, 4500.0) +def test_leisure_canwork_stays_positive_when_work_cost_meets_endowment() -> None: + """Leisure bends to a strictly positive floor when work costs reach the endowment. + + Without a floor, leisure would be zero (or negative past the endowment) and feed a + non-positive base into the CRRA aggregator. The smooth floor keeps it clearly + positive. + """ + result = preferences.leisure_canwork_retiree_or_nongroup( + working_hours_value=jnp.array(4500.0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(1), + time_endowment=jnp.asarray(5000.0), + leisure_cost_of_bad_health=jnp.asarray(0.0), + fixed_cost_of_work=jnp.asarray(500.0), # 4500 + 500 == 5000 == endowment + labor_force_reentry_cost=jnp.asarray(0.0), + ) + assert result > 1.0 + + +def test_leisure_canwork_tied_decreases_smoothly_into_saturation() -> None: + """Past the endowment, more work cost still lowers leisure, never below zero. + + A flat clamp or a raw `endowment - cost` would either pin leisure or drive it + negative; the smooth floor keeps it strictly positive and strictly decreasing. + """ + common = { + "good_health": jnp.int32(1), + "time_endowment": jnp.asarray(5000.0), + "leisure_cost_of_bad_health": jnp.asarray(0.0), + } + at_endowment = preferences.leisure_canwork_tied( + working_hours_value=jnp.array(4500.0), + fixed_cost_of_work=jnp.asarray(500.0), # cost == endowment + **common, + ) + past_endowment = preferences.leisure_canwork_tied( + working_hours_value=jnp.array(4500.0), + fixed_cost_of_work=jnp.asarray(700.0), # cost exceeds endowment by 200 + **common, + ) + assert at_endowment > past_endowment > 0.0 + + def test_utility_positive_leisure() -> None: result = preferences.u_alive( consumption_equiv=jnp.array(10000.0),