Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,67 @@ def test_integrate_double_nvt(
assert not torch.isnan(final_state.energy).any()


def test_integrate_converts_init_kwarg(
ar_supercell_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
"""integrate scales Nose-Hoover `tau` to internal units like `timestep`.

Otherwise `tau` is ~98x too small for metal units, giving an over-stiff
thermostat that diverges on force spikes

See https://github.com/TorchSim/torch-sim/issues/579 for more info
"""
tau = 0.1 # ps, same convention as `timestep`
final = ts.integrate(
system=ar_supercell_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_nose_hoover,
n_steps=1,
temperature=100.0,
timestep=0.002,
init_kwargs={"tau": tau},
)
expected = tau * ts.units.MetalUnits.time
assert torch.allclose(final.chain.tau, torch.full_like(final.chain.tau, expected))


def test_integrate_converts_step_kwarg(
ar_supercell_sim_state: SimState,
lj_model: LennardJonesModel,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""integrate divides inverse-time kwargs (e.g. Langevin `gamma`) by the unit factor.

The opposite direction from relaxation times: `gamma` is a rate, so it scales as
1/time, not time. It is a step kwarg, so it travels through `**integrator_kwargs`
to the step function (issue #579 / InverseTimeArg).
"""
gamma = 10.0 # 1/ps, same time convention as `timestep`
# gamma is consumed inside the step function and never stored on the state, so
# unlike the tau test in test_integrate_converts_persistent_init_kwarg, we must
# spy on the step call to see the converted value.
received: dict[str, object] = {}
init, real_step = ts.integrators.INTEGRATOR_REGISTRY[ts.Integrator.nvt_langevin]

def spy_step(**kwargs: object) -> MDState:
received["gamma"] = kwargs["gamma"]
return real_step(**kwargs)

monkeypatch.setitem(
ts.integrators.INTEGRATOR_REGISTRY, ts.Integrator.nvt_langevin, (init, spy_step)
)
ts.integrate(
system=ar_supercell_sim_state,
model=lj_model,
integrator=ts.Integrator.nvt_langevin,
n_steps=1,
temperature=100.0,
timestep=0.002,
gamma=gamma,
)
assert received["gamma"] == gamma / ts.units.MetalUnits.time


def test_integrate_double_nvt_multiple_temperatures(
ar_double_sim_state: SimState, lj_model: LennardJonesModel
) -> None:
Expand Down
65 changes: 63 additions & 2 deletions torch_sim/integrators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,25 @@
"""

# ruff: noqa: F401
import typing
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Final
from typing import Any, Final, Literal

import torch_sim as ts

from .md import MDState, initialize_momenta, momentum_step, position_step, velocity_verlet
from .md import (
InversePressureArg,
InverseTimeArg,
MDState,
PressureArg,
TimeArg,
initialize_momenta,
momentum_step,
position_step,
velocity_verlet,
)
from .npt import (
NPTLangevinAnisotropicState,
NPTLangevinIsotropicState,
Expand Down Expand Up @@ -202,3 +214,52 @@ class Integrator(StrEnum):
npt_crescale_triclinic_step,
),
}


@dataclass(frozen=True, slots=True)
class UnitKwarg:
"""Metadata for integrator parameters carrying physical units.

``factor`` converts the parameter to internal units (multiply by it).
``channel`` is the function the parameter is used in (either ``"init"``
or ``"step"``).
"""

factor: float
channel: Literal["init", "step"]


_UNIT_FACTORS: Final = frozenset(
arg.__metadata__[0]
for arg in (TimeArg, InverseTimeArg, PressureArg, InversePressureArg)
)


def _collect_unit_kwargs(integrator: Integrator) -> dict[str, UnitKwarg]:
"""Read the unit-annotated kwargs off an integrator's init/step signatures."""
init_fn, step_fn = INTEGRATOR_REGISTRY[integrator]
unit_kwargs: dict[str, UnitKwarg] = {}
for channel, fn in (("init", init_fn), ("step", step_fn)):
hints = typing.get_type_hints(fn, include_extras=True)
for name, hint in hints.items():
for meta in getattr(hint, "__metadata__", ()):
if isinstance(meta, float) and meta in _UNIT_FACTORS:
if name in unit_kwargs:
raise TypeError(
f"{integrator}: unit kwarg {name!r} is annotated in both "
f"the init and step functions, making its `integrate` "
f"channel ambiguous"
)
unit_kwargs[name] = UnitKwarg(float(meta), channel)
return unit_kwargs


#: Unit-carrying kwargs of each integrator, derived at import time from the
#: ``TimeArg``/``InverseTimeArg``/``PressureArg``/``InversePressureArg`` annotations
#: on the registered init/step functions. :func:`torch_sim.runners.integrate` reads
#: this to unit-convert kwargs.
INTEGRATOR_UNIT_KWARGS: Final[dict[Integrator, dict[str, UnitKwarg]]] = {
integrator: unit_kwargs
for integrator in Integrator
if (unit_kwargs := _collect_unit_kwargs(integrator))
}
8 changes: 8 additions & 0 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated

import torch

Expand All @@ -17,6 +18,13 @@
logger = logging.getLogger(__name__)


# The metadata is the factor that converts the kwarg to internal units
TimeArg = Annotated[float | torch.Tensor | None, MetalUnits.time]
InverseTimeArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.time]
PressureArg = Annotated[float | torch.Tensor, MetalUnits.pressure]
InversePressureArg = Annotated[float | torch.Tensor | None, 1 / MetalUnits.pressure]


@dataclass(kw_only=True)
class MDState(SimState):
"""State information for molecular dynamics simulations.
Expand Down
46 changes: 25 additions & 21 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import torch_sim as ts
from torch_sim._duecredit import dcite
from torch_sim.integrators.md import (
InversePressureArg,
InverseTimeArg,
MDState,
NoseHooverChain,
NoseHooverChainFns,
PressureArg,
TimeArg,
construct_nose_hoover_chain,
initialize_momenta,
momentum_step,
Expand Down Expand Up @@ -407,9 +411,9 @@ def npt_langevin_anisotropic_init(
*,
kT: float | torch.Tensor,
dt: float | torch.Tensor,
alpha: float | torch.Tensor | None = None,
cell_alpha: float | torch.Tensor | None = None,
b_tau: float | torch.Tensor | None = None,
alpha: InverseTimeArg = None,
cell_alpha: InverseTimeArg = None,
b_tau: TimeArg = None,
**_kwargs: Any,
) -> NPTLangevinAnisotropicState:
"""Initialize NPT Langevin state with independent per-dimension cell lengths.
Expand Down Expand Up @@ -511,7 +515,7 @@ def npt_langevin_anisotropic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
external_pressure: PressureArg,
) -> NPTLangevinAnisotropicState:
r"""Perform one NPT Langevin step with independent per-dimension cell lengths.

Expand Down Expand Up @@ -886,9 +890,9 @@ def npt_langevin_isotropic_init(
*,
kT: float | torch.Tensor,
dt: float | torch.Tensor,
alpha: float | torch.Tensor | None = None,
cell_alpha: float | torch.Tensor | None = None,
b_tau: float | torch.Tensor | None = None,
alpha: InverseTimeArg = None,
cell_alpha: InverseTimeArg = None,
b_tau: TimeArg = None,
**_kwargs: Any,
) -> NPTLangevinIsotropicState:
"""Initialize an NPT Langevin state using logarithmic strain coordinate.
Expand Down Expand Up @@ -986,7 +990,7 @@ def npt_langevin_isotropic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
external_pressure: PressureArg,
) -> NPTLangevinIsotropicState:
r"""Perform one NPT Langevin step using logarithmic strain coordinate.

Expand Down Expand Up @@ -1643,8 +1647,8 @@ def npt_nose_hoover_isotropic_init(
chain_length: int = 3,
chain_steps: int = 2,
sy_steps: int = 3,
t_tau: float | torch.Tensor | None = None,
b_tau: float | torch.Tensor | None = None,
t_tau: TimeArg = None,
b_tau: TimeArg = None,
**kwargs: Any,
) -> NPTNoseHooverIsotropicState:
"""Initialize the NPT Nose-Hoover state.
Expand Down Expand Up @@ -1810,7 +1814,7 @@ def npt_nose_hoover_isotropic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
external_pressure: PressureArg,
) -> NPTNoseHooverIsotropicState:
r"""Perform a complete NPT integration step with Nose-Hoover chain thermostats.

Expand Down Expand Up @@ -2515,8 +2519,8 @@ def npt_crescale_triclinic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
external_pressure: PressureArg,
tau: TimeArg = None,
) -> NPTCRescaleState:
r"""Perform one NPT integration step with anisotropic stochastic cell rescaling.

