diff --git a/tests/test_runners.py b/tests/test_runners.py index 8a817c1f6..7e03d7e54 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -108,6 +108,67 @@ def test_integrate_double_nvt( assert not torch.isnan(final_state.energy).any() +def test_integrate_converts_init_kwarg( + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """integrate scales Nose-Hoover `tau` to internal units like `timestep`. + + Otherwise `tau` is ~98x too small for metal units, giving an over-stiff + thermostat that diverges on force spikes + + See https://github.com/TorchSim/torch-sim/issues/579 for more info + """ + tau = 0.1 # ps, same convention as `timestep` + final = ts.integrate( + system=ar_supercell_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_nose_hoover, + n_steps=1, + temperature=100.0, + timestep=0.002, + init_kwargs={"tau": tau}, + ) + expected = tau * ts.units.MetalUnits.time + assert torch.allclose(final.chain.tau, torch.full_like(final.chain.tau, expected)) + + +def test_integrate_converts_step_kwarg( + ar_supercell_sim_state: SimState, + lj_model: LennardJonesModel, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """integrate divides inverse-time kwargs (e.g. Langevin `gamma`) by the unit factor. + + The opposite direction from relaxation times: `gamma` is a rate, so it scales as + 1/time, not time. It is a step kwarg, so it travels through `**integrator_kwargs` + to the step function (issue #579 / InverseTimeArg). + """ + gamma = 10.0 # 1/ps, same time convention as `timestep` + # gamma is consumed inside the step function and never stored on the state, so + # unlike the tau test in test_integrate_converts_persistent_init_kwarg, we must + # spy on the step call to see the converted value. + received: dict[str, object] = {} + init, real_step = ts.integrators.INTEGRATOR_REGISTRY[ts.Integrator.nvt_langevin] + + def spy_step(**kwargs: object) -> MDState: + received["gamma"] = kwargs["gamma"] + return real_step(**kwargs) + + monkeypatch.setitem( + ts.integrators.INTEGRATOR_REGISTRY, ts.Integrator.nvt_langevin, (init, spy_step) + ) + ts.integrate( + system=ar_supercell_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=1, + temperature=100.0, + timestep=0.002, + gamma=gamma, + ) + assert received["gamma"] == gamma / ts.units.MetalUnits.time + + def test_integrate_double_nvt_multiple_temperatures( ar_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index e6f09f534..3af9ac8fe 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -74,13 +74,25 @@ """ # ruff: noqa: F401 +import typing from collections.abc import Callable +from dataclasses import dataclass from enum import StrEnum -from typing import Any, Final +from typing import Any, Final, Literal import torch_sim as ts -from .md import MDState, initialize_momenta, momentum_step, position_step, velocity_verlet +from .md import ( + InversePressureArg, + InverseTimeArg, + MDState, + PressureArg, + TimeArg, + initialize_momenta, + momentum_step, + position_step, + velocity_verlet, +) from .npt import ( NPTLangevinAnisotropicState, NPTLangevinIsotropicState, @@ -202,3 +214,52 @@ class Integrator(StrEnum): npt_crescale_triclinic_step, ), } + + +@dataclass(frozen=True, slots=True) +class UnitKwarg: + """Metadata for integrator parameters carrying physical units. + + ``factor`` converts the parameter to internal units (multiply by it). + ``channel`` is the function the parameter is used in (either ``"init"`` + or ``"step"``). + """ + + factor: float + channel: Literal["init", "step"] + + +_UNIT_FACTORS: Final = frozenset( + arg.__metadata__[0] + for arg in (TimeArg, InverseTimeArg, PressureArg, InversePressureArg) +) + + +def _collect_unit_kwargs(integrator: Integrator) -> dict[str, UnitKwarg]: + """Read the unit-annotated kwargs off an integrator's init/step signatures.""" + init_fn, step_fn = INTEGRATOR_REGISTRY[integrator] + unit_kwargs: dict[str, UnitKwarg] = {} + for channel, fn in (("init", init_fn), ("step", step_fn)): + hints = typing.get_type_hints(fn, include_extras=True) + for name, hint in hints.items(): + for meta in getattr(hint, "__metadata__", ()): + if isinstance(meta, float) and meta in _UNIT_FACTORS: + if name in unit_kwargs: + raise TypeError( + f"{integrator}: unit kwarg {name!r} is annotated in both " + f"the init and step functions, making its `integrate` " + f"channel ambiguous" + ) + unit_kwargs[name] = UnitKwarg(float(meta), channel) + return unit_kwargs + + +#: Unit-carrying kwargs of each integrator, derived at import time from the +#: ``TimeArg``/``InverseTimeArg``/``PressureArg``/``InversePressureArg`` annotations +#: on the registered init/step functions. :func:`torch_sim.runners.integrate` reads +#: this to unit-convert kwargs. +INTEGRATOR_UNIT_KWARGS: Final[dict[Integrator, dict[str, UnitKwarg]]] = { + integrator: unit_kwargs + for integrator in Integrator + if (unit_kwargs := _collect_unit_kwargs(integrator)) +} diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index a88685746..1b82d660e 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -4,6 +4,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass +from typing import Annotated import torch @@ -17,6 +18,13 @@ logger = logging.getLogger(__name__) +# The metadata is the factor that converts the kwarg to internal units +TimeArg = Annotated[float | torch.Tensor | None, MetalUnits.time] +InverseTimeArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.time] +PressureArg = Annotated[float | torch.Tensor, MetalUnits.pressure] +InversePressureArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.pressure] + + @dataclass(kw_only=True) class MDState(SimState): """State information for molecular dynamics simulations. diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 35cacd32f..f12fc39a5 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -11,9 +11,13 @@ import torch_sim as ts from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( + InversePressureArg, + InverseTimeArg, MDState, NoseHooverChain, NoseHooverChainFns, + PressureArg, + TimeArg, construct_nose_hoover_chain, initialize_momenta, momentum_step, @@ -407,9 +411,9 @@ def npt_langevin_anisotropic_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - alpha: float | torch.Tensor | None = None, - cell_alpha: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + alpha: InverseTimeArg = None, + cell_alpha: InverseTimeArg = None, + b_tau: TimeArg = None, **_kwargs: Any, ) -> NPTLangevinAnisotropicState: """Initialize NPT Langevin state with independent per-dimension cell lengths. @@ -511,7 +515,7 @@ def npt_langevin_anisotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTLangevinAnisotropicState: r"""Perform one NPT Langevin step with independent per-dimension cell lengths. @@ -886,9 +890,9 @@ def npt_langevin_isotropic_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - alpha: float | torch.Tensor | None = None, - cell_alpha: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + alpha: InverseTimeArg = None, + cell_alpha: InverseTimeArg = None, + b_tau: TimeArg = None, **_kwargs: Any, ) -> NPTLangevinIsotropicState: """Initialize an NPT Langevin state using logarithmic strain coordinate. @@ -986,7 +990,7 @@ def npt_langevin_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTLangevinIsotropicState: r"""Perform one NPT Langevin step using logarithmic strain coordinate. @@ -1643,8 +1647,8 @@ def npt_nose_hoover_isotropic_init( chain_length: int = 3, chain_steps: int = 2, sy_steps: int = 3, - t_tau: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + t_tau: TimeArg = None, + b_tau: TimeArg = None, **kwargs: Any, ) -> NPTNoseHooverIsotropicState: """Initialize the NPT Nose-Hoover state. @@ -1810,7 +1814,7 @@ def npt_nose_hoover_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTNoseHooverIsotropicState: r"""Perform a complete NPT integration step with Nose-Hoover chain thermostats. @@ -2515,8 +2519,8 @@ def npt_crescale_triclinic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with anisotropic stochastic cell rescaling. @@ -2643,8 +2647,8 @@ def npt_crescale_anisotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2719,8 +2723,8 @@ def npt_crescale_triclinic_average_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2795,8 +2799,8 @@ def npt_crescale_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with isotropic stochastic cell rescaling. @@ -2902,8 +2906,8 @@ def npt_crescale_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - tau_p: float | torch.Tensor | None = None, - isothermal_compressibility: float | torch.Tensor | None = None, + tau_p: TimeArg = None, + isothermal_compressibility: InversePressureArg = None, ) -> NPTCRescaleState: """Initialize the NPT cell rescaling state. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 644b10897..9007e63dc 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -8,9 +8,11 @@ import torch_sim as ts from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( + InverseTimeArg, MDState, NoseHooverChain, NoseHooverChainFns, + TimeArg, construct_nose_hoover_chain, initialize_momenta, momentum_step, @@ -143,7 +145,7 @@ def nvt_langevin_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - gamma: float | torch.Tensor | None = None, + gamma: InverseTimeArg = None, ) -> MDState: r"""Perform one complete Langevin dynamics integration step using the BAOAB scheme. @@ -291,7 +293,7 @@ def nvt_nose_hoover_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + tau: TimeArg = None, chain_length: int = 3, chain_steps: int = 3, sy_steps: int = 3, @@ -711,7 +713,7 @@ def nvt_vrescale_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + tau: TimeArg = None, ) -> NVTVRescaleState: r"""Perform one complete V-Rescale (CSVR) dynamics integration step. diff --git a/torch_sim/runners.py b/torch_sim/runners.py index cd46cc7ce..2f7bb7ccd 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -18,7 +18,7 @@ import torch_sim as ts from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher -from torch_sim.integrators import INTEGRATOR_REGISTRY, Integrator +from torch_sim.integrators import INTEGRATOR_REGISTRY, INTEGRATOR_UNIT_KWARGS, Integrator from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState @@ -249,7 +249,7 @@ def _write_initial_state( trajectory_reporter.report(state, 0, model=model) -def integrate[T: SimState]( # noqa: C901 +def integrate[T: SimState]( # noqa: C901, PLR0915 system: StateLike, model: ModelInterface, *, @@ -287,7 +287,7 @@ def integrate[T: SimState]( # noqa: C901 it's passed to `tqdm` as kwargs. init_kwargs (dict[str, Any], optional): Additional keyword arguments for integrator init function. - **integrator_kwargs: Additional keyword arguments for integrator init function + **integrator_kwargs: Additional keyword arguments for integrator step function Returns: T: Final state after integration @@ -320,6 +320,16 @@ def integrate[T: SimState]( # noqa: C901 f"integrator must be key from Integrator or a tuple of " f"(init_func, step_func), got {type(integrator)}" ) + + # Like `timestep` above, unit-carrying kwargs (e.g. `tau`, `gamma`, + # `external_pressure`) must be converted to internal units per + # INTEGRATOR_UNIT_KWARGS + init_kwargs = dict(init_kwargs or {}) + channels = {"init": init_kwargs, "step": integrator_kwargs} + for key, meta in INTEGRATOR_UNIT_KWARGS.get(integrator, {}).items(): + kwargs = channels[meta.channel] + if kwargs.get(key) is not None: + kwargs[key] = kwargs[key] * meta.factor # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( initial_state, model, autobatcher=autobatcher @@ -351,9 +361,7 @@ def integrate[T: SimState]( # noqa: C901 batch_kT = ( kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs ) - state = init_func( - state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {} - ) + state = init_func(state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs) # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: diff --git a/torch_sim/units.py b/torch_sim/units.py index 19cd6c4fc..146adab22 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -15,7 +15,7 @@ class BaseConstant(float, Enum): References: http://arxiv.org/pdf/1507.07956.pdf - https://wiki.fysik.dtu.dk/ase/_modules/ase/units.html#create_units + https://docs.ase-lib.org/_modules/ase/units.html#create_units """ def __new__(cls, value: float) -> Self: