Source code for ufp.adapters.torchsim

"""
Torch-sim adapters for running UFP models inside state-based simulations.

Prefer the metatomic-backed wrapper when available; keep the ASE-backed path as
the simple fallback for debugging and compatibility.
"""

from __future__ import annotations

from typing import List, Optional, Union

import ase
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress

from ufp.adapters.metatomic import wrap_atomistic_model
from ufp.core._arrays import _to_numpy
from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend


try:
    from torch_sim.models.interface import ModelInterface as _TorchSimModelInterface

    HAS_TORCHSIM = True
except ImportError:
    _TorchSimModelInterface = torch.nn.Module
    HAS_TORCHSIM = False


def state_to_ase_atoms(state) -> List[ase.Atoms]:
    """
    Convert a torch-sim state-like object into ASE atoms.

    The state is expected to expose ``positions``, ``cell``, ``pbc``,
    ``atomic_numbers``, and ``system_idx``.

    Args:
        state: Torch-sim state-like object.

    Returns:
        Per-system ASE structures in system-index order.

    Raises:
        TypeError: If the state is missing required attributes.
    """

    required = ["positions", "cell", "pbc", "atomic_numbers", "system_idx"]
    missing = [name for name in required if not hasattr(state, name)]
    if missing:
        raise TypeError(
            f"state-like object is missing required attributes: {', '.join(missing)}"
        )

    positions = _to_numpy(state.positions)
    cells = _to_numpy(state.cell)
    pbc = _to_numpy(state.pbc)
    atomic_numbers = _to_numpy(state.atomic_numbers)
    system_idx = _to_numpy(state.system_idx).astype(np.int64, copy=False)

    n_systems = int(system_idx.max()) + 1 if len(system_idx) != 0 else int(len(cells))
    atoms_list = []
    for system_i in range(n_systems):
        atom_mask = system_idx == system_i
        atoms_list.append(
            ase.Atoms(
                numbers=atomic_numbers[atom_mask],
                positions=positions[atom_mask],
                cell=cells[system_i],
                pbc=pbc[system_i] if np.ndim(pbc) == 2 else pbc,
            )
        )

    return atoms_list


def _split_state(state):
    """Split a torch-sim-like state into per-system ASE structures and atom indices."""
    atoms_list = state_to_ase_atoms(state)
    indices = []
    for system_i in range(len(atoms_list)):
        atom_indices = torch.nonzero(
            state.system_idx == system_i,
            as_tuple=False,
        ).reshape(-1)
        indices.append(atom_indices)

    return list(zip(atoms_list, indices, strict=True))


