Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/aca_model/agent/labor_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import jax.numpy as jnp
import numpy as np
from lcm import categorical
from lcm.typing import (
ContinuousState,
Expand Down Expand Up @@ -39,12 +40,16 @@ class SpousalIncome:
married_has_inc: ScalarInt


HOURS_VALUES = jnp.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])
# Host array, not a module-level JAX array: a device array here would
# reserve the GPU memory pool at import time in every process that imports
# the model. It is converted to a device array at each indexing site, where
# the value folds into the surrounding compiled function.
HOURS_VALUES = np.array([0.0, 1000.0, 1500.0, 2000.0, 2500.0])


def working_hours_value(labor_supply: DiscreteAction) -> FloatND:
"""Map labor supply choice to annual hours worked."""
return HOURS_VALUES[labor_supply]
return jnp.asarray(HOURS_VALUES)[labor_supply]


def wage(
Expand Down Expand Up @@ -74,7 +79,7 @@ def income(

income = wage * hours^(1 + exp) * int^(-exp)
"""
hours = HOURS_VALUES[labor_supply]
hours = jnp.asarray(HOURS_VALUES)[labor_supply]
return jnp.where(
hours > 0.0,
wage
Expand Down
36 changes: 27 additions & 9 deletions src/aca_model/agent/preferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from lcm import categorical
from lcm.typing import (
Age,
BoolND,
ContinuousAction,
ContinuousState,
DiscreteState,
Expand All @@ -19,6 +18,11 @@

from aca_model.agent.labor_market import LaggedLaborSupply

# Width of the smooth leisure floor, as a fraction of the time endowment. Small enough
# that leisure equals `time_endowment - cost` wherever work costs sit well below the
# endowment; it only bends the map near and beyond the endowment.
_LEISURE_SMOOTHING_FRACTION = 0.01


@categorical(ordered=False)
class PrefType:
Expand All @@ -44,11 +48,6 @@ class BenchmarkPrefType:
type_1: ScalarInt


def positive_leisure(leisure: FloatND) -> BoolND:
"""Return True where leisure is strictly positive."""
return leisure > 0


def equivalence_scale(is_married: IntND, exponent: ScalarFloat) -> FloatND:
"""Return the equivalence scale for household size adjustment.

Expand All @@ -69,6 +68,22 @@ def fixed_cost_of_work(
)


def _smooth_leisure_floor(
leisure_available: FloatND, time_endowment: ScalarFloat
) -> FloatND:
"""Bend leisure to a strictly positive floor as work costs approach the endowment.

`softplus(x) = log(1 + e^x)` via `jnp.logaddexp(0, x)`, scaled by a small fraction
of the endowment. Where `leisure_available` is large relative to the smoothing width
the map reduces to `leisure_available` (bulk unchanged); as it falls to zero leisure
bends to `0⁺` — never negative, never a kinked clamp — so the CRRA aggregator never
receives a non-positive base. The smoothing width scales with the endowment, so the
map is scale-invariant.
"""
smoothing = _LEISURE_SMOOTHING_FRACTION * time_endowment
return smoothing * jnp.logaddexp(0.0, leisure_available / smoothing)


def leisure_canwork_retiree_or_nongroup(
working_hours_value: FloatND,
good_health: IntND,
Expand All @@ -94,7 +109,8 @@ def leisure_canwork_retiree_or_nongroup(
0.0,
)

return time_endowment - health_loss - work_loss
leisure_available = time_endowment - health_loss - work_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def leisure_canwork_tied(
Expand All @@ -112,7 +128,8 @@ def leisure_canwork_tied(
work_loss = jnp.where(
working_hours_value > 0.0, working_hours_value + fixed_cost_of_work, 0.0
)
return time_endowment - health_loss - work_loss
leisure_available = time_endowment - health_loss - work_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def leisure_forcedout(
Expand All @@ -122,7 +139,8 @@ def leisure_forcedout(
) -> FloatND:
"""Compute leisure for forcedout regimes (no work)."""
health_loss = jnp.where(good_health, 0.0, leisure_cost_of_bad_health)
return time_endowment - health_loss
leisure_available = time_endowment - health_loss
return _smooth_leisure_floor(leisure_available, time_endowment)


def consumption_equiv(
Expand Down
6 changes: 0 additions & 6 deletions src/aca_model/baseline/regimes/_nongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from lcm.solvers import DCEGM
from lcm.typing import Age, DiscreteAction, FloatND, Period

from aca_model.agent import preferences
from aca_model.agent.labor_market import LaborSupply
from aca_model.baseline import health_insurance
from aca_model.baseline.regimes._common import (
Expand Down Expand Up @@ -112,10 +111,6 @@ def build_regime(
transition_func = _make_transition_forcedout(gets_mc, own)

states = build_states(spec, grids)
# `borrowing_constraint` is broadcast from the model level.
constraints: dict = {}
if spec["canwork"] == "canwork":
constraints["positive_leisure"] = preferences.positive_leisure

solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver}
return Regime(
Expand All @@ -129,6 +124,5 @@ def build_regime(
),
actions=build_actions(spec, grids),
functions=_build_functions(spec),
constraints=constraints,
**solver_kwargs,
)
6 changes: 0 additions & 6 deletions src/aca_model/baseline/regimes/_retiree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from lcm.solvers import DCEGM
from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period

from aca_model.agent import preferences
from aca_model.agent.labor_market import LaborSupply
from aca_model.baseline import health_insurance
from aca_model.baseline.regimes._common import (
Expand Down Expand Up @@ -122,10 +121,6 @@ def build_regime(
transition_func = _make_transition_forcedout(gets_mc, own, ng)

states = build_states(spec, grids)
# `borrowing_constraint` is broadcast from the model level.
constraints: dict = {}
if spec["canwork"] == "canwork":
constraints["positive_leisure"] = preferences.positive_leisure

solver_kwargs: dict = {} if dcegm_solver is None else {"solver": dcegm_solver}
return Regime(
Expand All @@ -139,6 +134,5 @@ def build_regime(
),
actions=build_actions(spec, grids),
functions=_build_functions(spec),
constraints=constraints,
**solver_kwargs,
)
3 changes: 0 additions & 3 deletions src/aca_model/baseline/regimes/_tied.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from lcm.solvers import DCEGM
from lcm.typing import Age, BoolND, DiscreteAction, FloatND, Period

from aca_model.agent import preferences
from aca_model.agent.labor_market import LaborSupply
from aca_model.baseline import health_insurance
from aca_model.baseline.regimes._common import (
Expand Down Expand Up @@ -105,7 +104,5 @@ def build_regime(
),
actions=build_actions(spec, grids),
functions=_build_functions(spec),
# `borrowing_constraint` is broadcast from the model level.
constraints={"positive_leisure": preferences.positive_leisure},
**solver_kwargs,
)
24 changes: 24 additions & 0 deletions tests/test_labor_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,35 @@

import jax.numpy as jnp
import numpy as np
import pytest

from aca_model.agent import labor_market
from aca_model.agent.labor_market import LaborSupply


def test_hours_values_is_host_array_so_import_allocates_no_device_memory() -> None:
"""`HOURS_VALUES` is a host (NumPy) array, not a device-pinned JAX array.

A module-level JAX array materializes on the default device the moment the
module is imported, reserving the GPU memory pool in every process that
imports the model — including the estimation orchestrator, which only
launches GPU worker ranks and must leave the devices free for them.
"""
assert isinstance(labor_market.HOURS_VALUES, np.ndarray)


@pytest.mark.parametrize(
("choice", "expected_hours"),
[(0, 0.0), (1, 1000.0), (2, 1500.0), (3, 2000.0), (4, 2500.0)],
)
def test_working_hours_value_maps_choice_to_annual_hours(
choice: int, expected_hours: float
) -> None:
"""Each labor-supply choice maps to its annual hours worked."""
result = labor_market.working_hours_value(jnp.asarray(choice, dtype=jnp.int32))
np.testing.assert_allclose(float(result), expected_hours)


def test_wage_combines_age_health_profile_with_residual() -> None:
"""`wage = exp(log_ft_wage_mean[period, good_health] + log_ft_wage_std * res)`."""
log_ft_wage_mean = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # [period, good_health]
Expand Down
43 changes: 43 additions & 0 deletions tests/test_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,49 @@ def test_leisure_bad_health() -> None:
assert jnp.isclose(result, 4500.0)


def test_leisure_canwork_stays_positive_when_work_cost_meets_endowment() -> None:
"""Leisure bends to a strictly positive floor when work costs reach the endowment.

Without a floor, leisure would be zero (or negative past the endowment) and feed a
non-positive base into the CRRA aggregator. The smooth floor keeps it clearly
positive.
"""
result = preferences.leisure_canwork_retiree_or_nongroup(
working_hours_value=jnp.array(4500.0),
good_health=jnp.int32(1),
lagged_labor_supply=jnp.int32(1),
time_endowment=jnp.asarray(5000.0),
leisure_cost_of_bad_health=jnp.asarray(0.0),
fixed_cost_of_work=jnp.asarray(500.0), # 4500 + 500 == 5000 == endowment
labor_force_reentry_cost=jnp.asarray(0.0),
)
assert result > 1.0


def test_leisure_canwork_tied_decreases_smoothly_into_saturation() -> None:
"""Past the endowment, more work cost still lowers leisure, never below zero.

A flat clamp or a raw `endowment - cost` would either pin leisure or drive it
negative; the smooth floor keeps it strictly positive and strictly decreasing.
"""
common = {
"good_health": jnp.int32(1),
"time_endowment": jnp.asarray(5000.0),
"leisure_cost_of_bad_health": jnp.asarray(0.0),
}
at_endowment = preferences.leisure_canwork_tied(
working_hours_value=jnp.array(4500.0),
fixed_cost_of_work=jnp.asarray(500.0), # cost == endowment
**common,
)
past_endowment = preferences.leisure_canwork_tied(
working_hours_value=jnp.array(4500.0),
fixed_cost_of_work=jnp.asarray(700.0), # cost exceeds endowment by 200
**common,
)
assert at_endowment > past_endowment > 0.0


def test_utility_positive_leisure() -> None:
result = preferences.u_alive(
consumption_equiv=jnp.array(10000.0),
Expand Down
Loading