Expand Down Expand Up @@ -2643,8 +2647,8 @@ def npt_crescale_anisotropic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
external_pressure: PressureArg,
tau: TimeArg = None,
) -> NPTCRescaleState:
"""Perform one NPT integration step with cell rescaling barostat.

Expand Down Expand Up @@ -2719,8 +2723,8 @@ def npt_crescale_triclinic_average_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
external_pressure: PressureArg,
tau: TimeArg = None,
) -> NPTCRescaleState:
"""Perform one NPT integration step with cell rescaling barostat.

Expand Down Expand Up @@ -2795,8 +2799,8 @@ def npt_crescale_isotropic_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
external_pressure: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
external_pressure: PressureArg,
tau: TimeArg = None,
) -> NPTCRescaleState:
r"""Perform one NPT integration step with isotropic stochastic cell rescaling.

Expand Down Expand Up @@ -2902,8 +2906,8 @@ def npt_crescale_init(
*,
kT: float | torch.Tensor,
dt: float | torch.Tensor,
tau_p: float | torch.Tensor | None = None,
isothermal_compressibility: float | torch.Tensor | None = None,
tau_p: TimeArg = None,
isothermal_compressibility: InversePressureArg = None,
) -> NPTCRescaleState:
"""Initialize the NPT cell rescaling state.

