Source code for ufp.adapters.metatomic

"""
Metatomic conversion and wrapper utilities for UFP models.

This module translates metatomic systems and outputs into the shared UFP input
and output structures without changing model logic.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Sequence, Union

import ase
import numpy as np
import torch

from ufp.core._arrays import _to_numpy
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.core.potential import UFPotential
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend


def _require_metatensor_torch():
    """Import and return ``metatensor.torch`` or raise with installation guidance."""
    try:
        import metatensor.torch as mts
    except ImportError as exc:
        raise ImportError(
            "metatensor-torch is not installed. Install it with `pip install "
            "metatensor-torch` or `pip install 'ufp[metatomic]'`."
        ) from exc

    return mts


def _require_metatomic_torch():
    """Import and return ``metatomic.torch`` or raise with installation guidance."""
    try:
        import metatomic.torch as mta
    except ImportError as exc:
        raise ImportError(
            "metatomic-torch is not installed. Install it with `pip install "
            "metatomic-torch` or `pip install 'ufp[metatomic]'`."
        ) from exc

    return mta


def _labels_values(samples, names):
    """Return the tensor backing a metatensor labels object."""
    return samples.view(list(names)).values


def _parameter_dtype(module: torch.nn.Module) -> Optional[torch.dtype]:
    """Infer the floating dtype advertised by a wrapped model."""
    for parameter in module.parameters():
        if parameter.is_floating_point():
            return parameter.dtype

    for buffer in module.buffers():
        if buffer.is_floating_point():
            return buffer.dtype

    return None


def _dtype_name(dtype: Optional[torch.dtype]) -> str:
    """Return the metatomic dtype string matching the provided torch dtype."""
    if dtype is None:
        return "float64"

    return str(dtype).removeprefix("torch.")


def _annotate_metatomic_forward() -> None:
    """Attach runtime metatomic annotations without making them hard imports."""
    mta = _require_metatomic_torch()
    mts = _require_metatensor_torch()
    required = (
        (mta, "System"),
        (mta, "ModelOutput"),
        (mts, "Labels"),
        (mts, "TensorMap"),
    )
    if any(not hasattr(module, name) for module, name in required):
        return

    UFPMetatomicModule.forward.__annotations__ = {
        "systems": List[mta.System],
        "outputs": Dict[str, mta.ModelOutput],
        "selected_atoms": Optional[mts.Labels],
        "return": Dict[str, mts.TensorMap],
    }


[docs] def system_to_ase_atoms(system) -> ase.Atoms: """ Convert a metatomic system-like object into ASE atoms. This helper intentionally uses duck typing so it can operate on scripted or Python metatomic systems without importing metatomic at module import time. Args: system: Object with ``positions``, ``cell``, ``pbc``, and ``types`` fields. Returns: ASE structure with the same geometry. Raises: TypeError: If ``system`` does not expose the required geometry fields. """ if not hasattr(system, "positions"): raise TypeError("`system` must expose positions, cell, pbc, and types") return ase.Atoms( numbers=_to_numpy(system.types), positions=_to_numpy(system.positions), cell=_to_numpy(system.cell), pbc=_to_numpy(system.pbc), )
[docs] def systems_to_ase_atoms(systems) -> list[ase.Atoms]: """ Convert metatomic system-like objects into ASE atoms. Args: systems: Iterable of system-like objects. Returns: ASE structures in input order. """ return [system_to_ase_atoms(system) for system in systems]
[docs] def neighbor_list_to_data( neighbors, options=None, *, atom_offset: int = 0, ) -> NeighborListData: """ Convert a metatomic neighbor-list block into UFP neighbor-list data. The expected metatomic convention is a block with sample names ``["first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]`` and distance vectors stored in ``neighbors.values``. Args: neighbors: Metatensor-like tensor block holding a neighbor list. options: Optional object exposing ``cutoff``, ``full_list``, and ``strict``. atom_offset: Offset applied to local atom indices. Returns: Normalized neighbor-list data. Raises: ValueError: If sample names or vector shapes do not match metatomic neighbor-list conventions. """ sample_names = list(neighbors.samples.names) required_names = [ "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c", ] if sample_names != required_names: raise ValueError( "unexpected metatomic neighbor-list samples. " f"Expected {required_names}, got {sample_names}" ) pairs = _labels_values(neighbors.samples, ["first_atom", "second_atom"]).T if atom_offset != 0: pairs = pairs + atom_offset shifts = _labels_values( neighbors.samples, ["cell_shift_a", "cell_shift_b", "cell_shift_c"], ) vectors = neighbors.values if vectors.ndim == 3 and vectors.shape[-1] == 1: vectors = vectors.squeeze(-1) if vectors.ndim != 2 or vectors.shape[1] != 3: raise ValueError( "metatomic neighbor-list values must contain distance vectors with " f"shape (n_pairs, 3) or (n_pairs, 3, 1), got {tuple(vectors.shape)}" ) if isinstance(vectors, torch.Tensor): distances = torch.linalg.vector_norm(vectors, dim=1) else: distances = np.linalg.norm(np.asarray(vectors), axis=1) return NeighborListData( pairs=pairs, shifts=shifts, distances=distances, vectors=vectors, backend=NeighborListBackend.METATOMIC.value, cutoff=None if options is None else float(options.cutoff), full_list=True if options is None else bool(options.full_list), sorted=False, strict=None if options is None else bool(options.strict), )
[docs] def systems_to_input( systems, *, neighbor_options=None, metadata: Optional[Dict[str, object]] = None, ) -> UFPInput: """ Convert metatomic systems into a shared UFP input bundle. Args: systems: Sequence of system-like objects. neighbor_options: Optional neighbor-list request used to fetch neighbors. metadata: Optional metadata dictionary copied into the input. Returns: Normalized UFP input bundle. Raises: ValueError: If no systems are provided. TypeError: If any system does not expose the required geometry fields. """ systems = list(systems) if not systems: raise ValueError("`systems` must contain at least one system") positions = [] cells = [] pbc = [] atomic_numbers = [] system_index = [] neighbor_lists = [] atom_offset = 0 for system_i, system in enumerate(systems): if not hasattr(system, "positions"): raise TypeError( "each system must expose positions, cell, pbc, and types fields" ) positions.append(system.positions) cells.append(system.cell) pbc.append(system.pbc) atomic_numbers.append(system.types) system_index.append( torch.full( (int(system.positions.shape[0]),), system_i, dtype=torch.int64, device=system.positions.device, ) ) if neighbor_options is not None: neighbors = system.get_neighbor_list(neighbor_options) neighbor_lists.append( neighbor_list_to_data( neighbors, options=neighbor_options, atom_offset=atom_offset, ) ) atom_offset += int(system.positions.shape[0]) neighbor_list = concatenate_neighbor_lists(neighbor_lists) return UFPInput( positions=torch.cat(positions, dim=0), cell=torch.stack(cells, dim=0), pbc=torch.stack(pbc, dim=0), atomic_numbers=torch.cat(atomic_numbers, dim=0), system_index=torch.cat(system_index, dim=0), neighbor_list=neighbor_list, metadata={} if metadata is None else dict(metadata), )
[docs] def make_system_output( values: torch.Tensor, property_name: str, ): """ Create a metatensor ``TensorMap`` for per-system outputs. Args: values: Tensor with shape ``(n_systems,)`` or ``(n_systems, n_properties)``. property_name: Name of the property axis. Returns: Metatensor tensor map storing the values. Raises: ImportError: If ``metatensor-torch`` is not installed. ValueError: If ``values`` does not have one or two dimensions. """ mts = _require_metatensor_torch() if values.ndim == 1: values = values.reshape(-1, 1) elif values.ndim != 2: raise ValueError( "`values` must have shape (n_systems,) or (n_systems, n_properties)" ) device = values.device samples = mts.Labels( ["system"], torch.arange(len(values), device=device, dtype=torch.int64).reshape(-1, 1), ) properties = mts.Labels( [property_name], torch.arange(values.shape[1], device=device, dtype=torch.int64).reshape(-1, 1), ) block = mts.TensorBlock( values=values, samples=samples, components=torch.jit.annotate(list[mts.Labels], []), properties=properties, ) return mts.TensorMap( mts.Labels("_", torch.tensor([[0]], device=device, dtype=torch.int64)), [block], )
[docs] def make_energy_output(energies: torch.Tensor): """ Create a metatensor ``TensorMap`` for batched total energies. Args: energies: Tensor with shape ``(n_systems,)`` or ``(n_systems, 1)``. Returns: Metatensor tensor map storing total energies. """ return make_system_output(energies, "energy")
[docs] def make_per_atom_output( values: torch.Tensor, system_sizes: Sequence[int], property_name: str = "feature", ): """ Create a metatensor ``TensorMap`` for batched per-atom outputs. Args: values: Tensor with shape ``(n_atoms_total,)`` or ``(n_atoms_total, n_properties)``. system_sizes: Atom counts for each system in the batch. property_name: Name to use for the output properties axis. Returns: Metatensor tensor map storing the per-atom values. Raises: ImportError: If ``metatensor-torch`` is not installed. ValueError: If ``values`` has an unsupported shape or atom count. """ mts = _require_metatensor_torch() if values.ndim == 1: values = values.reshape(-1, 1) elif values.ndim != 2: raise ValueError( "`values` must have shape (n_atoms_total,) or (n_atoms_total, n_props)" ) if values.shape[0] != sum(system_sizes): raise ValueError( "`values` and `system_sizes` describe different numbers of atoms" ) device = values.device system_index = [] atom_index = [] for system_i, size in enumerate(system_sizes): system_index.append( torch.full((size,), system_i, device=device, dtype=torch.int64) ) atom_index.append(torch.arange(size, device=device, dtype=torch.int64)) samples = mts.Labels( ["system", "atom"], torch.stack([torch.cat(system_index), torch.cat(atom_index)], dim=1), ) properties = mts.Labels( [property_name], torch.arange(values.shape[1], device=device, dtype=torch.int64).reshape(-1, 1), ) block = mts.TensorBlock( values=values, samples=samples, components=torch.jit.annotate(list[mts.Labels], []), properties=properties, ) return mts.TensorMap( mts.Labels("_", torch.tensor([[0]], device=device, dtype=torch.int64)), [block], )
def _prediction_to_outputs( prediction: UFPOutput, inputs: UFPInput, outputs, ) -> Dict[str, object]: """Map normalized UFP predictions into the requested metatomic output objects.""" result = {} for name, output_spec in outputs.items(): is_per_atom = bool(getattr(output_spec, "per_atom", False)) if name == "energy": if is_per_atom: if prediction.per_atom_energy is None: raise RuntimeError( "the wrapped UFPotential did not provide `per_atom_energy`" ) result[name] = make_per_atom_output( torch.as_tensor( prediction.per_atom_energy, device=inputs.device, dtype=inputs.dtype, ), inputs.system_sizes, property_name="energy", ) else: if prediction.energy is None: raise RuntimeError( "the wrapped UFPotential did not provide `energy`" ) result[name] = make_energy_output( torch.as_tensor( prediction.energy, device=inputs.device, dtype=inputs.dtype, ).reshape(inputs.n_systems, -1)[:, 0] ) continue if name not in prediction.features: raise RuntimeError( f"the wrapped UFPotential did not provide the requested output `{name}`" ) values = torch.as_tensor( prediction.features[name], device=inputs.device, dtype=inputs.dtype, ) if is_per_atom: result[name] = make_per_atom_output( values, inputs.system_sizes, property_name=name, ) else: result[name] = make_system_output(values, name) return result
[docs] class UFPMetatomicModule(torch.nn.Module): """ Wrap a UFP potential for metatomic execution. The wrapped UFP potential still owns all chemistry and neighbor-list consumption. This adapter only converts metatomic systems to the shared UFP input bundle and returns metatensor outputs. Args: potential: Wrapped UFP potential. full_neighbor_list: Whether the requested metatomic neighbor list should contain both directions for every pair. strict_neighbor_list: Whether the requested neighbor list should be strictly within the cutoff. """ def __init__( self, potential: UFPotential, *, full_neighbor_list: bool = True, strict_neighbor_list: bool = True, ) -> None: """Initialize UFPMetatomicModule.""" super().__init__() self.potential = potential self._neighbor_options = None if self.potential.cutoff is not None: mta = _require_metatomic_torch() self._neighbor_options = mta.NeighborListOptions( cutoff=float(self.potential.cutoff), full_list=bool(full_neighbor_list), strict=bool(strict_neighbor_list), )
[docs] def requested_neighbor_lists(self): """ Declare the neighbor list needed by the wrapped UFP model. Returns: A singleton list of ``NeighborListOptions`` or an empty list. """ if self._neighbor_options is None: return [] return [self._neighbor_options]
[docs] def forward( self, systems, outputs, selected_atoms=None, ) -> Dict[str, object]: """ Evaluate the wrapped UFP potential inside a metatomic workflow. Args: systems: Metatomic systems passed by the engine. outputs: Requested metatomic outputs. selected_atoms: Optional atom selection. This is not implemented. Returns: Dictionary mapping output names to metatensor tensor maps. Raises: NotImplementedError: If ``selected_atoms`` is provided. """ if selected_atoms is not None: raise NotImplementedError("`selected_atoms` is not implemented") inputs = systems_to_input(systems, neighbor_options=self._neighbor_options) prediction = self.potential.compute_input(inputs) return _prediction_to_outputs(prediction, inputs, outputs)
[docs] def wrap_atomistic_model( potential: UFPotential, *, atomic_types: Optional[Sequence[int]] = None, length_unit: str = "Angstrom", energy_unit: str = "eV", supported_devices: Optional[Sequence[str]] = None, dtype: Optional[Union[str, torch.dtype]] = None, metadata=None, full_neighbor_list: bool = True, strict_neighbor_list: bool = True, ): """ Wrap a UFP potential as a metatomic ``AtomisticModel``. Args: potential: UFP potential to wrap. atomic_types: Supported atomic numbers. When omitted, the value is inferred from ``potential.atomic_types`` when available. length_unit: Length unit advertised to metatomic. energy_unit: Energy unit advertised for the total-energy output. supported_devices: Explicit list of supported devices. dtype: Explicit model dtype string or torch dtype. metadata: Optional metatomic ``ModelMetadata`` object. full_neighbor_list: Whether metatomic should request a full neighbor list. strict_neighbor_list: Whether the requested list should be strict. Returns: Metatomic ``AtomisticModel``. Raises: ImportError: If ``metatomic-torch`` is not installed. ValueError: If supported atomic types cannot be inferred. """ mta = _require_metatomic_torch() if atomic_types is None: atomic_types = getattr(potential, "atomic_types", None) if atomic_types is None: raise ValueError( "`atomic_types` is required to build a metatomic AtomisticModel" ) if supported_devices is None: supported_devices = ["cpu"] if isinstance(dtype, torch.dtype): dtype_name = _dtype_name(dtype) elif isinstance(dtype, str): dtype_name = dtype else: dtype_name = _dtype_name(_parameter_dtype(potential)) capabilities = mta.ModelCapabilities( length_unit=length_unit, atomic_types=list(atomic_types), interaction_range=0.0 if potential.cutoff is None else float(potential.cutoff), outputs={ "energy": mta.ModelOutput( quantity="energy", unit=energy_unit, per_atom=False, ) }, supported_devices=list(supported_devices), dtype=dtype_name, ) if metadata is None: metadata = mta.ModelMetadata() raw_model = UFPMetatomicModule( potential, full_neighbor_list=full_neighbor_list, strict_neighbor_list=strict_neighbor_list, ) _annotate_metatomic_forward() return mta.AtomisticModel(raw_model.eval(), metadata, capabilities)
__all__ = [ "UFPMetatomicModule", "make_energy_output", "make_per_atom_output", "make_system_output", "neighbor_list_to_data", "system_to_ase_atoms", "systems_to_ase_atoms", "systems_to_input", "wrap_atomistic_model", ]