"""
Torch-sim adapters for running UFP models inside state-based simulations.
Prefer the metatomic-backed wrapper when available; keep the ASE-backed path as
the simple fallback for debugging and compatibility.
"""
from __future__ import annotations
from typing import List, Optional, Union
import ase
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from ufp.adapters.metatomic import wrap_atomistic_model
from ufp.core._arrays import _to_numpy
from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend
try:
from torch_sim.models.interface import ModelInterface as _TorchSimModelInterface
HAS_TORCHSIM = True
except ImportError:
_TorchSimModelInterface = torch.nn.Module
HAS_TORCHSIM = False
def state_to_ase_atoms(state) -> List[ase.Atoms]:
"""
Convert a torch-sim state-like object into ASE atoms.
The state is expected to expose ``positions``, ``cell``, ``pbc``,
``atomic_numbers``, and ``system_idx``.
Args:
state: Torch-sim state-like object.
Returns:
Per-system ASE structures in system-index order.
Raises:
TypeError: If the state is missing required attributes.
"""
required = ["positions", "cell", "pbc", "atomic_numbers", "system_idx"]
missing = [name for name in required if not hasattr(state, name)]
if missing:
raise TypeError(
f"state-like object is missing required attributes: {', '.join(missing)}"
)
positions = _to_numpy(state.positions)
cells = _to_numpy(state.cell)
pbc = _to_numpy(state.pbc)
atomic_numbers = _to_numpy(state.atomic_numbers)
system_idx = _to_numpy(state.system_idx).astype(np.int64, copy=False)
n_systems = int(system_idx.max()) + 1 if len(system_idx) != 0 else int(len(cells))
atoms_list = []
for system_i in range(n_systems):
atom_mask = system_idx == system_i
atoms_list.append(
ase.Atoms(
numbers=atomic_numbers[atom_mask],
positions=positions[atom_mask],
cell=cells[system_i],
pbc=pbc[system_i] if np.ndim(pbc) == 2 else pbc,
)
)
return atoms_list
def _split_state(state):
"""Split a torch-sim-like state into per-system ASE structures and atom indices."""
atoms_list = state_to_ase_atoms(state)
indices = []
for system_i in range(len(atoms_list)):
atom_indices = torch.nonzero(
state.system_idx == system_i,
as_tuple=False,
).reshape(-1)
indices.append(atom_indices)
return list(zip(atoms_list, indices, strict=True))
def _as_scalar_tensor(
value,
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Normalize a scalar energy-like value into a tensor on the requested device."""
if isinstance(value, torch.Tensor):
value = value.detach().to(device=device, dtype=dtype).reshape(-1)
else:
value = torch.as_tensor(value, device=device, dtype=dtype).reshape(-1)
if value.numel() != 1:
raise ValueError("expected a scalar energy value")
return value[0]
def _as_tensor(
value,
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Normalize array-like output into a detached tensor on the requested device."""
if isinstance(value, torch.Tensor):
return value.detach().to(device=device, dtype=dtype)
return torch.as_tensor(value, device=device, dtype=dtype)
def _resolve_torch_dtype(dtype: Optional[Union[str, torch.dtype]]) -> torch.dtype:
"""Resolve torch dtype."""
if isinstance(dtype, torch.dtype):
return dtype
if isinstance(dtype, str):
resolved = getattr(torch, dtype, None)
if isinstance(resolved, torch.dtype):
return resolved
raise ValueError(f"unsupported torch dtype string: {dtype}")
return torch.get_default_dtype()
[docs]
class ASEBackedTorchSimModel(_TorchSimModelInterface):
"""
Adapt a UFP potential to the torch-sim model interface through ASE.
This adapter treats ASE as the internal structure interchange format. It is useful
for early prototyping and CPU-side integration, but it should not be treated as the
final high-performance path for large or differentiable simulations.
Args:
potential: Wrapped UFP potential.
neighbor_backend: Backend used when the potential builds neighbor lists.
device: Device of the output tensors.
dtype: Dtype of the output tensors.
compute_forces: Whether forces are expected from the wrapped potential.
compute_stress: Whether stress is expected from the wrapped potential.
"""
def __init__(
self,
potential: UFPotential,
*,
neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO,
device: Optional[torch.device] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
compute_forces: bool = True,
compute_stress: bool = False,
) -> None:
"""Initialize the ASE-backed torch-sim adapter around one UFP potential."""
if not HAS_TORCHSIM:
raise ImportError(
"torch-sim is not installed. Install it with `pip install "
"torch-sim-atomistic` or `pip install 'ufp[torchsim]'`."
)
super().__init__()
self.potential = potential
self._device = torch.device("cpu") if device is None else device
self._dtype = _resolve_torch_dtype(dtype)
self._compute_forces = bool(compute_forces)
self._compute_stress = bool(compute_stress)
self._neighbor_backend = NeighborListBackend(neighbor_backend)
@property
def device(self) -> torch.device:
"""Return device."""
return self._device
@property
def dtype(self) -> torch.dtype:
"""Return dtype."""
return self._dtype
@property
def compute_forces(self) -> bool:
"""Report whether the wrapper requests forces from the wrapped potential."""
return self._compute_forces
@property
def compute_stress(self) -> bool:
"""Report whether the wrapper requests stress from the wrapped potential."""
return self._compute_stress
[docs]
def forward(self, state, **kwargs) -> dict[str, torch.Tensor]:
"""Split the incoming state by system and pack evaluated outputs."""
del kwargs
split_state = _split_state(state)
energies = torch.zeros(len(split_state), device=self.device, dtype=self.dtype)
results: dict[str, torch.Tensor] = {"energy": energies}
if self.compute_forces:
results["forces"] = torch.zeros(
tuple(state.positions.shape),
device=self.device,
dtype=self.dtype,
)
if self.compute_stress:
results["stress"] = torch.zeros(
(len(split_state), 3, 3),
device=self.device,
dtype=self.dtype,
)
for system_i, (atoms, atom_indices) in enumerate(split_state):
prediction = self.potential.compute(
atoms=atoms,
backend=self._neighbor_backend,
derive_forces=self.compute_forces,
dtype=self.dtype,
device=self.device,
)
if prediction.energy is None:
raise RuntimeError(
"wrapped UFPotential must provide `energy` for "
"torch-sim integration"
)
results["energy"][system_i] = _as_scalar_tensor(
prediction.energy,
dtype=self.dtype,
device=self.device,
)
if self.compute_forces:
if prediction.forces is None:
raise RuntimeError(
"wrapped UFPotential must provide `forces` when "
"`compute_forces=True`"
)
forces = _as_tensor(
prediction.forces,
dtype=self.dtype,
device=self.device,
)
if tuple(forces.shape) != (len(atom_indices), 3):
raise ValueError(
"wrapped UFPotential returned forces with shape "
f"{tuple(forces.shape)}, expected ({len(atom_indices)}, 3)"
)
results["forces"][atom_indices] = forces
if self.compute_stress:
if prediction.stress is None:
raise RuntimeError(
"wrapped UFPotential must provide `stress` when "
"`compute_stress=True`"
)
stress = _as_tensor(
prediction.stress,
dtype=self.dtype,
device=self.device,
)
if tuple(stress.shape) == (6,):
stress = torch.as_tensor(
voigt_6_to_full_3x3_stress(stress.detach().cpu().numpy()),
device=self.device,
dtype=self.dtype,
)
elif tuple(stress.shape) == (1, 3, 3):
stress = stress[0]
if tuple(stress.shape) != (3, 3):
raise ValueError(
"wrapped UFPotential returned stress with shape "
f"{tuple(stress.shape)}, expected (3, 3)"
)
results["stress"][system_i] = stress
return results
[docs]
def build_torchsim_model(
potential: UFPotential,
*,
atomic_types: Optional[list[int]] = None,
device: Union[str, torch.device] = "cpu",
length_unit: str = "Angstrom",
energy_unit: str = "eV",
supported_devices: Optional[list[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO,
fallback_to_ase: bool = False,
):
"""
Build the preferred torch-sim wrapper for a UFP potential.
When ``metatomic-torchsim`` is installed, this returns a
:py:class:`metatomic_torchsim.MetatomicModel` wrapped around a metatomic
``AtomisticModel`` created from the UFP potential. This keeps the model in
torch/metatensor space so energies remain differentiable for forces and stress.
If ``fallback_to_ase=True`` and the metatomic path is unavailable, UFP falls back
to the slower ASE-backed adapter.
Args:
potential: UFP potential to wrap.
atomic_types: Supported atomic numbers for metatomic model capabilities.
device: Device passed to the torch-sim wrapper.
length_unit: Length unit advertised to metatomic.
energy_unit: Energy unit advertised for total energies.
supported_devices: Devices advertised in metatomic capabilities.
dtype: Explicit model dtype string or torch dtype.
neighbor_backend: Backend used by the ASE fallback path.
fallback_to_ase: Whether to use the ASE-backed adapter when
``metatomic-torchsim`` is unavailable.
Returns:
A metatomic-backed torch-sim model, or an ASE-backed fallback model.
Raises:
ImportError: If ``metatomic-torchsim`` is unavailable and fallback is
disabled.
"""
try:
from metatomic_torchsim import MetatomicModel
except ImportError as exc:
if fallback_to_ase:
return ASEBackedTorchSimModel(
potential,
neighbor_backend=neighbor_backend,
device=torch.device(device),
dtype=_resolve_torch_dtype(dtype),
)
raise ImportError(
"metatomic-torchsim is not installed. Install it with `pip install "
"metatomic-torchsim` or `pip install 'ufp[metatomic,torchsim]'`. "
"Set `fallback_to_ase=True` to use the slower ASE-backed adapter."
) from exc
atomistic_model = wrap_atomistic_model(
potential,
atomic_types=atomic_types,
length_unit=length_unit,
energy_unit=energy_unit,
supported_devices=supported_devices,
dtype=dtype,
)
return MetatomicModel(atomistic_model, device=device)
__all__ = [
"ASEBackedTorchSimModel",
"HAS_TORCHSIM",
"build_torchsim_model",
"state_to_ase_atoms",
]