From 97f483ae177dee4df8e55ceca13b5ee2e2e2fd70 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 20 Jun 2026 13:50:27 -0700 Subject: [PATCH 1/2] wip fix 579 add run cmd simplify repro new repro repro fix use their model integrator fixes revert to sevennet cleanup get structures in memory better integrator definitions cleanup better test names more cleanup cleanup more cleanup cleanup cleanup cleanup delete repro file fix other params better name use constants cleanup --- tests/test_runners.py | 61 +++++++++++++++++++++++++++++ torch_sim/integrators/__init__.py | 65 ++++++++++++++++++++++++++++++- torch_sim/integrators/md.py | 8 ++++ torch_sim/integrators/npt.py | 46 ++++++++++++---------- torch_sim/integrators/nvt.py | 8 ++-- torch_sim/runners.py | 20 +++++++--- 6 files changed, 176 insertions(+), 32 deletions(-) diff --git a/tests/test_runners.py b/tests/test_runners.py index 8a817c1f6..7e03d7e54 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -108,6 +108,67 @@ def test_integrate_double_nvt( assert not torch.isnan(final_state.energy).any() +def test_integrate_converts_init_kwarg( + ar_supercell_sim_state: SimState, lj_model: LennardJonesModel +) -> None: + """integrate scales Nose-Hoover `tau` to internal units like `timestep`. + + Otherwise `tau` is ~98x too small for metal units, giving an over-stiff + thermostat that diverges on force spikes + + See https://github.com/TorchSim/torch-sim/issues/579 for more info + """ + tau = 0.1 # ps, same convention as `timestep` + final = ts.integrate( + system=ar_supercell_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_nose_hoover, + n_steps=1, + temperature=100.0, + timestep=0.002, + init_kwargs={"tau": tau}, + ) + expected = tau * ts.units.MetalUnits.time + assert torch.allclose(final.chain.tau, torch.full_like(final.chain.tau, expected)) + + +def test_integrate_converts_step_kwarg( + ar_supercell_sim_state: SimState, + lj_model: LennardJonesModel, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """integrate divides inverse-time kwargs (e.g. Langevin `gamma`) by the unit factor. + + The opposite direction from relaxation times: `gamma` is a rate, so it scales as + 1/time, not time. It is a step kwarg, so it travels through `**integrator_kwargs` + to the step function (issue #579 / InverseTimeArg). + """ + gamma = 10.0 # 1/ps, same time convention as `timestep` + # gamma is consumed inside the step function and never stored on the state, so + # unlike the tau test in test_integrate_converts_persistent_init_kwarg, we must + # spy on the step call to see the converted value. + received: dict[str, object] = {} + init, real_step = ts.integrators.INTEGRATOR_REGISTRY[ts.Integrator.nvt_langevin] + + def spy_step(**kwargs: object) -> MDState: + received["gamma"] = kwargs["gamma"] + return real_step(**kwargs) + + monkeypatch.setitem( + ts.integrators.INTEGRATOR_REGISTRY, ts.Integrator.nvt_langevin, (init, spy_step) + ) + ts.integrate( + system=ar_supercell_sim_state, + model=lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=1, + temperature=100.0, + timestep=0.002, + gamma=gamma, + ) + assert received["gamma"] == gamma / ts.units.MetalUnits.time + + def test_integrate_double_nvt_multiple_temperatures( ar_double_sim_state: SimState, lj_model: LennardJonesModel ) -> None: diff --git a/torch_sim/integrators/__init__.py b/torch_sim/integrators/__init__.py index e6f09f534..3af9ac8fe 100644 --- a/torch_sim/integrators/__init__.py +++ b/torch_sim/integrators/__init__.py @@ -74,13 +74,25 @@ """ # ruff: noqa: F401 +import typing from collections.abc import Callable +from dataclasses import dataclass from enum import StrEnum -from typing import Any, Final +from typing import Any, Final, Literal import torch_sim as ts -from .md import MDState, initialize_momenta, momentum_step, position_step, velocity_verlet +from .md import ( + InversePressureArg, + InverseTimeArg, + MDState, + PressureArg, + TimeArg, + initialize_momenta, + momentum_step, + position_step, + velocity_verlet, +) from .npt import ( NPTLangevinAnisotropicState, NPTLangevinIsotropicState, @@ -202,3 +214,52 @@ class Integrator(StrEnum): npt_crescale_triclinic_step, ), } + + +@dataclass(frozen=True, slots=True) +class UnitKwarg: + """Metadata for integrator parameters carrying physical units. + + ``factor`` converts the parameter to internal units (multiply by it). + ``channel`` is the function the parameter is used in (either ``"init"`` + or ``"step"``). + """ + + factor: float + channel: Literal["init", "step"] + + +_UNIT_FACTORS: Final = frozenset( + arg.__metadata__[0] + for arg in (TimeArg, InverseTimeArg, PressureArg, InversePressureArg) +) + + +def _collect_unit_kwargs(integrator: Integrator) -> dict[str, UnitKwarg]: + """Read the unit-annotated kwargs off an integrator's init/step signatures.""" + init_fn, step_fn = INTEGRATOR_REGISTRY[integrator] + unit_kwargs: dict[str, UnitKwarg] = {} + for channel, fn in (("init", init_fn), ("step", step_fn)): + hints = typing.get_type_hints(fn, include_extras=True) + for name, hint in hints.items(): + for meta in getattr(hint, "__metadata__", ()): + if isinstance(meta, float) and meta in _UNIT_FACTORS: + if name in unit_kwargs: + raise TypeError( + f"{integrator}: unit kwarg {name!r} is annotated in both " + f"the init and step functions, making its `integrate` " + f"channel ambiguous" + ) + unit_kwargs[name] = UnitKwarg(float(meta), channel) + return unit_kwargs + + +#: Unit-carrying kwargs of each integrator, derived at import time from the +#: ``TimeArg``/``InverseTimeArg``/``PressureArg``/``InversePressureArg`` annotations +#: on the registered init/step functions. :func:`torch_sim.runners.integrate` reads +#: this to unit-convert kwargs. +INTEGRATOR_UNIT_KWARGS: Final[dict[Integrator, dict[str, UnitKwarg]]] = { + integrator: unit_kwargs + for integrator in Integrator + if (unit_kwargs := _collect_unit_kwargs(integrator)) +} diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index a88685746..1b82d660e 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -4,6 +4,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass +from typing import Annotated import torch @@ -17,6 +18,13 @@ logger = logging.getLogger(__name__) +# The metadata is the factor that converts the kwarg to internal units +TimeArg = Annotated[float | torch.Tensor | None, MetalUnits.time] +InverseTimeArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.time] +PressureArg = Annotated[float | torch.Tensor, MetalUnits.pressure] +InversePressureArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.pressure] + + @dataclass(kw_only=True) class MDState(SimState): """State information for molecular dynamics simulations. diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 35cacd32f..f12fc39a5 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -11,9 +11,13 @@ import torch_sim as ts from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( + InversePressureArg, + InverseTimeArg, MDState, NoseHooverChain, NoseHooverChainFns, + PressureArg, + TimeArg, construct_nose_hoover_chain, initialize_momenta, momentum_step, @@ -407,9 +411,9 @@ def npt_langevin_anisotropic_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - alpha: float | torch.Tensor | None = None, - cell_alpha: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + alpha: InverseTimeArg = None, + cell_alpha: InverseTimeArg = None, + b_tau: TimeArg = None, **_kwargs: Any, ) -> NPTLangevinAnisotropicState: """Initialize NPT Langevin state with independent per-dimension cell lengths. @@ -511,7 +515,7 @@ def npt_langevin_anisotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTLangevinAnisotropicState: r"""Perform one NPT Langevin step with independent per-dimension cell lengths. @@ -886,9 +890,9 @@ def npt_langevin_isotropic_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - alpha: float | torch.Tensor | None = None, - cell_alpha: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + alpha: InverseTimeArg = None, + cell_alpha: InverseTimeArg = None, + b_tau: TimeArg = None, **_kwargs: Any, ) -> NPTLangevinIsotropicState: """Initialize an NPT Langevin state using logarithmic strain coordinate. @@ -986,7 +990,7 @@ def npt_langevin_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTLangevinIsotropicState: r"""Perform one NPT Langevin step using logarithmic strain coordinate. @@ -1643,8 +1647,8 @@ def npt_nose_hoover_isotropic_init( chain_length: int = 3, chain_steps: int = 2, sy_steps: int = 3, - t_tau: float | torch.Tensor | None = None, - b_tau: float | torch.Tensor | None = None, + t_tau: TimeArg = None, + b_tau: TimeArg = None, **kwargs: Any, ) -> NPTNoseHooverIsotropicState: """Initialize the NPT Nose-Hoover state. @@ -1810,7 +1814,7 @@ def npt_nose_hoover_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, + external_pressure: PressureArg, ) -> NPTNoseHooverIsotropicState: r"""Perform a complete NPT integration step with Nose-Hoover chain thermostats. @@ -2515,8 +2519,8 @@ def npt_crescale_triclinic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with anisotropic stochastic cell rescaling. @@ -2643,8 +2647,8 @@ def npt_crescale_anisotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2719,8 +2723,8 @@ def npt_crescale_triclinic_average_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: """Perform one NPT integration step with cell rescaling barostat. @@ -2795,8 +2799,8 @@ def npt_crescale_isotropic_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - external_pressure: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + external_pressure: PressureArg, + tau: TimeArg = None, ) -> NPTCRescaleState: r"""Perform one NPT integration step with isotropic stochastic cell rescaling. @@ -2902,8 +2906,8 @@ def npt_crescale_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - tau_p: float | torch.Tensor | None = None, - isothermal_compressibility: float | torch.Tensor | None = None, + tau_p: TimeArg = None, + isothermal_compressibility: InversePressureArg = None, ) -> NPTCRescaleState: """Initialize the NPT cell rescaling state. diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 644b10897..9007e63dc 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -8,9 +8,11 @@ import torch_sim as ts from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( + InverseTimeArg, MDState, NoseHooverChain, NoseHooverChainFns, + TimeArg, construct_nose_hoover_chain, initialize_momenta, momentum_step, @@ -143,7 +145,7 @@ def nvt_langevin_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - gamma: float | torch.Tensor | None = None, + gamma: InverseTimeArg = None, ) -> MDState: r"""Perform one complete Langevin dynamics integration step using the BAOAB scheme. @@ -291,7 +293,7 @@ def nvt_nose_hoover_init( *, kT: float | torch.Tensor, dt: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + tau: TimeArg = None, chain_length: int = 3, chain_steps: int = 3, sy_steps: int = 3, @@ -711,7 +713,7 @@ def nvt_vrescale_step( *, dt: float | torch.Tensor, kT: float | torch.Tensor, - tau: float | torch.Tensor | None = None, + tau: TimeArg = None, ) -> NVTVRescaleState: r"""Perform one complete V-Rescale (CSVR) dynamics integration step. diff --git a/torch_sim/runners.py b/torch_sim/runners.py index cd46cc7ce..2f7bb7ccd 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -18,7 +18,7 @@ import torch_sim as ts from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher -from torch_sim.integrators import INTEGRATOR_REGISTRY, Integrator +from torch_sim.integrators import INTEGRATOR_REGISTRY, INTEGRATOR_UNIT_KWARGS, Integrator from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState @@ -249,7 +249,7 @@ def _write_initial_state( trajectory_reporter.report(state, 0, model=model) -def integrate[T: SimState]( # noqa: C901 +def integrate[T: SimState]( # noqa: C901, PLR0915 system: StateLike, model: ModelInterface, *, @@ -287,7 +287,7 @@ def integrate[T: SimState]( # noqa: C901 it's passed to `tqdm` as kwargs. init_kwargs (dict[str, Any], optional): Additional keyword arguments for integrator init function. - **integrator_kwargs: Additional keyword arguments for integrator init function + **integrator_kwargs: Additional keyword arguments for integrator step function Returns: T: Final state after integration @@ -320,6 +320,16 @@ def integrate[T: SimState]( # noqa: C901 f"integrator must be key from Integrator or a tuple of " f"(init_func, step_func), got {type(integrator)}" ) + + # Like `timestep` above, unit-carrying kwargs (e.g. `tau`, `gamma`, + # `external_pressure`) must be converted to internal units per + # INTEGRATOR_UNIT_KWARGS + init_kwargs = dict(init_kwargs or {}) + channels = {"init": init_kwargs, "step": integrator_kwargs} + for key, meta in INTEGRATOR_UNIT_KWARGS.get(integrator, {}).items(): + kwargs = channels[meta.channel] + if kwargs.get(key) is not None: + kwargs[key] = kwargs[key] * meta.factor # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( initial_state, model, autobatcher=autobatcher @@ -351,9 +361,7 @@ def integrate[T: SimState]( # noqa: C901 batch_kT = ( kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs ) - state = init_func( - state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {} - ) + state = init_func(state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs) # set up trajectory reporters if autobatcher and trajectory_reporter is not None and og_filenames is not None: From 36e2f3612662fb1ac826e9a46a0878d21406b489 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 3 Jul 2026 12:12:21 -0700 Subject: [PATCH 2/2] fix lint --- torch_sim/units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/units.py b/torch_sim/units.py index 19cd6c4fc..146adab22 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -15,7 +15,7 @@ class BaseConstant(float, Enum): References: http://arxiv.org/pdf/1507.07956.pdf - https://wiki.fysik.dtu.dk/ase/_modules/ase/units.html#create_units + https://docs.ase-lib.org/_modules/ase/units.html#create_units """ def __new__(cls, value: float) -> Self: