From b05400851bc51fc2b5766775031aa3e3aee5f8ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benoit=20Skl=C3=A9nard?= Date: Wed, 1 Jul 2026 11:13:21 +0200 Subject: [PATCH] Fix missing device in FixCom and get_centers_of_mass torch.zeros() calls, closes #580 --- torch_sim/constraints.py | 43 ++++++++++++++++++++++++++-------------- torch_sim/transforms.py | 10 ++++++---- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 02bc03043..adc7b5c27 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -544,7 +544,9 @@ def get_removed_dof(self, state: SimState) -> torch.Tensor: Returns: Always returns 3 (center of mass translation degrees of freedom) """ - affected_systems = torch.zeros(state.n_systems, dtype=torch.long) + affected_systems = torch.zeros( + state.n_systems, dtype=torch.long, device=state.device + ) affected_systems[self.system_idx] = 1 return 3 * affected_systems @@ -559,24 +561,29 @@ def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None raise ValueError("FixCom requires state with system_idx") system_idx = state.system_idx dtype = state.positions.dtype - system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( - 0, system_idx, state.masses - ) + device = state.device + system_mass = torch.zeros( + state.n_systems, dtype=dtype, device=device + ).scatter_add_(0, system_idx, state.masses) if self.coms is None: - self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + self.coms = torch.zeros( + (state.n_systems, 3), dtype=dtype, device=device + ).scatter_add_( 0, system_idx.unsqueeze(-1).expand(-1, 3), state.masses.unsqueeze(-1) * state.positions, ) self.coms /= system_mass.unsqueeze(-1) - new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + new_com = torch.zeros( + (state.n_systems, 3), dtype=dtype, device=device + ).scatter_add_( 0, system_idx.unsqueeze(-1).expand(-1, 3), state.masses.unsqueeze(-1) * new_positions, ) new_com /= system_mass.unsqueeze(-1) - displacement = torch.zeros(state.n_systems, 3, dtype=dtype) + displacement = torch.zeros(state.n_systems, 3, dtype=dtype, device=device) displacement[self.system_idx] = ( -new_com[self.system_idx] + self.coms[self.system_idx] ) @@ -594,16 +601,19 @@ def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: system_idx = state.system_idx # Compute center of mass momenta dtype = momenta.dtype - com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + device = state.device + com_momenta = torch.zeros( + (state.n_systems, 3), dtype=dtype, device=device + ).scatter_add_( 0, system_idx.unsqueeze(-1).expand(-1, 3), momenta, ) - system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( - 0, system_idx, state.masses - ) + system_mass = torch.zeros( + state.n_systems, dtype=dtype, device=device + ).scatter_add_(0, system_idx, state.masses) velocity_com = com_momenta / system_mass.unsqueeze(-1) - velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype) + velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype, device=device) velocity_change[self.system_idx] = velocity_com[self.system_idx] momenta -= velocity_change[system_idx] * state.masses.unsqueeze(-1) @@ -621,18 +631,21 @@ def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: raise ValueError("FixCom requires state with system_idx") system_idx = state.system_idx dtype = state.positions.dtype - system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + device = state.device + system_square_mass = torch.zeros( + state.n_systems, dtype=dtype, device=device + ).scatter_add_( 0, system_idx, torch.square(state.masses), ) - lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + lmd = torch.zeros((state.n_systems, 3), dtype=dtype, device=device).scatter_add_( 0, system_idx.unsqueeze(-1).expand(-1, 3), forces * state.masses.unsqueeze(-1), ) lmd /= system_square_mass.unsqueeze(-1) - forces_change = torch.zeros(state.n_systems, 3, dtype=dtype) + forces_change = torch.zeros(state.n_systems, 3, dtype=dtype, device=device) forces_change[self.system_idx] = lmd[self.system_idx] forces -= forces_change[system_idx] * state.masses.unsqueeze(-1) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index f4f51189c..d99f98c71 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1620,13 +1620,15 @@ def get_centers_of_mass( torch.Tensor: A tensor of shape (n_structures, 3) containing the center of mass coordinates for each structure. """ - coms = torch.zeros((n_systems, 3), dtype=positions.dtype).scatter_add_( + coms = torch.zeros( + (n_systems, 3), dtype=positions.dtype, device=positions.device + ).scatter_add_( 0, system_idx.unsqueeze(-1).expand(-1, 3), masses.unsqueeze(-1) * positions, ) - system_masses = torch.zeros((n_systems,), dtype=positions.dtype).scatter_add_( - 0, system_idx, masses - ) + system_masses = torch.zeros( + (n_systems,), dtype=positions.dtype, device=positions.device + ).scatter_add_(0, system_idx, masses) coms /= system_masses.unsqueeze(-1) return coms