From 9914af28b5cfb4c37c6d593cc2d509f28e6262a1 Mon Sep 17 00:00:00 2001 From: amateurcat Date: Tue, 16 Jun 2026 17:23:35 -0700 Subject: [PATCH 1/4] Implement Grimme-stype bias potentials at torch_sim/metadynamics.py --- docs/reference/index.rst | 1 + tests/test_metadynamics.py | 202 +++++++++++++++++++++ torch_sim/__init__.py | 2 + torch_sim/metadynamics.py | 362 +++++++++++++++++++++++++++++++++++++ 4 files changed, 567 insertions(+) create mode 100644 tests/test_metadynamics.py create mode 100644 torch_sim/metadynamics.py diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 9c6d172a0..d40cfd35a 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -18,6 +18,7 @@ Overview of the TorchSim API. io integrators math + metadynamics models monte_carlo neighbors diff --git a/tests/test_metadynamics.py b/tests/test_metadynamics.py new file mode 100644 index 000000000..566994b87 --- /dev/null +++ b/tests/test_metadynamics.py @@ -0,0 +1,202 @@ +import pytest +import torch +from ase.build import molecule + +import torch_sim as ts +from torch_sim.metadynamics import EPS, RMSDCV, LogfermiWall +from torch_sim.models.interface import SumModel +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.units import UnitConversion + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +@pytest.fixture +def ethanol_state() -> ts.SimState: + """Single low-symmetry molecule (avoids degenerate Kabsch SVD).""" + return ts.io.atoms_to_state([molecule("CH3CH2OH")], DEVICE, DTYPE) + + +@pytest.fixture +def ragged_state() -> ts.SimState: + """Two molecules with different atom counts in one batch.""" + return ts.io.atoms_to_state([molecule("CH3CH2OH"), molecule("H2O")], DEVICE, DTYPE) + + +class TestLogfermiWall: + def test_output_shapes(self, ragged_state: ts.SimState) -> None: + wall = LogfermiWall(radius=10.0, device=DEVICE, dtype=DTYPE) + output = wall(ragged_state) + assert output["energy"].shape == (2,) + assert output["forces"].shape == (ragged_state.n_atoms, 3) + + def test_energy_negligible_deep_inside(self, ethanol_state: ts.SimState) -> None: + wall = LogfermiWall(radius=50.0, device=DEVICE, dtype=DTYPE) + output = wall(ethanol_state) + assert output["energy"].abs().max() < 1e-10 + assert output["forces"].abs().max() < 1e-10 + + def test_restoring_force_outside_wall(self, ethanol_state: ts.SimState) -> None: + state = ethanol_state + state.positions = state.positions + torch.tensor([20.0, 0.0, 0.0]) + wall = LogfermiWall(radius=5.0, device=DEVICE, dtype=DTYPE) + output = wall(state) + assert output["energy"].min() > 0 + # all atoms sit at x ~ 20 > radius, so forces must point back toward origin + assert (output["forces"][:, 0] < 0).all() + + def test_forces_match_autograd(self, ragged_state: ts.SimState) -> None: + state = ragged_state + wall = LogfermiWall(radius=2.0, beta=3.0, device=DEVICE, dtype=DTYPE) + analytic = wall(state)["forces"] + + positions = state.positions.detach().clone().requires_grad_(requires_grad=True) + dvec = positions + r = torch.norm(dvec, dim=-1) + EPS + v_atom = wall.k_wall * torch.nn.functional.softplus(wall.beta * (r - wall.radius)) + expected = -torch.autograd.grad(v_atom.sum(), positions)[0] + torch.testing.assert_close(analytic, expected, atol=1e-10, rtol=1e-8) + + def test_per_system_center(self, ragged_state: ts.SimState) -> None: + # centering the wall on each molecule's first atom changes nothing + # qualitatively, but exercises the (n_systems, 3) center branch + first_atom_idx = torch.tensor([0, ragged_state.system_idx.tolist().count(0)]) + center = ragged_state.positions[first_atom_idx] + wall = LogfermiWall(radius=5.0, center=center, device=DEVICE, dtype=DTYPE) + output = wall(ragged_state) + assert torch.isfinite(output["energy"]).all() + assert torch.isfinite(output["forces"]).all() + + +class TestRMSDCV: + def test_first_call_seeds_and_returns_zero(self, ragged_state: ts.SimState) -> None: + bias = RMSDCV(device=DEVICE, dtype=DTYPE) + output = bias(ragged_state) + assert output["energy"].abs().max() == 0 + assert output["forces"].abs().max() == 0 + assert bias.ref_buf is not None + assert bias.ref_buf.shape == (1, ragged_state.n_atoms, 3) + + def test_unmoved_state_feels_full_bias(self, ragged_state: ts.SimState) -> None: + k_push = 0.02 + bias = RMSDCV(k_push=k_push, update_interval=1000, device=DEVICE, dtype=DTYPE) + bias(ragged_state) # seed + output = bias(ragged_state) + expected = k_push * UnitConversion.Hartree_to_eV + torch.testing.assert_close( + output["energy"], + torch.full((2,), expected, device=DEVICE, dtype=DTYPE), + ) + # rmsd^2 = 0 is a stationary point of the bias, so forces vanish + assert output["forces"].abs().max() < 1e-8 + + def test_rotation_translation_invariance(self, ethanol_state: ts.SimState) -> None: + bias = RMSDCV(k_push=0.02, update_interval=1000, device=DEVICE, dtype=DTYPE) + bias(ethanol_state) # seed + reference_energy = bias(ethanol_state)["energy"] + + angle = torch.tensor(0.3, dtype=DTYPE) + cos_a, sin_a = torch.cos(angle), torch.sin(angle) + rot = torch.tensor( + [[cos_a, -sin_a, 0.0], [sin_a, cos_a, 0.0], [0.0, 0.0, 1.0]], dtype=DTYPE + ) + moved = ethanol_state.clone() + moved.positions = moved.positions @ rot.T + torch.tensor([1.0, -2.0, 3.0]) + output = bias(moved) + torch.testing.assert_close(output["energy"], reference_energy) + + def test_forces_match_finite_difference(self, ethanol_state: ts.SimState) -> None: + bias = RMSDCV( + k_push=0.02, alpha_width=1.0, update_interval=1000, device=DEVICE, dtype=DTYPE + ) + bias(ethanol_state) # seed + perturbed = ethanol_state.clone() + torch.manual_seed(42) + perturbed.positions = perturbed.positions + 0.1 * torch.randn_like( + perturbed.positions + ) + forces = bias(perturbed)["forces"] + + delta = 1e-6 + for atom, coord in [(0, 0), (3, 1), (7, 2)]: + plus = perturbed.clone() + plus.positions[atom, coord] += delta + minus = perturbed.clone() + minus.positions[atom, coord] -= delta + e_plus = bias(plus)["energy"].sum() + e_minus = bias(minus)["energy"].sum() + numerical = -(e_plus - e_minus) / (2 * delta) + torch.testing.assert_close( + forces[atom, coord], numerical, atol=1e-6, rtol=1e-4 + ) + + def test_buffer_capped_at_n_refs(self, ethanol_state: ts.SimState) -> None: + bias = RMSDCV(n_refs=3, update_interval=1, device=DEVICE, dtype=DTYPE) + for _ in range(6): + bias(ethanol_state) + assert bias.ref_buf is not None + assert bias.ref_buf.shape[0] == 3 + + def test_update_interval_deposition(self, ethanol_state: ts.SimState) -> None: + bias = RMSDCV(n_refs=100, update_interval=3, device=DEVICE, dtype=DTYPE) + for _ in range(7): # seed + deposits at calls 3 and 6 + bias(ethanol_state) + assert bias.ref_buf is not None + assert bias.ref_buf.shape[0] == 3 + + def test_atom_mask_zeroes_fixed_atom_forces(self, ethanol_state: ts.SimState) -> None: + mask = torch.ones(ethanol_state.n_atoms, dtype=torch.bool) + mask[:3] = False + bias = RMSDCV(atom_mask=mask, update_interval=1000, device=DEVICE, dtype=DTYPE) + bias(ethanol_state) # seed + perturbed = ethanol_state.clone() + torch.manual_seed(7) + perturbed.positions = perturbed.positions + 0.2 * torch.randn_like( + perturbed.positions + ) + output = bias(perturbed) + assert output["forces"][:3].abs().max() == 0 + assert output["forces"][3:].abs().max() > 0 + + def test_reset(self, ethanol_state: ts.SimState) -> None: + bias = RMSDCV(device=DEVICE, dtype=DTYPE) + bias(ethanol_state) + bias.reset() + assert bias.ref_buf is None + output = bias(ethanol_state) # re-seeds without error + assert output["energy"].abs().max() == 0 + + +class TestIntegrationWithSumModel: + @pytest.fixture + def lj_model(self) -> LennardJonesModel: + return LennardJonesModel( + sigma=2.0, + epsilon=0.01, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + cutoff=5.0, + ) + + def test_nvt_with_wall_and_rmsd_bias( + self, lj_model: LennardJonesModel, ethanol_state: ts.SimState + ) -> None: + wall = LogfermiWall(radius=6.0, device=DEVICE, dtype=DTYPE) + bias = RMSDCV(k_push=0.005, update_interval=2, device=DEVICE, dtype=DTYPE) + model = SumModel(lj_model, wall, bias) + + final_state = ts.integrate( + system=ethanol_state, + model=model, + integrator=ts.Integrator.nvt_langevin, + n_steps=10, + timestep=0.001, + temperature=300, + ) + assert torch.isfinite(final_state.positions).all() + assert torch.isfinite(final_state.energy).all() + assert bias.ref_buf is not None + assert bias.ref_buf.shape[0] > 1 diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index b24fab124..2a7bf2b20 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -11,6 +11,7 @@ elastic, io, math, + metadynamics, models, monte_carlo, neighbors, @@ -160,6 +161,7 @@ "lbfgs_init", "lbfgs_step", "math", + "metadynamics", "models", "monte_carlo", "neighbors", diff --git a/torch_sim/metadynamics.py b/torch_sim/metadynamics.py new file mode 100644 index 000000000..1b3b10ce4 --- /dev/null +++ b/torch_sim/metadynamics.py @@ -0,0 +1,362 @@ +"""Metadynamics bias potentials. + +This module implements history-dependent and static bias potentials that add +external energies and forces to a simulation. Each bias is a +:class:`~torch_sim.models.interface.ModelInterface`, so it composes with any +MLIP (or classical potential) through +:class:`~torch_sim.models.interface.SumModel`:: + + bias = RMSDCV(k_push=0.02, alpha_width=1.2) + biased_model = SumModel(mace_model, bias) + final_state = ts.integrate( + system=state, + model=biased_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=1000, + timestep=0.002, + temperature=300, + ) + +Both potentials follow the metadynamics scheme of Grimme +(10.1021/acs.jctc.9b00143, as implemented in xtb) and keep that paper's input +units (Hartree, Bohr) for their energy/width parameters; conversion to +TorchSim's eV/Angstrom internal units happens inside the classes. + +Notes: + These biases act on Cartesian coordinates and do not apply periodic + boundary conditions, matching their intended use on non-periodic + (molecular/cluster) systems. They contribute no stress. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.units import UnitConversion + + +if TYPE_CHECKING: + from torch_sim.state import SimState + + +EPS = 1e-12 + + +def _segment_sum( + src: torch.Tensor, system_idx: torch.Tensor, n_systems: int, dim: int = 0 +) -> torch.Tensor: + """Sum *src* over atoms belonging to the same system. + + Args: + src: Per-atom values whose dimension *dim* has size n_atoms. + system_idx: System index of each atom with shape [n_atoms]. + n_systems: Number of systems in the batch. + dim: The atom dimension of *src* indexed by *system_idx*. Defaults to 0. + + Returns: + Per-system sums where the atom dimension is replaced by n_systems. + """ + out_shape = list(src.shape) + out_shape[dim] = n_systems + out = torch.zeros(out_shape, device=src.device, dtype=src.dtype) + return out.index_add(dim, system_idx, src) + + +class LogfermiWall(ModelInterface): + """Log-Fermi wall potential confining atoms inside a sphere. + + Adds the per-atom energy ``k_wall * log(1 + exp(beta * (r - radius)))`` + where ``r`` is the distance of the atom from the wall center. The energy + is near zero well inside the sphere and grows linearly (slope + ``k_wall * beta``) outside it, gently steering escaping atoms back. + Idea and default parameters from 10.1021/acs.jctc.9b00143. + + Forces are computed analytically, so the model is safe to call under + ``torch.no_grad()``. + + Args: + radius: Wall radius in Angstrom. Defaults to 10.0. + k_wall: Wall strength in Hartree (Grimme's units, converted to eV + internally). Defaults to 0.019. + beta: Wall steepness in 1/Bohr (converted to 1/Angstrom internally). + Defaults to 10.0. + center: Wall center with shape [3] (shared by all systems) or + [n_systems, 3]. Defaults to the origin. + energy_label: Non-canonical output key under which this wall's + per-system energy is also reported, letting it survive SumModel + and land on the SimState as an extra (e.g. ``state.PE_FermiWall``). + Defaults to "PE_FermiWall"; give distinct labels if composing + several walls. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float64. + + Example:: + + wall = LogfermiWall(radius=8.0, device=model.device, dtype=model.dtype) + confined_model = SumModel(model, wall) + """ + + def __init__( + self, + radius: float = 10.0, + k_wall: float = 0.019, + beta: float = 10.0, + center: torch.Tensor | None = None, + energy_label: str = "PE_FermiWall", + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + ) -> None: + """Initialize the log-Fermi wall.""" + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = True + self._compute_stress = False + self._memory_scales_with = "n_atoms" + + self.energy_label = str(energy_label) + self.radius = float(radius) + self.k_wall = float(k_wall) * UnitConversion.Hartree_to_eV + self.beta = float(beta) * UnitConversion.Ang_to_Bohr # 1/Bohr -> 1/Ang + center = None if center is None else center.to(self._device, self._dtype) + self.center: torch.Tensor | None + self.register_buffer("center", center) + + def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: + """Compute wall energies and forces. + + Args: + state: Simulation state with positions [n_atoms, 3] and system_idx. + **_kwargs: Unused, accepted for interface compatibility. + + Returns: + Dict with "energy" [n_systems], "forces" [n_atoms, 3], and the + same wall energy under ``self.energy_label`` (a non-canonical key + so it survives SumModel and is stored on the SimState as a + per-system extra). + """ + positions = state.positions + system_idx = state.system_idx + + if self.center is None: + dvec = positions + elif self.center.ndim == 1: + dvec = positions - self.center + else: + dvec = positions - self.center[system_idx] + + r = torch.norm(dvec, dim=-1) + EPS # (n_atoms,) + x = self.beta * (r - self.radius) + v_atom = self.k_wall * torch.nn.functional.softplus(x) # log1p(e^x), no overflow + energy = _segment_sum(v_atom, system_idx, state.n_systems) + + # F = -dV/dr * r_hat with dV/dr = k_wall * beta * sigmoid(x) + dv_dr = self.k_wall * self.beta * torch.sigmoid(x) + forces = -dv_dr.unsqueeze(-1) * dvec / r.unsqueeze(-1) + + return {"energy": energy, "forces": forces, self.energy_label: energy} + + +class RMSDCV(ModelInterface): + """History-dependent RMSD bias (weighted, Kabsch-aligned) for metadynamics. + + Maintains a rolling buffer of reference structures and adds the repulsive + bias ``E = k_push * sum_i exp(-alpha * rmsd2_i)`` per system, where + ``rmsd2_i`` is the squared deviation from reference *i* after optimal + (Kabsch) alignment, averaged over biased atoms and Cartesian components. + This pushes the dynamics away from previously visited configurations. + Idea and default parameters from 10.1021/acs.jctc.9b00143. + + The buffer is seeded on the first call (which returns zero bias), and a + new reference is deposited every ``update_interval`` calls. TorchSim integrators + evaluate the model once per MD step (plus once at initialization), so + calls correspond to MD steps. Use :meth:`push_reference` and :meth:`reset` for manual + control of the buffer. + + Forces are obtained by autograd through the alignment, so backprop through + the SVD requires non-degenerate singular values (generic for molecular + geometries). Because the buffer is shaped to the batch seen first, the + model must be re-:meth:`reset` before reuse with a different batch. + + Args: + k_push: Bias strength in Hartree (converted to eV internally). + Defaults to 1.0. + alpha_width: Gaussian width in 1/Bohr^2 (converted to 1/Angstrom^2 + internally). Defaults to 1.0. + n_refs: Maximum number of stored references; oldest are dropped. + Defaults to 10. + update_interval: Deposit a new reference every this many calls. + Defaults to 1. + atom_mask: Boolean mask with shape [n_atoms] over the concatenated + atoms selecting which atoms participate in the CV (True = + biased). Excluded atoms feel no bias force. Defaults to all atoms. + energy_label: Non-canonical output key under which this bias's + per-system energy is also reported, letting it survive SumModel + and land on the SimState as an extra (e.g. ``state.PE_RMSDCV``). + Defaults to "PE_RMSDCV"; give distinct labels if composing + several biases. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float64. + + Example:: + + bias = RMSDCV(k_push=0.02, alpha_width=1.2, n_refs=20, update_interval=50) + metad_model = SumModel(mace_model, bias) + """ + + def __init__( + self, + k_push: float = 1.0, + alpha_width: float = 1.0, + n_refs: int = 10, + update_interval: int = 1, + atom_mask: torch.Tensor | None = None, + energy_label: str = "PE_RMSDCV", + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + ) -> None: + """Initialize the RMSD collective-variable bias.""" + super().__init__() + if n_refs < 1: + raise ValueError(f"{n_refs=} must be >= 1") + if update_interval < 1: + raise ValueError(f"{update_interval=} must be >= 1") + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = True + self._compute_stress = False + self._memory_scales_with = "n_atoms" + + self.energy_label = str(energy_label) + self.k_push = float(k_push) * UnitConversion.Hartree_to_eV + self.alpha = float(alpha_width) * UnitConversion.Ang_to_Bohr**2 + self.n_refs = int(n_refs) + self.update_interval = int(update_interval) + atom_mask = None if atom_mask is None else atom_mask.to(self._device, torch.bool) + self.atom_mask: torch.Tensor | None + self.register_buffer("atom_mask", atom_mask) + + # rolling buffer of centered references, shape (n_stored, n_biased_atoms, 3) + self.ref_buf: torch.Tensor | None = None + self._n_calls = 0 + + def reset(self) -> None: + """Clear all stored references and the internal call counter.""" + self.ref_buf = None + self._n_calls = 0 + + @torch.no_grad() + def push_reference(self, state: SimState) -> None: + """Deposit the current configuration as a new reference. + + Args: + state: Simulation state whose (masked) positions are stored. + """ + positions, system_idx = self._masked(state) + counts = torch.bincount(system_idx, minlength=state.n_systems) + com = _segment_sum(positions, system_idx, state.n_systems) + com = com / counts.unsqueeze(-1) + centered = (positions - com[system_idx]).unsqueeze(0) + if self.ref_buf is None: + self.ref_buf = centered + else: + self.ref_buf = torch.cat([self.ref_buf, centered], dim=0)[-self.n_refs :] + + def _masked(self, state: SimState) -> tuple[torch.Tensor, torch.Tensor]: + """Return positions and system indices of biased atoms only.""" + if self.atom_mask is None: + return state.positions, state.system_idx + return state.positions[self.atom_mask], state.system_idx[self.atom_mask] + + @staticmethod + def _kabsch( + rc: torch.Tensor, qc: torch.Tensor, system_idx: torch.Tensor, n_systems: int + ) -> torch.Tensor: + """Batched Kabsch rotations aligning current coords to many references. + + Args: + rc: Centered current coordinates with shape [n_biased, 3]. + qc: Centered reference coordinates with shape [n_refs, n_biased, 3]. + system_idx: System index of each biased atom with shape [n_biased]. + n_systems: Number of systems in the batch. + + Returns: + Rotation matrices with shape [n_refs, n_systems, 3, 3]. + """ + outer = rc.unsqueeze(0).unsqueeze(-1) * qc.unsqueeze(-2) # (X, n, 3, 3) + cov = _segment_sum(outer, system_idx, n_systems, dim=1) # (X, M, 3, 3) + u_mat, _, vh_mat = torch.linalg.svd(cov) + v_mat = vh_mat.transpose(-2, -1) + ut_mat = u_mat.transpose(-2, -1) + det = torch.linalg.det(v_mat @ ut_mat) # (X, n_systems) + corr = torch.eye(3, device=cov.device, dtype=cov.dtype).expand_as(cov).clone() + corr[..., 2, 2] = torch.where(det < 0, -1.0, 1.0) + return v_mat @ corr @ ut_mat + + def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: + """Compute bias energies and forces, depositing references as needed. + + Args: + state: Simulation state with positions [n_atoms, 3] and system_idx. + **_kwargs: Unused, accepted for interface compatibility. + + Returns: + Dict with "energy" [n_systems], "forces" [n_atoms, 3], and the + same bias energy under ``self.energy_label`` (a non-canonical key + so it survives SumModel and is stored on the SimState as a + per-system extra). + """ + n_systems = state.n_systems + + if self.ref_buf is None: + self.push_reference(state) + self._n_calls = 1 + zero_energy = torch.zeros( + n_systems, device=state.positions.device, dtype=state.positions.dtype + ) + return { + "energy": zero_energy, + "forces": torch.zeros_like(state.positions), + self.energy_label: zero_energy, + } + + masked_pos, system_idx = self._masked(state) + counts = torch.bincount(system_idx, minlength=n_systems) + + with torch.enable_grad(): + pos = masked_pos.detach().requires_grad_(requires_grad=True) # (n_biased, 3) + qc = self.ref_buf.to(pos) # (X, n_biased, 3), already centered + + com = _segment_sum(pos, system_idx, n_systems) / counts.unsqueeze(-1) + rc = pos - com[system_idx] + + rot = self._kabsch(rc, qc, system_idx, n_systems) # (X, M, 3, 3) + rc_rot = torch.einsum("xncd,nd->xnc", rot[:, system_idx], rc) + diff = rc_rot - qc + # squared deviation averaged over atoms AND coords (matches the xtb input + # convention; alpha absorbs the factor 3 vs the conventional RMSD^2) + sq = diff.pow(2).sum(dim=-1) # (X, n_biased) + rmsd2 = _segment_sum(sq, system_idx, n_systems, dim=1) / (3 * counts) + + energy = self.k_push * torch.exp(-self.alpha * rmsd2).sum(dim=0) # (M,) + grad = torch.autograd.grad(energy.sum(), pos)[0] + + forces = torch.zeros_like(state.positions) + if self.atom_mask is None: + forces = -grad + else: + forces[self.atom_mask] = -grad + + if self._n_calls % self.update_interval == 0: + self.push_reference(state) + self._n_calls += 1 + + detached_energy = energy.detach() + return { + "energy": detached_energy, + "forces": forces.detach(), + self.energy_label: detached_energy, + } From 2b63e4e790c72489dafe9135b8730fd6034b3206 Mon Sep 17 00:00:00 2001 From: amateurcat Date: Mon, 22 Jun 2026 13:41:13 -0700 Subject: [PATCH 2/4] 06222026@U52 add enhanced_sampling module, reconstruct meta-dynamics functions and moved them into the new module implemented Boxed MD --- docs/reference/index.rst | 2 +- tests/test_boxed_md.py | 191 +++++++++++ tests/test_metadynamics.py | 2 +- torch_sim/__init__.py | 4 +- torch_sim/enhanced_sampling/__init__.py | 23 ++ torch_sim/enhanced_sampling/boxed_md.py | 320 ++++++++++++++++++ torch_sim/enhanced_sampling/history.py | 134 ++++++++ .../{ => enhanced_sampling}/metadynamics.py | 52 +-- 8 files changed, 702 insertions(+), 26 deletions(-) create mode 100644 tests/test_boxed_md.py create mode 100644 torch_sim/enhanced_sampling/__init__.py create mode 100644 torch_sim/enhanced_sampling/boxed_md.py create mode 100644 torch_sim/enhanced_sampling/history.py rename torch_sim/{ => enhanced_sampling}/metadynamics.py (90%) diff --git a/docs/reference/index.rst b/docs/reference/index.rst index d40cfd35a..6f26a0c7e 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -15,10 +15,10 @@ Overview of the TorchSim API. autobatching properties.correlations elastic + enhanced_sampling io integrators math - metadynamics models monte_carlo neighbors diff --git a/tests/test_boxed_md.py b/tests/test_boxed_md.py new file mode 100644 index 000000000..6f4f8f58e --- /dev/null +++ b/tests/test_boxed_md.py @@ -0,0 +1,191 @@ +import pytest +import torch +from ase.build import molecule + +import torch_sim as ts +from torch_sim.enhanced_sampling.boxed_md import BoxedMD, run_boxed_md, velocity_inversion +from torch_sim.integrators.nvt import nvt_langevin_init +from torch_sim.models.interface import ModelInterface + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class HarmonicModel(ModelInterface): + """Per-atom harmonic well E = 1/2 k sum (x - x0)^2. + + A bound, stationary potential: the potential energy fluctuates around its + minimum and keeps producing fresh maxima, so BXDE can ratchet the system + outward and place a sequence of rising boundaries (unlike an unbound LJ + cluster, which simply relaxes and never beats its initial energy). + """ + + def __init__( + self, + centers: torch.Tensor, + k: float = 1.0, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + ) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = False + self.k = float(k) + self.register_buffer("centers", centers.to(device, dtype)) + + def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: + disp = state.positions - self.centers + e_atom = 0.5 * self.k * disp.pow(2).sum(dim=-1) + energy = torch.zeros( + state.n_systems, device=self._device, dtype=self._dtype + ).index_add(0, state.system_idx, e_atom) + forces = -self.k * disp + return {"energy": energy, "forces": forces} + + +@pytest.fixture +def ethanol_state() -> ts.SimState: + return ts.io.atoms_to_state([molecule("CH3CH2OH")], DEVICE, DTYPE) + + +@pytest.fixture +def harmonic_model(ethanol_state: ts.SimState) -> HarmonicModel: + return HarmonicModel(ethanol_state.positions.clone(), k=2.0) + + +class TestVelocityInversion: + def _random_inputs(self, n_atoms: int = 6) -> tuple[torch.Tensor, ...]: + torch.manual_seed(0) + momenta = torch.randn(n_atoms, 3, dtype=DTYPE) + forces = torch.randn(n_atoms, 3, dtype=DTYPE) + masses = torch.rand(n_atoms, dtype=DTYPE) + 0.5 + return momenta, forces, masses + + def test_reverses_gradient_projection(self) -> None: + # grad(phi) = grad(E) = -F, so a valid reflection must flip grad(phi).v, + # equivalently F . v' == -(F . v). + momenta, forces, masses = self._random_inputs() + velocities = momenta / masses.unsqueeze(-1) + new_momenta = velocity_inversion(momenta, forces, masses) + new_velocities = new_momenta / masses.unsqueeze(-1) + + f_dot_v = (forces * velocities).sum() + f_dot_v_new = (forces * new_velocities).sum() + torch.testing.assert_close(f_dot_v_new, -f_dot_v) + + def test_conserves_kinetic_energy(self) -> None: + # The mass-metric reflection is elastic: KE = 1/2 sum p^2 / m is preserved. + momenta, forces, masses = self._random_inputs() + new_momenta = velocity_inversion(momenta, forces, masses) + ke = (momenta.pow(2) / masses.unsqueeze(-1)).sum() + ke_new = (new_momenta.pow(2) / masses.unsqueeze(-1)).sum() + torch.testing.assert_close(ke_new, ke) + + def test_idempotent_pair(self) -> None: + # Applying the inversion twice (with the same forces) returns the original. + momenta, forces, masses = self._random_inputs() + once = velocity_inversion(momenta, forces, masses) + twice = velocity_inversion(once, forces, masses) + torch.testing.assert_close(twice, momenta) + + +class TestRunBoxedMD: + def test_rejects_multiple_systems(self, harmonic_model: HarmonicModel) -> None: + two = ts.io.atoms_to_state( + [molecule("CH3CH2OH"), molecule("H2O")], DEVICE, DTYPE + ) + with pytest.raises(ValueError, match="single system"): + run_boxed_md( + two, + harmonic_model, + n_steps=10, + i_samp=2, + timestep=0.001, + temperature=300, + ) + + def test_runs_and_returns_finite_state( + self, harmonic_model: HarmonicModel, ethanol_state: ts.SimState + ) -> None: + final_state, floors = run_boxed_md( + ethanol_state, + harmonic_model, + n_steps=200, + i_samp=5, + timestep=0.001, + temperature=300, + seed=1, + ) + assert torch.isfinite(final_state.positions).all() + assert torch.isfinite(final_state.energy).all() + assert floors.ndim == 1 + + def test_floors_monotonically_increase( + self, harmonic_model: HarmonicModel, ethanol_state: ts.SimState + ) -> None: + _, floors = run_boxed_md( + ethanol_state, + harmonic_model, + n_steps=500, + i_samp=3, + timestep=0.001, + temperature=500, + seed=2, + ) + assert floors.numel() >= 2 + # each new box raises (never lowers) the accessible-energy floor + assert (floors[1:] >= floors[:-1]).all() + + +class TestBoxedMDController: + def _init( + self, model: HarmonicModel, state: ts.SimState + ) -> tuple[ts.SimState, float, torch.Tensor]: + state.rng = 0 + kT = 300 * ts.units.UnitSystem.metal.temperature + dt = 0.001 * ts.units.UnitSystem.metal.time + return nvt_langevin_init(state, model, kT=kT), kT, dt + + def test_step_limit_then_resume( + self, harmonic_model: HarmonicModel, ethanol_state: ts.SimState + ) -> None: + md_state, kT, dt = self._init(harmonic_model, ethanol_state) + controller = BoxedMD(harmonic_model, i_samp=10_000, dt=dt, kT=kT) + + # i_samp is huge, so no boundary can be placed within a small budget + state, used, status = controller.run_epoch(md_state, max_steps=5) + assert used == 5 + assert status == BoxedMD.STEP_LIMIT + assert controller.total_steps == 5 + assert controller.i == 5 # all accepted (no floor yet), window advanced + + # resuming continues the same window rather than restarting it + controller.run_epoch(state, max_steps=3) + assert controller.total_steps == 8 + assert controller.i == 8 + + def test_new_boundary_status_and_record( + self, harmonic_model: HarmonicModel, ethanol_state: ts.SimState + ) -> None: + md_state, kT, dt = self._init(harmonic_model, ethanol_state) + controller = BoxedMD(harmonic_model, i_samp=3, dt=dt, kT=kT) + + _state, _used, status = controller.run_epoch(md_state, max_steps=500) + assert status == BoxedMD.NEW_BOUNDARY + assert controller.v_bxde is not None + assert len(controller.floors) == 1 + assert controller.i == 0 # window counter reset for the next box + + def test_reset( + self, harmonic_model: HarmonicModel, ethanol_state: ts.SimState + ) -> None: + md_state, kT, dt = self._init(harmonic_model, ethanol_state) + controller = BoxedMD(harmonic_model, i_samp=3, dt=dt, kT=kT) + controller.run_epoch(md_state, max_steps=500) + controller.reset() + assert controller.v_bxde is None + assert controller.total_steps == 0 + assert controller.floors.is_empty diff --git a/tests/test_metadynamics.py b/tests/test_metadynamics.py index 566994b87..c57a7099d 100644 --- a/tests/test_metadynamics.py +++ b/tests/test_metadynamics.py @@ -3,7 +3,7 @@ from ase.build import molecule import torch_sim as ts -from torch_sim.metadynamics import EPS, RMSDCV, LogfermiWall +from torch_sim.enhanced_sampling.metadynamics import EPS, RMSDCV, LogfermiWall from torch_sim.models.interface import SumModel from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.units import UnitConversion diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 2a7bf2b20..0698957e5 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -9,9 +9,9 @@ autobatching, constraints, elastic, + enhanced_sampling, io, math, - metadynamics, models, monte_carlo, neighbors, @@ -147,6 +147,7 @@ "concatenate_states", "constraints", "elastic", + "enhanced_sampling", "fire_init", "fire_step", "generate_energy_convergence_fn", @@ -161,7 +162,6 @@ "lbfgs_init", "lbfgs_step", "math", - "metadynamics", "models", "monte_carlo", "neighbors", diff --git a/torch_sim/enhanced_sampling/__init__.py b/torch_sim/enhanced_sampling/__init__.py new file mode 100644 index 000000000..e94f2f2bf --- /dev/null +++ b/torch_sim/enhanced_sampling/__init__.py @@ -0,0 +1,23 @@ +"""Enhanced-sampling building blocks for TorchSim. + +This subpackage collects enhanced-sampling methods and the reusable machinery +behind them. :class:`~torch_sim.enhanced_sampling.history.History` provides the +deposition/capacity bookkeeping shared by history-dependent biases. The +:mod:`~torch_sim.enhanced_sampling.metadynamics` module implements bias +potentials (:class:`LogfermiWall`, :class:`RMSDCV`) that compose with any MLIP +through :class:`~torch_sim.models.interface.SumModel`. +""" + +from torch_sim.enhanced_sampling.boxed_md import BoxedMD, run_boxed_md, velocity_inversion +from torch_sim.enhanced_sampling.history import History +from torch_sim.enhanced_sampling.metadynamics import RMSDCV, LogfermiWall + + +__all__ = [ + "RMSDCV", + "BoxedMD", + "History", + "LogfermiWall", + "run_boxed_md", + "velocity_inversion", +] diff --git a/torch_sim/enhanced_sampling/boxed_md.py b/torch_sim/enhanced_sampling/boxed_md.py new file mode 100644 index 000000000..0b31b7417 --- /dev/null +++ b/torch_sim/enhanced_sampling/boxed_md.py @@ -0,0 +1,320 @@ +"""Boxed molecular dynamics in energy space (BXDE). + +BXDE accelerates the discovery of rare events by progressively raising a lower +bound on the accessible potential energy, forcing the system to climb into +high-energy regions it would otherwise visit only rarely. The method follows +Shannon et al., *J. Chem. Theory Comput.* **2018**, 14, 4541 +(10.1021/acs.jctc.8b00515). + +The scheme alternates free-sampling windows with adaptive boundary placement: + +1. Run MD freely, tracking the running maximum potential energy ``PE_max``. +2. Once at least ``i_samp`` steps have elapsed *and* a fresh maximum is reached, + freeze a reflective lower boundary at ``PE_max`` and begin a new window. +3. Thereafter, whenever a step would take the potential energy below the active + boundary, revert to the previous step and invert the velocity component + along the energy gradient, so the boundary bounds the energy from below only. + +Because the physical force always points toward lower potential energy -- into +the forbidden region -- the boundary is run inside a Langevin (NVT) integrator: +the stochastic force lets the trajectory random-walk along the boundary rather +than becoming trapped against it. + +This module is a trajectory-level controller, not a bias potential: it adds no +energy or forces and so is not a +:class:`~torch_sim.models.interface.ModelInterface`. Drive a whole run with +:func:`run_boxed_md`, or step box-by-box with :class:`BoxedMD` for finer control. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch + +from torch_sim.enhanced_sampling.history import History +from torch_sim.integrators.md import MDState +from torch_sim.integrators.nvt import nvt_langevin_init, nvt_langevin_step +from torch_sim.units import UnitSystem + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim.models.interface import ModelInterface + from torch_sim.state import SimState + from torch_sim.trajectory import TrajectoryReporter + + +EPS = 1e-12 + + +def velocity_inversion( + momenta: torch.Tensor, forces: torch.Tensor, masses: torch.Tensor +) -> torch.Tensor: + r"""Reflect the velocity component along the potential-energy gradient. + + Implements the BXD velocity inversion (Eqs 2-3 of 10.1021/acs.jctc.8b00515) + for a boundary in potential-energy space, where the constraint gradient + :math:`\nabla\phi = \nabla E = -F` is given directly by the forces. In terms + of momenta :math:`p = M v` the update is + + .. math:: + + p' = p - 2\,\frac{F \cdot v}{F \cdot M^{-1} F}\,F, + \qquad v = M^{-1} p + + which reverses the sign of :math:`\nabla\phi \cdot v` (so the system is + pushed back toward higher energy) while conserving kinetic energy. + + Args: + momenta: Particle momenta with shape [n_atoms, 3]. + forces: Forces on the particles with shape [n_atoms, 3]. + masses: Particle masses with shape [n_atoms]. + + Returns: + The reflected momenta with shape [n_atoms, 3]. + """ + inv_m = (1.0 / masses).unsqueeze(-1) # [n_atoms, 1] + f_dot_v = (forces * momenta * inv_m).sum() # F . v (single system) + f_m_f = (forces.pow(2) * inv_m).sum() # F . M^-1 . F + coeff = 2.0 * f_dot_v / (f_m_f + EPS) + return momenta - coeff * forces + + +class BoxedMD: + """Stateful BXDE controller that drives one energy box at a time. + + The controller owns everything that must persist across boxes: the active + lower boundary, the running ``PE_max`` and step counter of the current + sampling window, the record of placed boundaries, and the previous accepted + step used for roll-backs. Call :meth:`run_epoch` repeatedly; it advances one + box and returns control either when a new boundary is placed or when a step + budget is exhausted mid-box (in which case the next call resumes the same + box). + + Args: + model: Energy/force model passed to the inner integrator step. + i_samp: Minimum number of accepted steps in a window before a new + boundary may be placed. + dt: Integration timestep in the model's internal time units. + kT: Target temperature in energy units for the Langevin step. + gamma: Langevin friction coefficient forwarded to the step function. + Defaults to the integrator's own default (``1/(100*dt)``). + floor_capacity: Maximum number of placed boundaries to retain in the + history buffer. Defaults to 10000. + step_fn: Inner integrator step. Defaults to ``nvt_langevin_step``. + trajectory_reporter: Optional :class:`~torch_sim.trajectory.TrajectoryReporter` + whose ``report`` is called once per attempted step (indexed by the + cumulative attempted-step count), recording rolled-back steps too. + device: Device for the history/counters. Defaults to ``model.device``. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + """ + + NEW_BOUNDARY = "new_boundary" + STEP_LIMIT = "step_limit" + + def __init__( + self, + model: ModelInterface, + *, + i_samp: int, + dt: torch.Tensor | float, + kT: torch.Tensor | float, + gamma: torch.Tensor | float | None = None, + floor_capacity: int = 10000, + step_fn: Callable[..., MDState] = nvt_langevin_step, + trajectory_reporter: TrajectoryReporter | None = None, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + ) -> None: + """Initialize the BXDE controller.""" + if i_samp < 1: + raise ValueError(f"{i_samp=} must be >= 1") + self.model = model + self.step_fn = step_fn + self.trajectory_reporter = trajectory_reporter + self.i_samp = int(i_samp) + self.dt = dt + self.kT = kT + self.gamma = gamma + self._device = device or model.device + self._dtype = dtype + + self.floors = History( + capacity=floor_capacity, stride=1, device=self._device, dtype=self._dtype + ) + self.v_bxde: torch.Tensor | None = None # active lower boundary (scalar) + self.pe_max = self._neg_inf() + self.i = 0 # accepted steps in the current window + self.total_steps = 0 # attempted integrator steps over the whole run + self.n_inversions = 0 + + def _neg_inf(self) -> torch.Tensor: + return torch.tensor(float("-inf"), device=self._device, dtype=self._dtype) + + def _begin_new_window(self) -> None: + """Reset the per-window running max and counter, keeping the boundary.""" + self.pe_max = self._neg_inf() + self.i = 0 + + def reset(self) -> None: + """Clear all boundaries and counters, returning to a pristine state.""" + self.floors.reset() + self.v_bxde = None + self._begin_new_window() + self.total_steps = 0 + self.n_inversions = 0 + + def _report(self, state: MDState) -> None: + """Record the current state to the trajectory reporter, if configured.""" + if self.trajectory_reporter is not None: + self.trajectory_reporter.report(state, self.total_steps, self.model) + + def run_epoch(self, state: MDState, max_steps: int) -> tuple[MDState, int, str]: + """Advance the current energy box until a boundary is placed or budget runs out. + + Args: + state: Current MD state (single system). + max_steps: Maximum number of attempted integrator steps to take. + + Returns: + A tuple ``(state, n_used, status)`` where ``n_used`` is the number of + attempted steps taken (each costs one model call, including rolled-back + steps) and ``status`` is :attr:`NEW_BOUNDARY` if a boundary was placed + or :attr:`STEP_LIMIT` if the budget was reached first. + """ + used = 0 + while used < max_steps: + floor_active = self.v_bxde is not None + if floor_active: + # snapshot the accepted step so we can revert to it (time t) + prev_positions = state.positions.clone() + prev_momenta = state.momenta.clone() + prev_energy = state.energy.clone() + prev_forces = state.forces.clone() + + state = self.step_fn( + state, self.model, dt=self.dt, kT=self.kT, gamma=self.gamma + ) + used += 1 + self.total_steps += 1 + pe = state.energy.reshape(()) # scalar potential energy + + if floor_active and bool(pe < self.v_bxde): + # crossed below the boundary: revert to t and invert the velocity + state.positions = prev_positions + state.momenta = velocity_inversion( + prev_momenta, prev_forces, state.masses + ) + state.energy = prev_energy + state.forces = prev_forces + self.n_inversions += 1 + self._report(state) + continue + + # accepted step: track the running maximum and place a boundary once + # the window is long enough and a fresh maximum is reached + self.i += 1 + if bool(pe > self.pe_max): + if self.i > self.i_samp: + self.v_bxde = self.pe_max.clone() # freeze at the running max + self.floors.push(self.v_bxde) + self._begin_new_window() + self._report(state) + return state, used, self.NEW_BOUNDARY + self.pe_max = pe.clone() + + self._report(state) + + return state, used, self.STEP_LIMIT + + +def run_boxed_md( + state: SimState, + model: ModelInterface, + *, + n_steps: int, + i_samp: int, + timestep: float, + temperature: float, + gamma: torch.Tensor | float | None = None, + seed: int | None = None, + floor_capacity: int = 10000, + trajectory_reporter: TrajectoryReporter | None = None, + unit_system: UnitSystem = UnitSystem.metal, + **init_kwargs: Any, +) -> tuple[MDState, torch.Tensor]: + """Run boxed molecular dynamics in energy space for a single system. + + Wraps a :class:`BoxedMD` controller in a loop that respects a global step + budget: it advances box by box, and stops as soon as ``n_steps`` attempted + integrator steps have been taken -- interrupting mid-box if necessary. + + Args: + state: Initial system (single system). If not already an + :class:`~torch_sim.integrators.md.MDState`, momenta are sampled from + a Maxwell-Boltzmann distribution at ``temperature``. + model: Energy/force model. + n_steps: Total number of attempted integrator steps (the hard budget). + i_samp: Minimum accepted steps per window before a boundary may be placed. + timestep: Integration timestep in ``unit_system`` time units. + temperature: Target temperature in Kelvin (converted to energy units). + gamma: Optional Langevin friction coefficient. + seed: Optional RNG seed for momentum initialization and the thermostat. + floor_capacity: Maximum number of placed boundaries to retain. + trajectory_reporter: Optional reporter recording every attempted step + (including the initial frame at step 0). The caller owns its + lifecycle and is responsible for closing it. + unit_system: Unit system for temperature/timestep conversion. Defaults + to metal units (eV, Angstrom, ps). + **init_kwargs: Extra keyword arguments forwarded to ``nvt_langevin_init``. + + Returns: + A tuple ``(final_state, floors)`` where ``floors`` is a 1-D tensor of the + placed lower boundaries (in energy units), ordered and monotonically + non-decreasing. + + Raises: + ValueError: If ``state`` contains more than one system. + """ + if state.n_systems != 1: + raise ValueError( + f"run_boxed_md expects a single system, got {state.n_systems}" + ) + + device, dtype = state.device, state.dtype + kT = float(temperature) * unit_system.temperature + dt = torch.as_tensor(timestep * unit_system.time, dtype=dtype, device=device) + + if seed is not None: + state.rng = seed + if not isinstance(state, MDState): + state = nvt_langevin_init(state, model, kT=kT, **init_kwargs) + + controller = BoxedMD( + model, + i_samp=i_samp, + dt=dt, + kT=kT, + gamma=gamma, + floor_capacity=floor_capacity, + trajectory_reporter=trajectory_reporter, + device=device, + dtype=dtype, + ) + + if trajectory_reporter is not None: + trajectory_reporter.report(state, 0, model) # initial frame + + remaining = n_steps + while remaining > 0: + state, used, status = controller.run_epoch(state, remaining) + remaining -= used + if status == BoxedMD.STEP_LIMIT: + break + + floors = controller.floors.stack() + if floors is None: + floors = torch.empty(0, device=device, dtype=dtype) + return state, floors diff --git a/torch_sim/enhanced_sampling/history.py b/torch_sim/enhanced_sampling/history.py new file mode 100644 index 000000000..d792dfb81 --- /dev/null +++ b/torch_sim/enhanced_sampling/history.py @@ -0,0 +1,134 @@ +"""Trajectory history buffer for history-dependent enhanced-sampling methods. + +History-dependent biases accumulate a record of the trajectory and read it +back to build a bias energy: metadynamics deposits reference structures or +collective-variable values and later sums repulsive kernels over them, while +boundary methods track a quantity (an energy, a distance) over time. The +*what* and the energy mapping differ between methods, but the bookkeeping -- +a deposition cadence, a capacity limit that drops the oldest entries, and +restart-safe storage -- is identical. :class:`History` owns exactly that +bookkeeping so each method only decides what to deposit and how to use it. +""" + +from __future__ import annotations + +import torch + + +class History(torch.nn.Module): + """Rolling, capacity-bounded buffer of per-step quantities. + + Deposited values are stacked along a new leading axis, so every value + passed to :meth:`push` must share the same trailing shape. The first + ``capacity`` deposits grow the buffer; later deposits drop the oldest + entry. Stored values are detached and held in a registered buffer, so the + history moves with :meth:`~torch.nn.Module.to` and is captured by + ``state_dict`` for checkpoint/restart. + + A bias typically calls :meth:`maybe_push` once per forward pass and lets + the configured ``stride`` decide when a deposit actually happens; use + :meth:`push` for unconditional, manually controlled deposits (e.g. seeding + the buffer on the first step). + + Args: + capacity: Maximum number of stored entries; the oldest are dropped + once it is exceeded. Must be >= 1. + stride: Deposit on every ``stride``-th :meth:`maybe_push` call. Must + be >= 1. Defaults to 1 (deposit on every call). + device: Device for the stored buffer. Defaults to CPU. + dtype: Floating-point dtype of the stored buffer. Defaults to + ``torch.float64``. + + Example:: + + history = History(capacity=10, stride=5) + history.push(value) # unconditional seed + deposited = history.maybe_push(value) # True every 5th call + record = history.stack() # (n_stored, *value.shape) + """ + + data: torch.Tensor | None + + def __init__( + self, + capacity: int, + stride: int = 1, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + ) -> None: + """Initialize an empty history buffer.""" + super().__init__() + if capacity < 1: + raise ValueError(f"{capacity=} must be >= 1") + if stride < 1: + raise ValueError(f"{stride=} must be >= 1") + self.capacity = int(capacity) + self.stride = int(stride) + self._device = device or torch.device("cpu") + self._dtype = dtype + # registered (rather than a plain attribute) so it moves with .to() and + # is saved/restored via state_dict; starts None until the first deposit. + self.register_buffer("data", None) + self._n_calls = 0 + + @property + def is_empty(self) -> bool: + """Whether no value has been deposited yet.""" + return self.data is None + + def __len__(self) -> int: + """Number of currently stored entries.""" + return 0 if self.data is None else self.data.shape[0] + + @torch.no_grad() + def push(self, value: torch.Tensor) -> None: + """Deposit *value* as the newest entry, dropping the oldest past capacity. + + Args: + value: Tensor to store. Detached and cast to the buffer's device + and dtype. Its shape must match earlier deposits. + + Raises: + ValueError: If *value*'s shape differs from existing entries. A + history shaped to one batch cannot ingest a different one; + call :meth:`reset` before reusing it with another system. + """ + entry = value.detach().to(self._device, self._dtype).unsqueeze(0) + if self.data is None: + self.data = entry + return + if entry.shape[1:] != self.data.shape[1:]: + raise ValueError( + f"value shape {tuple(value.shape)} does not match stored entry " + f"shape {tuple(self.data.shape[1:])}; call reset() before reusing " + "this history with a different system." + ) + self.data = torch.cat([self.data, entry], dim=0)[-self.capacity :] + + def maybe_push(self, value: torch.Tensor) -> bool: + """Advance the call counter and deposit *value* on ``stride`` boundaries. + + Args: + value: Tensor to deposit if this call lands on a stride boundary. + + Returns: + ``True`` if a deposit happened, ``False`` otherwise. + """ + self._n_calls += 1 + if self._n_calls % self.stride == 0: + self.push(value) + return True + return False + + def stack(self) -> torch.Tensor | None: + """Return the stored record with shape ``(n_stored, *value.shape)``. + + Returns: + The stacked entries, or ``None`` if the buffer is empty. + """ + return self.data + + def reset(self) -> None: + """Clear all stored entries and the deposition counter.""" + self.data = None + self._n_calls = 0 diff --git a/torch_sim/metadynamics.py b/torch_sim/enhanced_sampling/metadynamics.py similarity index 90% rename from torch_sim/metadynamics.py rename to torch_sim/enhanced_sampling/metadynamics.py index 1b3b10ce4..f4de94fa6 100644 --- a/torch_sim/metadynamics.py +++ b/torch_sim/enhanced_sampling/metadynamics.py @@ -34,6 +34,7 @@ import torch +from torch_sim.enhanced_sampling.history import History from torch_sim.models.interface import ModelInterface from torch_sim.units import UnitConversion @@ -170,11 +171,13 @@ class RMSDCV(ModelInterface): This pushes the dynamics away from previously visited configurations. Idea and default parameters from 10.1021/acs.jctc.9b00143. - The buffer is seeded on the first call (which returns zero bias), and a - new reference is deposited every ``update_interval`` calls. TorchSim integrators - evaluate the model once per MD step (plus once at initialization), so - calls correspond to MD steps. Use :meth:`push_reference` and :meth:`reset` for manual - control of the buffer. + The references are held in a :class:`~torch_sim.enhanced_sampling.history.History` + buffer, which owns the deposition cadence and capacity limit. The buffer is + seeded on the first call (which returns zero bias), and a new reference is + deposited every ``update_interval`` calls. TorchSim integrators evaluate the + model once per MD step (plus once at initialization), so calls correspond to + MD steps. Use :meth:`push_reference` and :meth:`reset` for manual control of + the buffer. Forces are obtained by autograd through the alignment, so backprop through the SVD requires non-degenerate singular values (generic for molecular @@ -239,14 +242,22 @@ def __init__( self.atom_mask: torch.Tensor | None self.register_buffer("atom_mask", atom_mask) - # rolling buffer of centered references, shape (n_stored, n_biased_atoms, 3) - self.ref_buf: torch.Tensor | None = None - self._n_calls = 0 + # rolling buffer of centered references, each entry (n_biased_atoms, 3) + self.history = History( + capacity=self.n_refs, + stride=self.update_interval, + device=self._device, + dtype=self._dtype, + ) + + @property + def ref_buf(self) -> torch.Tensor | None: + """Stored references with shape (n_stored, n_biased_atoms, 3), or None.""" + return self.history.stack() def reset(self) -> None: - """Clear all stored references and the internal call counter.""" - self.ref_buf = None - self._n_calls = 0 + """Clear all stored references and the internal deposition counter.""" + self.history.reset() @torch.no_grad() def push_reference(self, state: SimState) -> None: @@ -255,15 +266,15 @@ def push_reference(self, state: SimState) -> None: Args: state: Simulation state whose (masked) positions are stored. """ + self.history.push(self._centered(state)) + + def _centered(self, state: SimState) -> torch.Tensor: + """Return the biased atoms' positions with each system's COM removed.""" positions, system_idx = self._masked(state) counts = torch.bincount(system_idx, minlength=state.n_systems) com = _segment_sum(positions, system_idx, state.n_systems) com = com / counts.unsqueeze(-1) - centered = (positions - com[system_idx]).unsqueeze(0) - if self.ref_buf is None: - self.ref_buf = centered - else: - self.ref_buf = torch.cat([self.ref_buf, centered], dim=0)[-self.n_refs :] + return positions - com[system_idx] def _masked(self, state: SimState) -> tuple[torch.Tensor, torch.Tensor]: """Return positions and system indices of biased atoms only.""" @@ -311,9 +322,8 @@ def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: """ n_systems = state.n_systems - if self.ref_buf is None: + if self.history.is_empty: self.push_reference(state) - self._n_calls = 1 zero_energy = torch.zeros( n_systems, device=state.positions.device, dtype=state.positions.dtype ) @@ -328,7 +338,7 @@ def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: with torch.enable_grad(): pos = masked_pos.detach().requires_grad_(requires_grad=True) # (n_biased, 3) - qc = self.ref_buf.to(pos) # (X, n_biased, 3), already centered + qc = self.history.stack().to(pos) # (X, n_biased, 3), already centered com = _segment_sum(pos, system_idx, n_systems) / counts.unsqueeze(-1) rc = pos - com[system_idx] @@ -350,9 +360,7 @@ def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: else: forces[self.atom_mask] = -grad - if self._n_calls % self.update_interval == 0: - self.push_reference(state) - self._n_calls += 1 + self.history.maybe_push(self._centered(state)) detached_energy = energy.detach() return { From abc7074992862649df0f24e640e98742b6d7ff18 Mon Sep 17 00:00:00 2001 From: amateurcat Date: Wed, 24 Jun 2026 17:57:31 -0700 Subject: [PATCH 3/4] 06242026@U52 add loxodynamics --- .../{ => enhanced_sampling}/test_boxed_md.py | 0 tests/enhanced_sampling/test_loxodynamics.py | 210 ++++++ .../test_metadynamics.py | 0 tests/enhanced_sampling/test_skewencoder.py | 108 +++ torch_sim/enhanced_sampling/__init__.py | 31 +- torch_sim/enhanced_sampling/loxodynamics.py | 641 ++++++++++++++++++ torch_sim/enhanced_sampling/metadynamics.py | 45 +- torch_sim/enhanced_sampling/skewencoder.py | 442 ++++++++++++ 8 files changed, 1457 insertions(+), 20 deletions(-) rename tests/{ => enhanced_sampling}/test_boxed_md.py (100%) create mode 100644 tests/enhanced_sampling/test_loxodynamics.py rename tests/{ => enhanced_sampling}/test_metadynamics.py (100%) create mode 100644 tests/enhanced_sampling/test_skewencoder.py create mode 100644 torch_sim/enhanced_sampling/loxodynamics.py create mode 100644 torch_sim/enhanced_sampling/skewencoder.py diff --git a/tests/test_boxed_md.py b/tests/enhanced_sampling/test_boxed_md.py similarity index 100% rename from tests/test_boxed_md.py rename to tests/enhanced_sampling/test_boxed_md.py diff --git a/tests/enhanced_sampling/test_loxodynamics.py b/tests/enhanced_sampling/test_loxodynamics.py new file mode 100644 index 000000000..d9a583031 --- /dev/null +++ b/tests/enhanced_sampling/test_loxodynamics.py @@ -0,0 +1,210 @@ +import pytest +import torch +from ase.build import molecule + +import torch_sim as ts +from torch_sim.enhanced_sampling.loxodynamics import ( + LoxodynamicsWall, + PairDistanceDescriptor, + run_loxodynamics, +) +from torch_sim.enhanced_sampling.skewencoder import ( + Skewencoder, + SkewencoderConfig, + fit_descriptor_normalizer, +) +from torch_sim.models.interface import ModelInterface + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class HarmonicModel(ModelInterface): + """Per-atom harmonic well E = 1/2 k sum (x - x0)^2 (single bound system).""" + + def __init__( + self, centers: torch.Tensor, k: float = 1.0, dtype: torch.dtype = DTYPE + ) -> None: + super().__init__() + self._device = DEVICE + self._dtype = dtype + self._compute_forces = True + self._compute_stress = False + self.k = float(k) + self.register_buffer("centers", centers.to(DEVICE, dtype)) + + def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: + disp = state.positions - self.centers + e_atom = 0.5 * self.k * disp.pow(2).sum(dim=-1) + energy = torch.zeros(state.n_systems, device=DEVICE, dtype=self._dtype).index_add( + 0, state.system_idx, e_atom + ) + return {"energy": energy, "forces": -self.k * disp} + + +@pytest.fixture +def water_state() -> ts.SimState: + return ts.io.atoms_to_state([molecule("H2O")], DEVICE, DTYPE) + + +def _all_pairs_3() -> torch.Tensor: + return torch.tensor([[0, 1], [0, 2], [1, 2]], dtype=torch.long) + + +class TestPairDistanceDescriptor: + def test_correct_distances(self) -> None: + positions = torch.tensor( + [[0.0, 0.0, 0.0], [3.0, 0.0, 0.0], [0.0, 4.0, 0.0]], dtype=DTYPE + ) + desc = PairDistanceDescriptor(_all_pairs_3()) + assert desc.n_descriptors == 3 + d = desc(positions) + torch.testing.assert_close(d, torch.tensor([3.0, 4.0, 5.0], dtype=DTYPE)) + + def test_gradients_flow(self) -> None: + positions = torch.randn(3, 3, dtype=DTYPE, requires_grad=True) + desc = PairDistanceDescriptor(_all_pairs_3()) + desc(positions).sum().backward() + assert positions.grad is not None + assert torch.isfinite(positions.grad).all() + + def test_bad_shape_raises(self) -> None: + with pytest.raises(ValueError, match="n_pairs, 2"): + PairDistanceDescriptor(torch.tensor([0, 1, 2], dtype=torch.long)) + + +class TestLoxodynamicsWall: + def _wall(self, *, offset: float) -> tuple[LoxodynamicsWall, int]: + torch.manual_seed(0) + desc = PairDistanceDescriptor(_all_pairs_3()) + cfg = SkewencoderConfig(input_dim=3, hidden_dims=(8, 4)) + enc = Skewencoder(cfg).to(DTYPE) + sample = torch.randn(20, 3, dtype=DTYPE).abs() + 1.0 + norm = fit_descriptor_normalizer(sample) + wall = LoxodynamicsWall( + desc, + enc, + norm, + mu=0.0, + sigma=1.0, # standardized wall divides by sigma; use a unit scale + skewness=1.0, # positive -> sign +1, lower wall + kappa=1.0, + offset=offset, + device=DEVICE, + dtype=DTYPE, + ) + return wall, 3 + + def test_shapes(self, water_state: ts.SimState) -> None: + wall, _ = self._wall(offset=1.0) + out = wall(water_state) + assert out["energy"].shape == (1,) + assert out["forces"].shape == (water_state.n_atoms, 3) + + def test_zero_beyond_wall(self, water_state: ts.SimState) -> None: + # a very negative offset puts the CV beyond the wall -> no penalty + wall, _ = self._wall(offset=-1000.0) + out = wall(water_state) + assert out["energy"].item() == 0.0 + assert out["forces"].abs().max() == 0.0 + + def test_positive_when_violating(self, water_state: ts.SimState) -> None: + # a large positive offset forces a violation -> positive penalty + wall, _ = self._wall(offset=1000.0) + out = wall(water_state) + assert out["energy"].item() > 0.0 + + def test_rejects_multiple_systems(self) -> None: + two = ts.io.atoms_to_state([molecule("H2O"), molecule("H2O")], DEVICE, DTYPE) + wall, _ = self._wall(offset=1.0) + with pytest.raises(ValueError, match="single system"): + wall(two) + + +class TestRunLoxodynamics: + def _setup( + self, water_state: ts.SimState + ) -> tuple[HarmonicModel, PairDistanceDescriptor, SkewencoderConfig]: + model = HarmonicModel(water_state.positions.clone(), k=2.0) + desc = PairDistanceDescriptor(_all_pairs_3()) + cfg = SkewencoderConfig( + input_dim=3, + hidden_dims=(8, 4), + max_epochs=5, + batch_size=16, + early_stopping_patience=3, + ) + return model, desc, cfg + + def test_runs_to_budget(self, water_state: ts.SimState) -> None: + model, desc, cfg = self._setup(water_state) + result = run_loxodynamics( + water_state, + model, + descriptor=desc, + max_steps=20, + segment_steps=5, + initial_unbiased_steps=5, + timestep=0.0005, + temperature=300.0, + sample_stride=1, + min_local_samples=3, + seed=0, + skewencoder_config=cfg, + ) + assert result.total_steps == 20 + assert len(result.training_reports) >= 1 + assert result.global_descriptors.shape[1] == 3 + assert result.global_descriptors.shape[0] > 0 + assert len(result.wall_stats) == len(result.training_reports) + + def test_rejects_multiple_systems(self) -> None: + two = ts.io.atoms_to_state([molecule("H2O"), molecule("H2O")], DEVICE, DTYPE) + model = HarmonicModel(two.positions.clone(), k=2.0) + desc = PairDistanceDescriptor(_all_pairs_3()) + with pytest.raises(ValueError, match="single system"): + run_loxodynamics( + two, + model, + descriptor=desc, + max_steps=10, + segment_steps=5, + timestep=0.0005, + temperature=300.0, + min_local_samples=3, + ) + + +class TestExecutorDtype: + def test_float32_model_end_to_end(self) -> None: + # With no explicit dtype the executor adopts the model's dtype, so the + # latent wall composes with a float32 model under SumModel and the whole + # run stays in float32. + f32_state = ts.io.atoms_to_state([molecule("H2O")], DEVICE, torch.float32) + model = HarmonicModel(f32_state.positions.clone(), k=2.0, dtype=torch.float32) + cfg = SkewencoderConfig( + input_dim=3, + hidden_dims=(8, 4), + max_epochs=5, + batch_size=16, + early_stopping_patience=3, + ) + result = run_loxodynamics( + f32_state, + model, + descriptor=PairDistanceDescriptor(_all_pairs_3()), + max_steps=20, + segment_steps=5, + initial_unbiased_steps=5, + timestep=0.0005, + temperature=300.0, + sample_stride=1, + min_local_samples=3, + seed=0, + skewencoder_config=cfg, + ) + assert result.total_steps == 20 + assert result.global_descriptors.dtype == torch.float32 + assert result.final_state.positions.dtype == torch.float32 + assert len(result.training_reports) >= 1 diff --git a/tests/test_metadynamics.py b/tests/enhanced_sampling/test_metadynamics.py similarity index 100% rename from tests/test_metadynamics.py rename to tests/enhanced_sampling/test_metadynamics.py diff --git a/tests/enhanced_sampling/test_skewencoder.py b/tests/enhanced_sampling/test_skewencoder.py new file mode 100644 index 000000000..c3203f247 --- /dev/null +++ b/tests/enhanced_sampling/test_skewencoder.py @@ -0,0 +1,108 @@ +import pytest +import torch + +from torch_sim.enhanced_sampling.skewencoder import ( + DescriptorNormalizer, + Skewencoder, + SkewencoderConfig, + SkewencoderTrainer, + fit_descriptor_normalizer, + skewencoder_loss, + skewness_1d, +) + + +DTYPE = torch.float64 + + +def _small_config(input_dim: int = 4) -> SkewencoderConfig: + return SkewencoderConfig( + input_dim=input_dim, + hidden_dims=(8, 4), + max_epochs=5, + batch_size=16, + early_stopping_patience=3, + ) + + +class TestSkewness: + def test_symmetric_near_zero(self) -> None: + values = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=DTYPE) + assert skewness_1d(values).abs() < 1e-9 + + def test_right_tail_positive(self) -> None: + values = torch.tensor([1.0, 1.0, 1.0, 1.0, 10.0], dtype=DTYPE) + assert skewness_1d(values) > 0 + + def test_left_tail_negative(self) -> None: + values = torch.tensor([-1.0, -1.0, -1.0, -1.0, -10.0], dtype=DTYPE) + assert skewness_1d(values) < 0 + + def test_differentiable(self) -> None: + values = torch.randn(50, dtype=DTYPE, requires_grad=True) + skewness_1d(values).backward() + assert values.grad is not None + assert torch.isfinite(values.grad).all() + + +class TestSkewencoder: + def test_output_shapes(self) -> None: + model = Skewencoder(_small_config(input_dim=5)).to(DTYPE) + x = torch.randn(8, 5, dtype=DTYPE) + recon, latent = model(x) + assert recon.shape == (8, 5) + assert latent.shape == (8, 1) + assert model.encode(x).shape == (8, 1) + + def test_latent_dim_must_be_one(self) -> None: + with pytest.raises(ValueError, match="latent_dim"): + SkewencoderConfig(input_dim=4, latent_dim=2) + + +class TestSkewencoderLoss: + def test_finite_and_diagnostics(self) -> None: + cfg = _small_config(input_dim=4) + model = Skewencoder(cfg).to(DTYPE) + global_x = torch.randn(32, 4, dtype=DTYPE) + local_x = torch.randn(12, 4, dtype=DTYPE) + loss, diag = skewencoder_loss( + model, global_x, local_x, alpha=cfg.alpha, beta=cfg.beta, eps=cfg.eps + ) + assert torch.isfinite(loss) + assert set(diag) == { + "loss_total", + "loss_reconstruction", + "loss_skew", + "loss_l2", + "local_skewness", + } + + +class TestNormalizer: + def test_roundtrip(self) -> None: + x = torch.randn(100, 4, dtype=DTYPE) * 3 + 5 + norm = fit_descriptor_normalizer(x) + assert isinstance(norm, DescriptorNormalizer) + z = norm.transform(x) + torch.testing.assert_close( + z.mean(dim=0), torch.zeros(4, dtype=DTYPE), atol=1e-6, rtol=0 + ) + torch.testing.assert_close(norm.inverse_transform(z), x, atol=1e-9, rtol=0) + + +class TestTrainer: + def test_trains_and_encodes(self) -> None: + torch.manual_seed(0) + cfg = _small_config(input_dim=4) + model = Skewencoder(cfg).to(DTYPE) + trainer = SkewencoderTrainer(cfg) + global_x = torch.randn(64, 4, dtype=DTYPE) + local_x = torch.randn(20, 4, dtype=DTYPE) + + normalizer, report = trainer.train(model, global_x, local_x) + assert isinstance(normalizer, DescriptorNormalizer) + assert 1 <= report.n_epochs <= cfg.max_epochs + assert report.final_loss == report.final_loss # not NaN + + latent = model.encode(normalizer.transform(local_x)) + assert latent.shape == (20, 1) diff --git a/torch_sim/enhanced_sampling/__init__.py b/torch_sim/enhanced_sampling/__init__.py index e94f2f2bf..3ca341d2e 100644 --- a/torch_sim/enhanced_sampling/__init__.py +++ b/torch_sim/enhanced_sampling/__init__.py @@ -5,19 +5,48 @@ deposition/capacity bookkeeping shared by history-dependent biases. The :mod:`~torch_sim.enhanced_sampling.metadynamics` module implements bias potentials (:class:`LogfermiWall`, :class:`RMSDCV`) that compose with any MLIP -through :class:`~torch_sim.models.interface.SumModel`. +through :class:`~torch_sim.models.interface.SumModel`. The +:mod:`~torch_sim.enhanced_sampling.boxed_md` module implements boxed MD in energy +space (BXDE), and :mod:`~torch_sim.enhanced_sampling.loxodynamics` implements the +skewness-guided latent-space Loxodynamics method. """ from torch_sim.enhanced_sampling.boxed_md import BoxedMD, run_boxed_md, velocity_inversion from torch_sim.enhanced_sampling.history import History +from torch_sim.enhanced_sampling.loxodynamics import ( + LoxodynamicsExecutor, + LoxodynamicsResult, + LoxodynamicsWall, + LoxodynamicsWallStats, + PairDistanceDescriptor, + run_loxodynamics, +) from torch_sim.enhanced_sampling.metadynamics import RMSDCV, LogfermiWall +from torch_sim.enhanced_sampling.skewencoder import ( + DescriptorNormalizer, + Skewencoder, + SkewencoderConfig, + SkewencoderTrainer, + SkewencoderTrainingReport, +) __all__ = [ "RMSDCV", "BoxedMD", + "DescriptorNormalizer", "History", "LogfermiWall", + "LoxodynamicsExecutor", + "LoxodynamicsResult", + "LoxodynamicsWall", + "LoxodynamicsWallStats", + "PairDistanceDescriptor", + "Skewencoder", + "SkewencoderConfig", + "SkewencoderTrainer", + "SkewencoderTrainingReport", "run_boxed_md", + "run_loxodynamics", "velocity_inversion", ] diff --git a/torch_sim/enhanced_sampling/loxodynamics.py b/torch_sim/enhanced_sampling/loxodynamics.py new file mode 100644 index 000000000..5449b158e --- /dev/null +++ b/torch_sim/enhanced_sampling/loxodynamics.py @@ -0,0 +1,641 @@ +"""Loxodynamics: skewness-guided latent-space enhanced sampling. + +Loxodynamics explores reactions without a predefined product or hand-built +reaction coordinate. It learns a one-dimensional latent collective variable from +structural descriptors with a +:class:`~torch_sim.enhanced_sampling.skewencoder.Skewencoder`, +reads the *skewness* of the local latent distribution to decide which way the +basin has a low-barrier exit, and erects a half-harmonic wall in latent +space that pushes the system that way. After each biased MD segment it appends +the new samples to a global buffer, warm-start retrains the same Skewencoder, +and rebuilds the wall, iterating until the total step budget is spent. + +This first version is an iterative biasing/exploration executor only: no +product-state detection, restart protocol, anti-backtracking, BXD boundaries, +PMF reconstruction, or multi-walker swarms. It requires a single system +(``state.n_systems == 1``). See ``instruction.md`` for the full scope. + +Idea & default parameters are from 10.1038/s41467-026-69586-8. +""" + +from __future__ import annotations + +import time +from dataclasses import asdict, dataclass, field, is_dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import torch + +from torch_sim.enhanced_sampling.skewencoder import ( + DescriptorNormalizer, + Skewencoder, + SkewencoderConfig, + SkewencoderTrainer, + SkewencoderTrainingReport, + skewness_1d, +) +from torch_sim.integrators.md import MDState +from torch_sim.integrators.nvt import nvt_langevin_init, nvt_langevin_step +from torch_sim.models.interface import ModelInterface, SumModel +from torch_sim.units import UnitSystem + + +if TYPE_CHECKING: + from torch_sim.state import SimState + from torch_sim.trajectory import TrajectoryReporter + + +EPS = 1e-12 + + +class PairDistanceDescriptor(torch.nn.Module): + """Differentiable interatomic pair-distance descriptor (no PBC in v0). + + Args: + pairs: Long tensor of atom index pairs with shape ``[n_pairs, 2]``. + eps: Numerical-stability floor added under the square root. + """ + + pairs: torch.Tensor + + def __init__(self, pairs: torch.Tensor, *, eps: float = 1.0e-12) -> None: + """Validate and store the pair list.""" + super().__init__() + if pairs.ndim != 2 or pairs.shape[1] != 2: + raise ValueError( + f"pairs must have shape [n_pairs, 2], got {tuple(pairs.shape)}" + ) + if pairs.shape[0] < 1: + raise ValueError("pairs must contain at least one pair") + self.eps = float(eps) + self.register_buffer("pairs", pairs.long()) + + @property + def n_descriptors(self) -> int: + """Number of pair distances produced.""" + return int(self.pairs.shape[0]) + + def forward(self, positions: torch.Tensor) -> torch.Tensor: + """Compute pair distances ``[n_pairs]`` from positions ``[n_atoms, 3]``.""" + if positions.ndim != 2 or positions.shape[1] != 3: + raise ValueError( + f"positions must have shape [n_atoms, 3], got {tuple(positions.shape)}" + ) + i_idx, j_idx = self.pairs[:, 0], self.pairs[:, 1] + diff = positions[i_idx] - positions[j_idx] + return torch.sqrt(diff.pow(2).sum(dim=-1) + self.eps) + + +@dataclass +class LoxodynamicsWallStats: + """Latent-space statistics and wall placement for one iteration.""" + + iteration: int + mu: float + sigma: float + skewness: float + sign: float + boundary: float + n_local_samples: int + n_global_samples: int + + +class LoxodynamicsWall(ModelInterface): + """Half-harmonic wall in the *standardized* 1-D latent CV. + + The wall acts on the standardized latent ``z = (s - mu) / sigma``, where + ``s = encode(normalize(descriptor))`` and ``mu``/``sigma`` are the local + latent mean/std at build time. It confines ``z`` on the side opposite the + skew tail, pushing the system toward the basin's low-barrier exit: + + scaled_z = sign * (s - mu) / sigma + violation = relu(offset - scaled_z) + E_wall = kappa * violation**2 + + .. note:: + This **standardized** wall is a deliberate departure from the original + Loxodynamics method (10.1038/s41467-026-69586-8), which places the + half-harmonic wall directly on the **raw** latent ``s`` (``offset`` in + raw latent units). We found the raw-latent wall numerically unstable: + the latent ``s`` has an arbitrary, unanchored scale -- the encoder's + output magnitude drifts or jumps across warm-start retrains -- so a raw + wall ``relu(mu + sigma + offset - sign*s)`` injects an energy + ``~kappa*(sigma + offset)**2`` that diverges as ``sigma`` grows large, + producing a force spike that detonates the integrator (observed e.g. on + rigid molecules whose autoencoder escapes a degenerate solution into a + large-magnitude latent). Standardizing by ``sigma`` removes this: the + initial violation is ``offset`` and the initial energy is + ``kappa*offset**2`` regardless of ``sigma``, and the force scales as + ``(1/sigma) * ds/dx`` so the encoder's scale cancels. Consequently + ``offset`` here is in units of ``sigma`` (standard deviations past the + mean), not raw latent units as in the paper. + + Args: + descriptor: Pair-distance descriptor. + skewencoder: Trained Skewencoder (used in eval mode, parameters frozen). + normalizer: Descriptor normalizer fit on the global buffer. + mu: Local latent mean. + sigma: Local latent standard deviation. + skewness: Local latent skewness (sets the wall orientation). + kappa: Harmonic wall force constant. + offset: Wall margin in units of the local latent std ``sigma`` (i.e. how + many ``sigma`` past the mean to push the CV before the wall relaxes). + min_abs_skew: Below this magnitude the skew sign is treated as positive. + min_sigma: Floor applied to ``sigma`` to keep the standardization well + defined when the latent is nearly degenerate. + energy_label: Non-canonical output key for this wall's energy. + cv_label: Non-canonical output key for the (raw) latent collective variable. + device: Computation device. Defaults to CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + """ + + def __init__( + self, + descriptor: PairDistanceDescriptor, + skewencoder: Skewencoder, + normalizer: DescriptorNormalizer, + *, + mu: torch.Tensor | float, + sigma: torch.Tensor | float, + skewness: torch.Tensor | float, + kappa: torch.Tensor | float, + offset: torch.Tensor | float, + min_abs_skew: float = 1.0e-8, + min_sigma: float = 1.0e-6, + energy_label: str = "PE_LoxodynamicsWall", + cv_label: str = "loxo_cv", + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + ) -> None: + """Build the wall from trained components and latent statistics.""" + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = True + self._compute_stress = False + self._memory_scales_with = "n_atoms" + self.energy_label = str(energy_label) + self.cv_label = str(cv_label) + + self.descriptor = descriptor.to(self._device) + # Eval mode only; do not freeze parameters -- the encoder instance is + # shared with the executor and warm-start retrained between segments. + # Forces use autograd.grad(energy, [positions]), which differentiates + # w.r.t. positions only, so parameter requires_grad is irrelevant here. + self.skewencoder = skewencoder.to(self._device, self._dtype).eval() + self.normalizer = normalizer.to(self._device, self._dtype) + + mu_t = torch.as_tensor(mu, device=self._device, dtype=self._dtype) + sigma_t = torch.as_tensor( + sigma, device=self._device, dtype=self._dtype + ).clamp_min(min_sigma) + skew_t = torch.as_tensor(skewness, device=self._device, dtype=self._dtype) + offset_t = torch.as_tensor(offset, device=self._device, dtype=self._dtype) + sign = torch.where( + skew_t.abs() < min_abs_skew, + torch.ones((), device=self._device, dtype=self._dtype), + torch.sign(skew_t), + ) + # Raw-latent value where the wall relaxes (z = offset), for reporting. + boundary = mu_t + sign * offset_t * sigma_t + self.register_buffer("sign", sign) + self.register_buffer("mu", mu_t) + self.register_buffer("sigma", sigma_t) + self.register_buffer("offset", offset_t) + self.register_buffer("boundary", boundary) + self.register_buffer( + "kappa", torch.as_tensor(kappa, device=self._device, dtype=self._dtype) + ) + + def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: + """Compute wall energy ``[1]`` and forces ``[n_atoms, 3]`` from positions.""" + if state.n_systems != 1: + raise ValueError( + f"LoxodynamicsWall supports a single system, got {state.n_systems}" + ) + with torch.enable_grad(): + pos = state.positions.detach().requires_grad_(requires_grad=True) + desc = self.descriptor(pos) + norm_desc = self.normalizer.transform(desc) + latent = self.skewencoder.encode(norm_desc.unsqueeze(0)).reshape(()) + scaled_z = self.sign * (latent - self.mu) / self.sigma + violation = torch.relu(self.offset - scaled_z) + energy_scalar = self.kappa * violation.pow(2) + grad = torch.autograd.grad(energy_scalar, pos)[0] + + forces = -grad + energy = energy_scalar.detach().reshape(1) + return { + "energy": energy, + "forces": forces.detach(), + self.energy_label: energy, + # The latent CV (raw, un-signed) so the trajectory can record the + # collective variable the wall acts on, step by step. + self.cv_label: latent.detach().reshape(1), + } + + +@dataclass +class LoxodynamicsResult: + """Outcome of a Loxodynamics run.""" + + final_state: MDState + wall_stats: list[LoxodynamicsWallStats] + training_reports: list[SkewencoderTrainingReport] + global_descriptors: torch.Tensor + total_steps: int + segment_times: list[float] = field(default_factory=list) + """Wall-clock seconds per MD segment; index 0 is the initial unbiased + segment, index ``i + 1`` is the biased segment that follows iteration ``i``.""" + + +class LoxodynamicsExecutor: + """Trajectory-level Loxodynamics controller (single system). + + Runs an initial unbiased segment, then repeatedly trains the Skewencoder, + builds a latent wall, and runs a biased segment, until ``max_steps`` + attempted MD steps have been taken. + + Args: + model: Base energy/force model. + descriptor: Pair-distance descriptor. + max_steps: Total attempted-step budget (the only stopping criterion). + segment_steps: Steps per biased segment. + timestep: Timestep in ``unit_system`` time units. + temperature: Temperature in Kelvin. + initial_unbiased_steps: Steps in the first unbiased segment. Defaults to + ``segment_steps``. + sample_stride: Collect a descriptor sample every this many steps. + kappa: Wall force constant. + wall_offset: Wall margin in units of the local latent std ``sigma`` -- + how many ``sigma`` past the mean to push the standardized CV before + the wall relaxes. Defaults to 1.0 (one standard deviation). + global_buffer_capacity: Max global descriptor samples retained (newest). + min_local_samples: Minimum local samples required before training. + gamma: Optional Langevin friction. + seed: Optional RNG seed. + unit_system: Unit system for conversions. Defaults to metal units. + skewencoder_config: Optional config; one is built from the descriptor if + omitted. + trajectory_reporter: Optional reporter recording every attempted step. + checkpoint_dir: If set, write a Skewencoder checkpoint after each retrain + to ``/skewencoder_iter.pt`` (model weights, config, + normalizer, latent stats, and training report). The directory is + created if needed. + device: Computation device. Defaults to ``model.device``. + dtype: Working dtype for the Skewencoder, wall, and descriptor buffers. + Defaults to ``model.dtype`` so the wall composes with the base model + under :class:`~torch_sim.models.interface.SumModel`. + verbose: If True, print a progress hint before each segment/retrain and a + timing summary at the end. + **init_kwargs: Extra arguments forwarded to ``nvt_langevin_init``. + """ + + def __init__( + self, + model: ModelInterface, + descriptor: PairDistanceDescriptor, + *, + max_steps: int, + segment_steps: int, + timestep: float, + temperature: float, + initial_unbiased_steps: int | None = None, + sample_stride: int = 1, + kappa: float = 1.0, + wall_offset: float = 1.0, + global_buffer_capacity: int = 50000, + min_local_samples: int = 10, + gamma: torch.Tensor | float | None = None, + seed: int | None = None, + unit_system: UnitSystem = UnitSystem.metal, + skewencoder_config: SkewencoderConfig | None = None, + trajectory_reporter: TrajectoryReporter | None = None, + checkpoint_dir: str | Path | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + verbose: bool = False, + **init_kwargs: Any, + ) -> None: + """Validate arguments and store configuration.""" + if max_steps < 1: + raise ValueError(f"{max_steps=} must be >= 1") + if segment_steps < 1: + raise ValueError(f"{segment_steps=} must be >= 1") + if sample_stride < 1: + raise ValueError(f"{sample_stride=} must be >= 1") + if min_local_samples < 1: + raise ValueError(f"{min_local_samples=} must be >= 1") + + if skewencoder_config is None: + skewencoder_config = SkewencoderConfig(input_dim=descriptor.n_descriptors) + if descriptor.n_descriptors != skewencoder_config.input_dim: + raise ValueError( + f"descriptor.n_descriptors ({descriptor.n_descriptors}) != " + f"skewencoder_config.input_dim ({skewencoder_config.input_dim})" + ) + + self.model = model + self.descriptor = descriptor + self.config = skewencoder_config + self.max_steps = int(max_steps) + self.segment_steps = int(segment_steps) + self.timestep = timestep + self.temperature = temperature + self.initial_unbiased_steps = ( + int(initial_unbiased_steps) + if initial_unbiased_steps is not None + else int(segment_steps) + ) + self.sample_stride = int(sample_stride) + self.kappa = kappa + self.wall_offset = wall_offset + self.global_buffer_capacity = int(global_buffer_capacity) + self.min_local_samples = int(min_local_samples) + self.gamma = gamma + self.seed = seed + self.unit_system = unit_system + self.trajectory_reporter = trajectory_reporter + self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir is not None else None + self.verbose = bool(verbose) + self._device = device or model.device + self._dtype = dtype if dtype is not None else model.dtype + self.init_kwargs = init_kwargs + + def _run_segment( + self, + state: MDState, + model: ModelInterface, + n_steps: int, + step_offset: int, + dt: torch.Tensor, + kT: float, + ) -> tuple[MDState, list[torch.Tensor], int]: + """Run ``n_steps`` Langevin steps, collecting descriptors and reporting.""" + local: list[torch.Tensor] = [] + for i in range(n_steps): + state = nvt_langevin_step(state, model, dt=dt, kT=kT, gamma=self.gamma) + if (i + 1) % self.sample_stride == 0: + local.append(self.descriptor(state.positions).detach()) + if self.trajectory_reporter is not None: + self.trajectory_reporter.report(state, step_offset + i + 1, self.model) + return state, local, n_steps + + def _extend_global( + self, global_buf: list[torch.Tensor], new: list[torch.Tensor] + ) -> None: + """Append new descriptors and keep only the newest ``capacity`` samples.""" + global_buf.extend(d.to(self._device, self._dtype) for d in new) + if len(global_buf) > self.global_buffer_capacity: + del global_buf[: len(global_buf) - self.global_buffer_capacity] + + def _save_checkpoint( + self, + iteration: int, + skewencoder: Skewencoder, + normalizer: DescriptorNormalizer, + stats: LoxodynamicsWallStats, + report: SkewencoderTrainingReport, + ) -> Path: + """Snapshot the (warm-started) Skewencoder and its context to disk.""" + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + path = self.checkpoint_dir / f"skewencoder_iter{iteration}.pt" + config = asdict(self.config) if is_dataclass(self.config) else self.config + torch.save( + { + "iteration": iteration, + "skewencoder_state_dict": { + k: v.detach().cpu() for k, v in skewencoder.state_dict().items() + }, + "skewencoder_config": config, + "normalizer": { + "mean": normalizer.mean.detach().cpu(), + "std": normalizer.std.detach().cpu(), + "eps": normalizer.eps, + }, + "wall_stats": asdict(stats), + "training_report": asdict(report), + }, + path, + ) + return path + + def run(self, state: SimState | MDState) -> LoxodynamicsResult: # noqa: C901, PLR0915 + """Execute the Loxodynamics loop and return the result.""" + if state.n_systems != 1: + raise ValueError( + f"Loxodynamics supports a single system, got {state.n_systems}" + ) + + kT = float(self.temperature) * self.unit_system.temperature + dt = torch.as_tensor( + self.timestep * self.unit_system.time, + device=self._device, + dtype=self._dtype, + ) + + if self.seed is not None: + state.rng = self.seed + if not isinstance(state, MDState): + state = nvt_langevin_init(state, self.model, kT=kT, **self.init_kwargs) + if self.trajectory_reporter is not None: + self.trajectory_reporter.report(state, 0, self.model) + + skewencoder = Skewencoder(self.config).to(self._device, self._dtype) + trainer = SkewencoderTrainer(self.config) + + global_buf: list[torch.Tensor] = [] + wall_stats: list[LoxodynamicsWallStats] = [] + training_reports: list[SkewencoderTrainingReport] = [] + segment_times: list[float] = [] + + t_run = time.perf_counter() + total_steps = 0 + # initial unbiased segment + n0 = min(self.initial_unbiased_steps, self.max_steps) + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] Now start unbiased sampling: {n0} steps at " + f"{self.temperature:g} K, dt={self.timestep:g}", + flush=True, + ) + t_seg = time.perf_counter() + state, local, used = self._run_segment(state, self.model, n0, total_steps, dt, kT) + segment_times.append(time.perf_counter() - t_seg) + total_steps += used + self._extend_global(global_buf, local) + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] unbiased segment done: {used} steps in " + f"{segment_times[-1]:.1f} s", + flush=True, + ) + + iteration = 0 + while total_steps < self.max_steps: + if len(local) < self.min_local_samples: + raise ValueError( + f"collected only {len(local)} local samples (< " + f"{self.min_local_samples}); increase segment length or " + "decrease sample_stride / min_local_samples" + ) + + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] iter {iteration}: retraining Skewencoder on " + f"{len(global_buf)} global / {len(local)} local samples ...", + flush=True, + ) + global_t = torch.stack(global_buf).to(self._device, self._dtype) + local_t = torch.stack(local).to(self._device, self._dtype) + _normalizer, report = trainer.train(skewencoder, global_t, local_t) + training_reports.append(report) + + with torch.no_grad(): + latent = skewencoder.encode(_normalizer.transform(local_t)).reshape(-1) + mu = latent.mean() + sigma = latent.std(unbiased=False) + skew = skewness_1d(latent, eps=self.config.eps) + + wall = LoxodynamicsWall( + self.descriptor, + skewencoder, + _normalizer, + mu=mu, + sigma=sigma, + skewness=skew, + kappa=self.kappa, + offset=self.wall_offset, + device=self._device, + dtype=self._dtype, + ) + wall_stats.append( + LoxodynamicsWallStats( + iteration=iteration, + mu=float(mu), + sigma=float(sigma), + skewness=float(skew), + sign=float(wall.sign), + boundary=float(wall.boundary), + n_local_samples=len(local), + n_global_samples=len(global_buf), + ) + ) + if self.checkpoint_dir is not None: + ckpt_path = self._save_checkpoint( + iteration, skewencoder, _normalizer, wall_stats[-1], report + ) + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] saved model checkpoint: {ckpt_path}", flush=True + ) + iteration += 1 + + remaining = self.max_steps - total_steps + if remaining <= 0: + break + seg = min(self.segment_steps, remaining) + if self.verbose: + w = wall_stats[-1] + print( # noqa: T201 + f"[loxodynamics] Now start loxodynamics with wall settings: " + f"iter={w.iteration}, sign={w.sign:+.0f}, mu={w.mu:.4g}, " + f"sigma={w.sigma:.4g}, skewness={w.skewness:+.4g}, " + f"boundary={w.boundary:.4g}, kappa={float(self.kappa):g}, " + f"offset={float(self.wall_offset):g}; {seg} steps", + flush=True, + ) + biased_model = SumModel(self.model, wall) + t_seg = time.perf_counter() + state, local, used = self._run_segment( + state, biased_model, seg, total_steps, dt, kT + ) + segment_times.append(time.perf_counter() - t_seg) + total_steps += used + self._extend_global(global_buf, local) + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] biased segment (iter {wall_stats[-1].iteration}) " + f"done: {used} steps in {segment_times[-1]:.1f} s", + flush=True, + ) + + if self.verbose: + print( # noqa: T201 + f"[loxodynamics] simulation finished: {total_steps} steps, " + f"{len(wall_stats)} walls, in {time.perf_counter() - t_run:.1f} s " + "(time cost)", + flush=True, + ) + + if global_buf: + global_descriptors = torch.stack(global_buf) + else: + global_descriptors = torch.empty( + 0, + self.descriptor.n_descriptors, + device=self._device, + dtype=self._dtype, + ) + return LoxodynamicsResult( + final_state=state, + wall_stats=wall_stats, + training_reports=training_reports, + global_descriptors=global_descriptors, + total_steps=total_steps, + segment_times=segment_times, + ) + + +def run_loxodynamics( + state: SimState, + model: ModelInterface, + *, + descriptor: PairDistanceDescriptor, + max_steps: int, + segment_steps: int, + timestep: float, + temperature: float, + initial_unbiased_steps: int | None = None, + sample_stride: int = 1, + kappa: float = 1.0, + wall_offset: float = 1.0, + global_buffer_capacity: int = 50000, + min_local_samples: int = 10, + gamma: torch.Tensor | float | None = None, + seed: int | None = None, + unit_system: UnitSystem = UnitSystem.metal, + skewencoder_config: SkewencoderConfig | None = None, + trajectory_reporter: TrajectoryReporter | None = None, + checkpoint_dir: str | Path | None = None, + verbose: bool = False, + **init_kwargs: Any, +) -> LoxodynamicsResult: + """Convenience wrapper: build a :class:`LoxodynamicsExecutor` and run it. + + See :class:`LoxodynamicsExecutor` for argument semantics. + """ + executor = LoxodynamicsExecutor( + model, + descriptor, + max_steps=max_steps, + segment_steps=segment_steps, + timestep=timestep, + temperature=temperature, + initial_unbiased_steps=initial_unbiased_steps, + sample_stride=sample_stride, + kappa=kappa, + wall_offset=wall_offset, + global_buffer_capacity=global_buffer_capacity, + min_local_samples=min_local_samples, + gamma=gamma, + seed=seed, + unit_system=unit_system, + skewencoder_config=skewencoder_config, + trajectory_reporter=trajectory_reporter, + checkpoint_dir=checkpoint_dir, + verbose=verbose, + **init_kwargs, + ) + return executor.run(state) diff --git a/torch_sim/enhanced_sampling/metadynamics.py b/torch_sim/enhanced_sampling/metadynamics.py index f4de94fa6..ea798d2d7 100644 --- a/torch_sim/enhanced_sampling/metadynamics.py +++ b/torch_sim/enhanced_sampling/metadynamics.py @@ -1,9 +1,10 @@ -"""Metadynamics bias potentials. +"""Bias potentials for enhanced sampling. -This module implements history-dependent and static bias potentials that add -external energies and forces to a simulation. Each bias is a -:class:`~torch_sim.models.interface.ModelInterface`, so it composes with any -MLIP (or classical potential) through +This module provides two composable bias potentials -- a static spherical +confining wall and a history-dependent RMSD bias -- that add extra energy and +forces on top of a base potential. Each bias is a +:class:`~torch_sim.models.interface.ModelInterface`, so it layers onto any MLIP +(or classical potential) through :class:`~torch_sim.models.interface.SumModel`:: bias = RMSDCV(k_push=0.02, alpha_width=1.2) @@ -67,13 +68,16 @@ def _segment_sum( class LogfermiWall(ModelInterface): - """Log-Fermi wall potential confining atoms inside a sphere. + """Spherical confining potential with a smooth log-Fermi boundary. - Adds the per-atom energy ``k_wall * log(1 + exp(beta * (r - radius)))`` - where ``r`` is the distance of the atom from the wall center. The energy - is near zero well inside the sphere and grows linearly (slope - ``k_wall * beta``) outside it, gently steering escaping atoms back. - Idea and default parameters from 10.1021/acs.jctc.9b00143. + Holds the atoms of each system near a center point through the per-atom + penalty ``k_wall * softplus(beta * (r - radius))``, where ``r`` is the + atom's distance from the center. Well inside the radius the penalty is + negligible; once an atom crosses the radius the softplus turns linear, so + the inward force saturates at ``k_wall * beta`` rather than diverging -- + a soft, leak-proof boundary for keeping fragments from drifting apart. This + is the log-Fermi restraint from Grimme's metadynamics-style biasing + (10.1021/acs.jctc.9b00143); the default parameters follow that reference. Forces are computed analytically, so the model is safe to call under ``torch.no_grad()``. @@ -162,14 +166,17 @@ def forward(self, state: SimState, **_kwargs) -> dict[str, torch.Tensor]: class RMSDCV(ModelInterface): - """History-dependent RMSD bias (weighted, Kabsch-aligned) for metadynamics. - - Maintains a rolling buffer of reference structures and adds the repulsive - bias ``E = k_push * sum_i exp(-alpha * rmsd2_i)`` per system, where - ``rmsd2_i`` is the squared deviation from reference *i* after optimal - (Kabsch) alignment, averaged over biased atoms and Cartesian components. - This pushes the dynamics away from previously visited configurations. - Idea and default parameters from 10.1021/acs.jctc.9b00143. + """History-dependent Gaussian bias on an RMSD collective variable. + + Discourages the system from revisiting earlier geometries by summing a + Gaussian repulsion over a rolling set of stored reference structures, + ``E = k_push * sum_i exp(-alpha * d2_i)`` per system, where ``d2_i`` is the + mean-square displacement from reference *i* after a least-squares (Kabsch) + rotation that removes rigid-body orientation, averaged over the biased atoms + and Cartesian components. As references accumulate along the trajectory the + bias fills the basins already visited and steers the dynamics toward new + configurations. The functional form and default parameters follow Grimme's + RMSD-based metadynamics (10.1021/acs.jctc.9b00143). The references are held in a :class:`~torch_sim.enhanced_sampling.history.History` buffer, which owns the deposition cadence and capacity limit. The buffer is diff --git a/torch_sim/enhanced_sampling/skewencoder.py b/torch_sim/enhanced_sampling/skewencoder.py new file mode 100644 index 000000000..c266f70f1 --- /dev/null +++ b/torch_sim/enhanced_sampling/skewencoder.py @@ -0,0 +1,442 @@ +"""Skewencoder: a 1-D-bottleneck autoencoder with a skewness auxiliary loss. + +The Skewencoder learns a one-dimensional latent collective variable from +structural descriptors. Its training objective combines an autoencoder +reconstruction loss (on the *global* descriptor history) with an auxiliary loss +(on the *local*, most-recent segment) that rewards a large-magnitude skewness of +the latent distribution -- the signal Loxodynamics uses to pick a biasing +direction. + +This is a lightweight, pure-PyTorch adaptation of the reference ``skewencoder`` +package (MIT licensed) by the original authors; only the model structure, the +``log(1 + exp(-skew^2))`` skewness loss, and the shifted-softplus activation are +reused. None of the reference package's Lightning / mlcolvar / PLUMED machinery +is required here. +""" + +from __future__ import annotations + +import math +import time +from dataclasses import dataclass + +import torch +from torch import nn + + +_ACTIVATIONS = ("shifted_softplus", "tanh", "relu", "silu") + + +@dataclass +class SkewencoderConfig: + """Configuration for a :class:`Skewencoder` and its trainer. + + Attributes: + input_dim: Number of input descriptors. + hidden_dims: Encoder hidden layer widths; the decoder mirrors them. + The default ``(90, 40, 20, 5)`` follows the reference architecture. + latent_dim: Bottleneck width. Must be 1 in this version. + activation: One of ``"shifted_softplus"``, ``"tanh"``, ``"relu"``, + ``"silu"``. + alpha: Coefficient of the skewness auxiliary loss. + beta: Coefficient of the L2 weight regularization. + learning_rate: Adam learning rate. + batch_size: Minibatch size for the global reconstruction term. + max_epochs: Maximum training epochs per call. + early_stopping_patience: Epochs without improvement before stopping. + min_delta: Minimum total-loss improvement counted as progress. + eps: Numerical-stability floor used in skewness/normalization. + verbose: If True, print per-epoch training diagnostics to stdout. + verbose_stride: Print every this many epochs when ``verbose`` (the first + and last epoch are always printed). + """ + + input_dim: int + hidden_dims: tuple[int, ...] = (90, 40, 20, 5) + latent_dim: int = 1 + activation: str = "shifted_softplus" + alpha: float = 0.1 + beta: float = 1.0e-5 + learning_rate: float = 1.0e-3 + batch_size: int = 128 + max_epochs: int = 200 + early_stopping_patience: int = 10 + min_delta: float = 1.0e-6 + eps: float = 1.0e-12 + verbose: bool = False + verbose_stride: int = 1 + + def __post_init__(self) -> None: + """Validate the configuration.""" + if self.latent_dim != 1: + raise ValueError(f"{self.latent_dim=} must be 1 in this version") + if self.input_dim < 1: + raise ValueError(f"{self.input_dim=} must be >= 1") + if self.activation not in _ACTIVATIONS: + raise ValueError(f"{self.activation=} must be one of {_ACTIVATIONS}") + + +class ShiftedSoftplus(nn.Module): + """Shifted softplus activation ``softplus(x) - log(2)``. + + Matches the reference implementation: softplus shifted so that the output is + zero at the origin, giving a smooth, everywhere-differentiable nonlinearity. + """ + + def __init__(self) -> None: + """Initialize the activation.""" + super().__init__() + self._softplus = nn.Softplus() + self._shift = math.log(2.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the shifted softplus.""" + return self._softplus(x) - self._shift + + +def _make_activation(name: str) -> nn.Module: + """Return a fresh activation module for the given name.""" + if name == "shifted_softplus": + return ShiftedSoftplus() + if name == "tanh": + return nn.Tanh() + if name == "relu": + return nn.ReLU() + if name == "silu": + return nn.SiLU() + raise ValueError(f"unsupported activation {name!r}") + + +def _build_mlp(dims: list[int], activation: str) -> nn.Sequential: + """Build an MLP over ``dims`` with the activation between hidden layers. + + No activation is applied after the final linear layer, so the encoder + bottleneck and the decoder output stay linear. + """ + layers: list[nn.Module] = [] + for i in range(len(dims) - 1): + layers.append(nn.Linear(dims[i], dims[i + 1])) + if i < len(dims) - 2: + layers.append(_make_activation(activation)) + return nn.Sequential(*layers) + + +class Skewencoder(nn.Module): + """Autoencoder with a one-dimensional latent bottleneck. + + The encoder maps ``input_dim`` descriptors through ``hidden_dims`` down to a + single latent value; the decoder mirrors that path back to ``input_dim``. + + Args: + config: The :class:`SkewencoderConfig`. + """ + + def __init__(self, config: SkewencoderConfig) -> None: + """Build the encoder/decoder from the config.""" + super().__init__() + self.config = config + enc_dims = [config.input_dim, *config.hidden_dims, config.latent_dim] + self.encoder = _build_mlp(enc_dims, config.activation) + self.decoder = _build_mlp(list(reversed(enc_dims)), config.activation) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode descriptors ``[n_samples, input_dim]`` to latent ``[n_samples, 1]``.""" + return self.encoder(x) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent ``[n_samples, 1]`` back to ``[n_samples, input_dim]``.""" + return self.decoder(z) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Return ``(reconstruction, latent)`` for inputs ``[n_samples, input_dim]``.""" + z = self.encode(x) + return self.decode(z), z + + +def skewness_1d(values: torch.Tensor, eps: float = 1.0e-12) -> torch.Tensor: + """Differentiable sample skewness of a 1-D set of values. + + Args: + values: Any tensor; it is flattened to 1-D. + eps: Numerical-stability floor for the variance. + + Returns: + A scalar tensor with the (biased) sample skewness + ``mean(c^3) / (mean(c^2) + eps)^{1.5}`` where ``c = values - mean``. + """ + flat = values.reshape(-1) + centered = flat - flat.mean() + m2 = centered.pow(2).mean() + m3 = centered.pow(3).mean() + return m3 / (m2 + eps).pow(1.5) + + +def skewness_loss(latent_local: torch.Tensor, eps: float = 1.0e-12) -> torch.Tensor: + """Skewness auxiliary loss ``log(1 + exp(-skew^2)) = softplus(-skew^2)``. + + Minimizing this drives the latent skewness toward large magnitude (either + sign), so the latent distribution develops a pronounced tail. + + Args: + latent_local: Latent values of the local segment, any shape. + eps: Numerical-stability floor passed to :func:`skewness_1d`. + + Returns: + A scalar loss tensor. + """ + gamma = skewness_1d(latent_local, eps=eps) + return torch.nn.functional.softplus(-gamma.pow(2)) + + +def _l2_weights(model: Skewencoder) -> torch.Tensor: + """Sum of squared trainable parameters (weights and biases, not buffers).""" + total = None + for param in model.parameters(): + if not param.requires_grad: + continue + term = param.pow(2).sum() + total = term if total is None else total + term + if total is None: + return torch.zeros((), device=next(model.parameters()).device) + return total + + +def skewencoder_loss( + model: Skewencoder, + global_x: torch.Tensor, + local_x: torch.Tensor, + *, + alpha: float, + beta: float, + eps: float, +) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute the multitask Skewencoder loss. + + ``L = L_AE(global_x) + alpha * L_skew(local_x) + beta * L2_weights`` + + Reconstruction is taken on the global dataset and the skewness term on the + local dataset. + + Args: + model: The Skewencoder. + global_x: Normalized global descriptors ``[n_global, input_dim]``. + local_x: Normalized local descriptors ``[n_local, input_dim]``. + alpha: Skewness loss coefficient. + beta: L2 regularization coefficient. + eps: Numerical-stability floor for skewness. + + Returns: + A tuple ``(loss_total, diagnostics)`` where ``diagnostics`` holds + detached scalars: ``loss_total``, ``loss_reconstruction``, + ``loss_skew``, ``loss_l2``, ``local_skewness``. + """ + recon, _ = model(global_x) + loss_ae = torch.nn.functional.mse_loss(recon, global_x) + + z_local = model.encode(local_x) + loss_skew = skewness_loss(z_local, eps=eps) + l2 = _l2_weights(model) + + loss_total = loss_ae + alpha * loss_skew + beta * l2 + diagnostics = { + "loss_total": loss_total.detach(), + "loss_reconstruction": loss_ae.detach(), + "loss_skew": loss_skew.detach(), + "loss_l2": (beta * l2).detach(), + "local_skewness": skewness_1d(z_local, eps=eps).detach(), + } + return loss_total, diagnostics + + +@dataclass +class DescriptorNormalizer: + """Affine descriptor normalizer ``(x - mean) / std``. + + Attributes: + mean: Per-descriptor mean ``[input_dim]``. + std: Per-descriptor standard deviation ``[input_dim]`` (clamped by eps). + eps: Numerical-stability floor for the standard deviation. + """ + + mean: torch.Tensor + std: torch.Tensor + eps: float = 1.0e-12 + + def transform(self, x: torch.Tensor) -> torch.Tensor: + """Normalize ``x`` to zero mean / unit variance per descriptor.""" + return (x - self.mean) / self.std + + def inverse_transform(self, x: torch.Tensor) -> torch.Tensor: + """Invert :meth:`transform`.""" + return x * self.std + self.mean + + def to(self, device: torch.device, dtype: torch.dtype) -> DescriptorNormalizer: + """Return a copy with the statistics moved to ``device``/``dtype``.""" + return DescriptorNormalizer( + mean=self.mean.to(device, dtype), + std=self.std.to(device, dtype), + eps=self.eps, + ) + + +def fit_descriptor_normalizer( + x: torch.Tensor, eps: float = 1.0e-12 +) -> DescriptorNormalizer: + """Fit a :class:`DescriptorNormalizer` from descriptor samples. + + Args: + x: Descriptor samples ``[n_samples, input_dim]``. + eps: Floor applied to the standard deviation. + + Returns: + A normalizer whose ``std`` is clamped to at least ``eps``. + """ + mean = x.mean(dim=0) + std = x.std(dim=0, unbiased=False).clamp_min(eps) + return DescriptorNormalizer(mean=mean, std=std, eps=eps) + + +@dataclass +class SkewencoderTrainingReport: + """Summary of a single :meth:`SkewencoderTrainer.train` call.""" + + n_epochs: int + final_loss: float + final_reconstruction_loss: float + final_skew_loss: float + final_l2_loss: float + final_local_skewness: float + stopped_early: bool + train_time_s: float = 0.0 + """Wall-clock time spent in :meth:`SkewencoderTrainer.train`, in seconds.""" + + +class SkewencoderTrainer: + """Adam trainer for the multitask Skewencoder loss with early stopping. + + Args: + config: The :class:`SkewencoderConfig` (hyperparameters and stopping). + """ + + def __init__(self, config: SkewencoderConfig) -> None: + """Store the training configuration.""" + self.config = config + + def train( # noqa: C901 + self, + model: Skewencoder, + global_descriptors: torch.Tensor, + local_descriptors: torch.Tensor, + *, + normalizer: DescriptorNormalizer | None = None, + ) -> tuple[DescriptorNormalizer, SkewencoderTrainingReport]: + """Train ``model`` in place; warm-start by reusing the same instance. + + Args: + model: Skewencoder to train (modified in place). + global_descriptors: Raw global descriptors ``[n_global, input_dim]`` + used for the reconstruction loss. + local_descriptors: Raw local descriptors ``[n_local, input_dim]`` + used for the skewness loss. + normalizer: Optional fixed normalizer; if ``None`` one is fit from + ``global_descriptors``. + + Returns: + A tuple ``(normalizer, report)``. + """ + cfg = self.config + t_start = time.perf_counter() + param = next(model.parameters()) + device, dtype = param.device, param.dtype + + global_x = global_descriptors.to(device, dtype) + local_x = local_descriptors.to(device, dtype) + if normalizer is None: + normalizer = fit_descriptor_normalizer(global_x, eps=cfg.eps) + normalizer = normalizer.to(device, dtype) + + gx = normalizer.transform(global_x) + lx = normalizer.transform(local_x) + + # Re-enable grad in case a previous wall put the shared encoder in a + # frozen/eval state; warm-start retraining needs trainable parameters. + for param in model.parameters(): + param.requires_grad_(requires_grad=True) + + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) + n_global = gx.shape[0] + batch_size = min(cfg.batch_size, n_global) + + best_loss = math.inf + patience = 0 + stopped_early = False + n_epochs = 0 + if cfg.verbose: + print( # noqa: T201 + f" [skewencoder] train: n_global={n_global} n_local={lx.shape[0]} " + f"max_epochs={cfg.max_epochs} batch={batch_size}", + flush=True, + ) + model.train() + for epoch in range(1, cfg.max_epochs + 1): + n_epochs = epoch + perm = torch.randperm(n_global, device=device) + for start in range(0, n_global, batch_size): + idx = perm[start : start + batch_size] + loss, _ = skewencoder_loss( + model, gx[idx], lx, alpha=cfg.alpha, beta=cfg.beta, eps=cfg.eps + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + epoch_loss, epoch_diag = skewencoder_loss( + model, gx, lx, alpha=cfg.alpha, beta=cfg.beta, eps=cfg.eps + ) + current = float(epoch_loss) + if best_loss - current > cfg.min_delta: + best_loss = current + patience = 0 + else: + patience += 1 + if patience >= cfg.early_stopping_patience: + stopped_early = True + + if cfg.verbose and ( + epoch == 1 or stopped_early or epoch % cfg.verbose_stride == 0 + ): + print( # noqa: T201 + f" epoch {epoch:>4d}/{cfg.max_epochs} " + f"loss={current:.4e} " + f"recon={float(epoch_diag['loss_reconstruction']):.4e} " + f"skew={float(epoch_diag['loss_skew']):.4e} " + f"gamma={float(epoch_diag['local_skewness']):+.4f} " + f"patience={patience}", + flush=True, + ) + if stopped_early: + break + + if cfg.verbose: + print( # noqa: T201 + f" [skewencoder] done: {n_epochs} epochs, " + f"stopped_early={stopped_early}, best_loss={best_loss:.4e}", + flush=True, + ) + model.eval() + with torch.no_grad(): + _, diag = skewencoder_loss( + model, gx, lx, alpha=cfg.alpha, beta=cfg.beta, eps=cfg.eps + ) + report = SkewencoderTrainingReport( + n_epochs=n_epochs, + final_loss=float(diag["loss_total"]), + final_reconstruction_loss=float(diag["loss_reconstruction"]), + final_skew_loss=float(diag["loss_skew"]), + final_l2_loss=float(diag["loss_l2"]), + final_local_skewness=float(diag["local_skewness"]), + stopped_early=stopped_early, + train_time_s=time.perf_counter() - t_start, + ) + return normalizer, report From 31f1b57648397396c347ed7d4f015ed2f45d8958 Mon Sep 17 00:00:00 2001 From: amateurcat Date: Thu, 25 Jun 2026 11:17:49 -0700 Subject: [PATCH 4/4] 06252026@U52: Fix format with pinned ruff 0.15.4 --- tests/enhanced_sampling/test_boxed_md.py | 4 +- tests/enhanced_sampling/test_loxodynamics.py | 101 ++++++++++++++++++- torch_sim/enhanced_sampling/boxed_md.py | 4 +- torch_sim/enhanced_sampling/history.py | 4 +- 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/tests/enhanced_sampling/test_boxed_md.py b/tests/enhanced_sampling/test_boxed_md.py index 6f4f8f58e..a4e597fa4 100644 --- a/tests/enhanced_sampling/test_boxed_md.py +++ b/tests/enhanced_sampling/test_boxed_md.py @@ -94,9 +94,7 @@ def test_idempotent_pair(self) -> None: class TestRunBoxedMD: def test_rejects_multiple_systems(self, harmonic_model: HarmonicModel) -> None: - two = ts.io.atoms_to_state( - [molecule("CH3CH2OH"), molecule("H2O")], DEVICE, DTYPE - ) + two = ts.io.atoms_to_state([molecule("CH3CH2OH"), molecule("H2O")], DEVICE, DTYPE) with pytest.raises(ValueError, match="single system"): run_boxed_md( two, diff --git a/tests/enhanced_sampling/test_loxodynamics.py b/tests/enhanced_sampling/test_loxodynamics.py index d9a583031..589d16fd5 100644 --- a/tests/enhanced_sampling/test_loxodynamics.py +++ b/tests/enhanced_sampling/test_loxodynamics.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest import torch from ase.build import molecule @@ -75,13 +77,16 @@ def test_bad_shape_raises(self) -> None: class TestLoxodynamicsWall: - def _wall(self, *, offset: float) -> tuple[LoxodynamicsWall, int]: + def _components(self): + """Deterministic (descriptor, encoder, fitted normalizer) for the wall.""" torch.manual_seed(0) desc = PairDistanceDescriptor(_all_pairs_3()) - cfg = SkewencoderConfig(input_dim=3, hidden_dims=(8, 4)) - enc = Skewencoder(cfg).to(DTYPE) + enc = Skewencoder(SkewencoderConfig(input_dim=3, hidden_dims=(8, 4))).to(DTYPE) sample = torch.randn(20, 3, dtype=DTYPE).abs() + 1.0 - norm = fit_descriptor_normalizer(sample) + return desc, enc, fit_descriptor_normalizer(sample) + + def _wall(self, *, offset: float) -> tuple[LoxodynamicsWall, int]: + desc, enc, norm = self._components() wall = LoxodynamicsWall( desc, enc, @@ -121,6 +126,60 @@ def test_rejects_multiple_systems(self) -> None: with pytest.raises(ValueError, match="single system"): wall(two) + def test_emits_latent_cv(self, water_state: ts.SimState) -> None: + # the wall reports the raw (unsigned) latent CV under the "loxo_cv" key + desc, enc, norm = self._components() + wall = LoxodynamicsWall( + desc, + enc, + norm, + mu=0.0, + sigma=1.0, + skewness=1.0, + kappa=1.0, + offset=1.0, + device=DEVICE, + dtype=DTYPE, + ) + out = wall(water_state) + assert "loxo_cv" in out + assert out["loxo_cv"].shape == (1,) + with torch.no_grad(): + latent = enc.encode( + norm.transform(desc(water_state.positions)).unsqueeze(0) + ).reshape(1) + torch.testing.assert_close(out["loxo_cv"], latent) + + def test_energy_scale_invariant_at_reference(self, water_state: ts.SimState) -> None: + # The standardized wall acts on (s - mu)/sigma, so at the reference point + # (mu == the current latent) the violation is exactly `offset` and the + # energy is kappa*offset**2 regardless of sigma. A raw-latent wall would + # instead give ~kappa*(sigma + offset)**2, which diverges as sigma grows. + desc, enc, norm = self._components() + with torch.no_grad(): + s0 = ( + enc.encode(norm.transform(desc(water_state.positions)).unsqueeze(0)) + .reshape(()) + .item() + ) + energies = [ + LoxodynamicsWall( + desc, + enc, + norm, + mu=s0, + sigma=sigma, + skewness=1.0, + kappa=1.0, + offset=2.0, + device=DEVICE, + dtype=DTYPE, + )(water_state)["energy"].item() + for sigma in (0.1, 1.0, 100.0) + ] + for energy in energies: + assert energy == pytest.approx(4.0, abs=1e-9) # kappa * offset**2 + class TestRunLoxodynamics: def _setup( @@ -175,6 +234,40 @@ def test_rejects_multiple_systems(self) -> None: min_local_samples=3, ) + def test_checkpoint_dir_saves_loadable_models( + self, water_state: ts.SimState, tmp_path: Path + ) -> None: + model, desc, cfg = self._setup(water_state) + result = run_loxodynamics( + water_state, + model, + descriptor=desc, + max_steps=20, + segment_steps=5, + initial_unbiased_steps=5, + timestep=0.0005, + temperature=300.0, + sample_stride=1, + min_local_samples=3, + seed=0, + skewencoder_config=cfg, + checkpoint_dir=tmp_path, + ) + # one checkpoint per retrain/wall + files = sorted(tmp_path.glob("skewencoder_iter*.pt")) + assert len(files) == len(result.wall_stats) >= 1 + ckpt = torch.load(files[0], weights_only=False) + assert { + "skewencoder_state_dict", + "skewencoder_config", + "normalizer", + "wall_stats", + "training_report", + } <= set(ckpt) + # the saved weights reload into a fresh Skewencoder built from the config + enc = Skewencoder(SkewencoderConfig(**ckpt["skewencoder_config"])) + enc.load_state_dict(ckpt["skewencoder_state_dict"]) + class TestExecutorDtype: def test_float32_model_end_to_end(self) -> None: diff --git a/torch_sim/enhanced_sampling/boxed_md.py b/torch_sim/enhanced_sampling/boxed_md.py index 0b31b7417..61bb9869c 100644 --- a/torch_sim/enhanced_sampling/boxed_md.py +++ b/torch_sim/enhanced_sampling/boxed_md.py @@ -279,9 +279,7 @@ def run_boxed_md( ValueError: If ``state`` contains more than one system. """ if state.n_systems != 1: - raise ValueError( - f"run_boxed_md expects a single system, got {state.n_systems}" - ) + raise ValueError(f"run_boxed_md expects a single system, got {state.n_systems}") device, dtype = state.device, state.dtype kT = float(temperature) * unit_system.temperature diff --git a/torch_sim/enhanced_sampling/history.py b/torch_sim/enhanced_sampling/history.py index d792dfb81..d0a11f99e 100644 --- a/torch_sim/enhanced_sampling/history.py +++ b/torch_sim/enhanced_sampling/history.py @@ -42,9 +42,9 @@ class History(torch.nn.Module): Example:: history = History(capacity=10, stride=5) - history.push(value) # unconditional seed + history.push(value) # unconditional seed deposited = history.maybe_push(value) # True every 5th call - record = history.stack() # (n_stored, *value.shape) + record = history.stack() # (n_stored, *value.shape) """ data: torch.Tensor | None