Expand Down
8 changes: 5 additions & 3 deletions torch_sim/integrators/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import torch_sim as ts
from torch_sim._duecredit import dcite
from torch_sim.integrators.md import (
InverseTimeArg,
MDState,
NoseHooverChain,
NoseHooverChainFns,
TimeArg,
construct_nose_hoover_chain,
initialize_momenta,
momentum_step,
Expand Down Expand Up @@ -143,7 +145,7 @@ def nvt_langevin_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
gamma: float | torch.Tensor | None = None,
gamma: InverseTimeArg = None,
) -> MDState:
r"""Perform one complete Langevin dynamics integration step using the BAOAB scheme.

Expand Down Expand Up @@ -291,7 +293,7 @@ def nvt_nose_hoover_init(
*,
kT: float | torch.Tensor,
dt: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
tau: TimeArg = None,
chain_length: int = 3,
chain_steps: int = 3,
sy_steps: int = 3,
Expand Down Expand Up @@ -711,7 +713,7 @@ def nvt_vrescale_step(
*,
dt: float | torch.Tensor,
kT: float | torch.Tensor,
tau: float | torch.Tensor | None = None,
tau: TimeArg = None,
) -> NVTVRescaleState:
r"""Perform one complete V-Rescale (CSVR) dynamics integration step.

Expand Down
20 changes: 14 additions & 6 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch_sim as ts
from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher
from torch_sim.integrators import INTEGRATOR_REGISTRY, Integrator
from torch_sim.integrators import INTEGRATOR_REGISTRY, INTEGRATOR_UNIT_KWARGS, Integrator
from torch_sim.integrators.md import MDState
from torch_sim.models.interface import ModelInterface
from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState
Expand Down Expand Up @@ -249,7 +249,7 @@ def _write_initial_state(
trajectory_reporter.report(state, 0, model=model)


def integrate[T: SimState]( # noqa: C901
def integrate[T: SimState]( # noqa: C901, PLR0915
system: StateLike,
model: ModelInterface,
*,
Expand Down Expand Up @@ -287,7 +287,7 @@ def integrate[T: SimState]( # noqa: C901
it's passed to `tqdm` as kwargs.
init_kwargs (dict[str, Any], optional): Additional keyword arguments for
integrator init function.
**integrator_kwargs: Additional keyword arguments for integrator init function
**integrator_kwargs: Additional keyword arguments for integrator step function

Returns:
T: Final state after integration
Expand Down Expand Up @@ -320,6 +320,16 @@ def integrate[T: SimState]( # noqa: C901
f"integrator must be key from Integrator or a tuple of "
f"(init_func, step_func), got {type(integrator)}"
)

# Like `timestep` above, unit-carrying kwargs (e.g. `tau`, `gamma`,
# `external_pressure`) must be converted to internal units per
# INTEGRATOR_UNIT_KWARGS
init_kwargs = dict(init_kwargs or {})
channels = {"init": init_kwargs, "step": integrator_kwargs}
for key, meta in INTEGRATOR_UNIT_KWARGS.get(integrator, {}).items():
kwargs = channels[meta.channel]
if kwargs.get(key) is not None:
kwargs[key] = kwargs[key] * meta.factor
# batch_iterator will be a list if autobatcher is False
batch_iterator = _configure_batches_iterator(
initial_state, model, autobatcher=autobatcher
Expand Down Expand Up @@ -351,9 +361,7 @@ def integrate[T: SimState]( # noqa: C901
batch_kT = (
kTs[:, system_indices] if (system_indices and len(kTs.shape) == 2) else kTs
)
state = init_func(
state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs or {}
)
state = init_func(state=state, model=model, kT=batch_kT[0], dt=dt, **init_kwargs)

# set up trajectory reporters
if autobatcher and trajectory_reporter is not None and og_filenames is not None:
Expand Down
Loading
Loading