def _as_scalar_tensor(
    value,
    *,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor:
    """Normalize a scalar energy-like value into a tensor on the requested device."""
    if isinstance(value, torch.Tensor):
        value = value.detach().to(device=device, dtype=dtype).reshape(-1)
    else:
        value = torch.as_tensor(value, device=device, dtype=dtype).reshape(-1)

    if value.numel() != 1:
        raise ValueError("expected a scalar energy value")

    return value[0]


def _as_tensor(
    value,
    *,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor:
    """Normalize array-like output into a detached tensor on the requested device."""
    if isinstance(value, torch.Tensor):
        return value.detach().to(device=device, dtype=dtype)

    return torch.as_tensor(value, device=device, dtype=dtype)


def _resolve_torch_dtype(dtype: Optional[Union[str, torch.dtype]]) -> torch.dtype:
    """Resolve torch dtype."""
    if isinstance(dtype, torch.dtype):
        return dtype

    if isinstance(dtype, str):
        resolved = getattr(torch, dtype, None)
        if isinstance(resolved, torch.dtype):
            return resolved

        raise ValueError(f"unsupported torch dtype string: {dtype}")

    return torch.get_default_dtype()


[docs] class ASEBackedTorchSimModel(_TorchSimModelInterface): """ Adapt a UFP potential to the torch-sim model interface through ASE. This adapter treats ASE as the internal structure interchange format. It is useful for early prototyping and CPU-side integration, but it should not be treated as the final high-performance path for large or differentiable simulations. Args: potential: Wrapped UFP potential. neighbor_backend: Backend used when the potential builds neighbor lists. device: Device of the output tensors. dtype: Dtype of the output tensors. compute_forces: Whether forces are expected from the wrapped potential. compute_stress: Whether stress is expected from the wrapped potential. """ def __init__( self, potential: UFPotential, *, neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None, compute_forces: bool = True, compute_stress: bool = False, ) -> None: """Initialize the ASE-backed torch-sim adapter around one UFP potential.""" if not HAS_TORCHSIM: raise ImportError( "torch-sim is not installed. Install it with `pip install " "torch-sim-atomistic` or `pip install 'ufp[torchsim]'`." ) super().__init__() self.potential = potential self._device = torch.device("cpu") if device is None else device self._dtype = _resolve_torch_dtype(dtype) self._compute_forces = bool(compute_forces) self._compute_stress = bool(compute_stress) self._neighbor_backend = NeighborListBackend(neighbor_backend) @property def device(self) -> torch.device: """Return device.""" return self._device @property def dtype(self) -> torch.dtype: """Return dtype.""" return self._dtype @property def compute_forces(self) -> bool: """Report whether the wrapper requests forces from the wrapped potential.""" return self._compute_forces @property def compute_stress(self) -> bool: """Report whether the wrapper requests stress from the wrapped potential.""" return self._compute_stress
[docs] def forward(self, state, **kwargs) -> dict[str, torch.Tensor]: """Split the incoming state by system and pack evaluated outputs.""" del kwargs split_state = _split_state(state) energies = torch.zeros(len(split_state), device=self.device, dtype=self.dtype) results: dict[str, torch.Tensor] = {"energy": energies} if self.compute_forces: results["forces"] = torch.zeros( tuple(state.positions.shape), device=self.device, dtype=self.dtype, ) if self.compute_stress: results["stress"] = torch.zeros( (len(split_state), 3, 3), device=self.device, dtype=self.dtype, ) for system_i, (atoms, atom_indices) in enumerate(split_state): prediction = self.potential.compute( atoms=atoms, backend=self._neighbor_backend, derive_forces=self.compute_forces, dtype=self.dtype, device=self.device, ) if prediction.energy is None: raise RuntimeError( "wrapped UFPotential must provide `energy` for " "torch-sim integration" ) results["energy"][system_i] = _as_scalar_tensor( prediction.energy, dtype=self.dtype, device=self.device, ) if self.compute_forces: if prediction.forces is None: raise RuntimeError( "wrapped UFPotential must provide `forces` when " "`compute_forces=True`" ) forces = _as_tensor( prediction.forces, dtype=self.dtype, device=self.device, ) if tuple(forces.shape) != (len(atom_indices), 3): raise ValueError( "wrapped UFPotential returned forces with shape " f"{tuple(forces.shape)}, expected ({len(atom_indices)}, 3)" ) results["forces"][atom_indices] = forces if self.compute_stress: if prediction.stress is None: raise RuntimeError( "wrapped UFPotential must provide `stress` when " "`compute_stress=True`" ) stress = _as_tensor( prediction.stress, dtype=self.dtype, device=self.device, ) if tuple(stress.shape) == (6,): stress = torch.as_tensor( voigt_6_to_full_3x3_stress(stress.detach().cpu().numpy()), device=self.device, dtype=self.dtype, ) elif tuple(stress.shape) == (1, 3, 3): stress = stress[0] if tuple(stress.shape) != (3, 3): raise ValueError( "wrapped UFPotential returned stress with shape " f"{tuple(stress.shape)}, expected (3, 3)" ) results["stress"][system_i] = stress return results
[docs] def build_torchsim_model( potential: UFPotential, *, atomic_types: Optional[list[int]] = None, device: Union[str, torch.device] = "cpu", length_unit: str = "Angstrom", energy_unit: str = "eV", supported_devices: Optional[list[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO, fallback_to_ase: bool = False, ): """ Build the preferred torch-sim wrapper for a UFP potential. When ``metatomic-torchsim`` is installed, this returns a :py:class:`metatomic_torchsim.MetatomicModel` wrapped around a metatomic ``AtomisticModel`` created from the UFP potential. This keeps the model in torch/metatensor space so energies remain differentiable for forces and stress. If ``fallback_to_ase=True`` and the metatomic path is unavailable, UFP falls back to the slower ASE-backed adapter. Args: potential: UFP potential to wrap. atomic_types: Supported atomic numbers for metatomic model capabilities. device: Device passed to the torch-sim wrapper. length_unit: Length unit advertised to metatomic. energy_unit: Energy unit advertised for total energies. supported_devices: Devices advertised in metatomic capabilities. dtype: Explicit model dtype string or torch dtype. neighbor_backend: Backend used by the ASE fallback path. fallback_to_ase: Whether to use the ASE-backed adapter when ``metatomic-torchsim`` is unavailable. Returns: A metatomic-backed torch-sim model, or an ASE-backed fallback model. Raises: ImportError: If ``metatomic-torchsim`` is unavailable and fallback is disabled. """ try: from metatomic_torchsim import MetatomicModel except ImportError as exc: if fallback_to_ase: return ASEBackedTorchSimModel( potential, neighbor_backend=neighbor_backend, device=torch.device(device), dtype=_resolve_torch_dtype(dtype), ) raise ImportError( "metatomic-torchsim is not installed. Install it with `pip install " "metatomic-torchsim` or `pip install 'ufp[metatomic,torchsim]'`. " "Set `fallback_to_ase=True` to use the slower ASE-backed adapter." ) from exc atomistic_model = wrap_atomistic_model( potential, atomic_types=atomic_types, length_unit=length_unit, energy_unit=energy_unit, supported_devices=supported_devices, dtype=dtype, ) return MetatomicModel(atomistic_model, device=device)
__all__ = [ "ASEBackedTorchSimModel", "HAS_TORCHSIM", "build_torchsim_model", "state_to_ase_atoms", ]