"""
Metatomic conversion and wrapper utilities for UFP models.
This module translates metatomic systems and outputs into the shared UFP input
and output structures without changing model logic.
"""
from __future__ import annotations
from typing import Dict, List, Optional, Sequence, Union
import ase
import numpy as np
import torch
from ufp.core._arrays import _to_numpy
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.core.potential import UFPotential
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend
def _require_metatensor_torch():
"""Import and return ``metatensor.torch`` or raise with installation guidance."""
try:
import metatensor.torch as mts
except ImportError as exc:
raise ImportError(
"metatensor-torch is not installed. Install it with `pip install "
"metatensor-torch` or `pip install 'ufp[metatomic]'`."
) from exc
return mts
def _require_metatomic_torch():
"""Import and return ``metatomic.torch`` or raise with installation guidance."""
try:
import metatomic.torch as mta
except ImportError as exc:
raise ImportError(
"metatomic-torch is not installed. Install it with `pip install "
"metatomic-torch` or `pip install 'ufp[metatomic]'`."
) from exc
return mta
def _labels_values(samples, names):
"""Return the tensor backing a metatensor labels object."""
return samples.view(list(names)).values
def _parameter_dtype(module: torch.nn.Module) -> Optional[torch.dtype]:
"""Infer the floating dtype advertised by a wrapped model."""
for parameter in module.parameters():
if parameter.is_floating_point():
return parameter.dtype
for buffer in module.buffers():
if buffer.is_floating_point():
return buffer.dtype
return None
def _dtype_name(dtype: Optional[torch.dtype]) -> str:
"""Return the metatomic dtype string matching the provided torch dtype."""
if dtype is None:
return "float64"
return str(dtype).removeprefix("torch.")
def _annotate_metatomic_forward() -> None:
"""Attach runtime metatomic annotations without making them hard imports."""
mta = _require_metatomic_torch()
mts = _require_metatensor_torch()
required = (
(mta, "System"),
(mta, "ModelOutput"),
(mts, "Labels"),
(mts, "TensorMap"),
)
if any(not hasattr(module, name) for module, name in required):
return
UFPMetatomicModule.forward.__annotations__ = {
"systems": List[mta.System],
"outputs": Dict[str, mta.ModelOutput],
"selected_atoms": Optional[mts.Labels],
"return": Dict[str, mts.TensorMap],
}
[docs]
def system_to_ase_atoms(system) -> ase.Atoms:
"""
Convert a metatomic system-like object into ASE atoms.
This helper intentionally uses duck typing so it can operate on scripted or Python
metatomic systems without importing metatomic at module import time.
Args:
system: Object with ``positions``, ``cell``, ``pbc``, and ``types`` fields.
Returns:
ASE structure with the same geometry.
Raises:
TypeError: If ``system`` does not expose the required geometry fields.
"""
if not hasattr(system, "positions"):
raise TypeError("`system` must expose positions, cell, pbc, and types")
return ase.Atoms(
numbers=_to_numpy(system.types),
positions=_to_numpy(system.positions),
cell=_to_numpy(system.cell),
pbc=_to_numpy(system.pbc),
)
[docs]
def systems_to_ase_atoms(systems) -> list[ase.Atoms]:
"""
Convert metatomic system-like objects into ASE atoms.
Args:
systems: Iterable of system-like objects.
Returns:
ASE structures in input order.
"""
return [system_to_ase_atoms(system) for system in systems]
[docs]
def neighbor_list_to_data(
neighbors,
options=None,
*,
atom_offset: int = 0,
) -> NeighborListData:
"""
Convert a metatomic neighbor-list block into UFP neighbor-list data.
The expected metatomic convention is a block with sample names
``["first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"]``
and distance vectors stored in ``neighbors.values``.
Args:
neighbors: Metatensor-like tensor block holding a neighbor list.
options: Optional object exposing ``cutoff``, ``full_list``, and ``strict``.
atom_offset: Offset applied to local atom indices.
Returns:
Normalized neighbor-list data.
Raises:
ValueError: If sample names or vector shapes do not match metatomic
neighbor-list conventions.
"""
sample_names = list(neighbors.samples.names)
required_names = [
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
]
if sample_names != required_names:
raise ValueError(
"unexpected metatomic neighbor-list samples. "
f"Expected {required_names}, got {sample_names}"
)
pairs = _labels_values(neighbors.samples, ["first_atom", "second_atom"]).T
if atom_offset != 0:
pairs = pairs + atom_offset
shifts = _labels_values(
neighbors.samples,
["cell_shift_a", "cell_shift_b", "cell_shift_c"],
)
vectors = neighbors.values
if vectors.ndim == 3 and vectors.shape[-1] == 1:
vectors = vectors.squeeze(-1)
if vectors.ndim != 2 or vectors.shape[1] != 3:
raise ValueError(
"metatomic neighbor-list values must contain distance vectors with "
f"shape (n_pairs, 3) or (n_pairs, 3, 1), got {tuple(vectors.shape)}"
)
if isinstance(vectors, torch.Tensor):
distances = torch.linalg.vector_norm(vectors, dim=1)
else:
distances = np.linalg.norm(np.asarray(vectors), axis=1)
return NeighborListData(
pairs=pairs,
shifts=shifts,
distances=distances,
vectors=vectors,
backend=NeighborListBackend.METATOMIC.value,
cutoff=None if options is None else float(options.cutoff),
full_list=True if options is None else bool(options.full_list),
sorted=False,
strict=None if options is None else bool(options.strict),
)
[docs]
def make_system_output(
values: torch.Tensor,
property_name: str,
):
"""
Create a metatensor ``TensorMap`` for per-system outputs.
Args:
values: Tensor with shape ``(n_systems,)`` or ``(n_systems, n_properties)``.
property_name: Name of the property axis.
Returns:
Metatensor tensor map storing the values.
Raises:
ImportError: If ``metatensor-torch`` is not installed.
ValueError: If ``values`` does not have one or two dimensions.
"""
mts = _require_metatensor_torch()
if values.ndim == 1:
values = values.reshape(-1, 1)
elif values.ndim != 2:
raise ValueError(
"`values` must have shape (n_systems,) or (n_systems, n_properties)"
)
device = values.device
samples = mts.Labels(
["system"],
torch.arange(len(values), device=device, dtype=torch.int64).reshape(-1, 1),
)
properties = mts.Labels(
[property_name],
torch.arange(values.shape[1], device=device, dtype=torch.int64).reshape(-1, 1),
)
block = mts.TensorBlock(
values=values,
samples=samples,
components=torch.jit.annotate(list[mts.Labels], []),
properties=properties,
)
return mts.TensorMap(
mts.Labels("_", torch.tensor([[0]], device=device, dtype=torch.int64)),
[block],
)
[docs]
def make_energy_output(energies: torch.Tensor):
"""
Create a metatensor ``TensorMap`` for batched total energies.
Args:
energies: Tensor with shape ``(n_systems,)`` or ``(n_systems, 1)``.
Returns:
Metatensor tensor map storing total energies.
"""
return make_system_output(energies, "energy")
[docs]
def make_per_atom_output(
values: torch.Tensor,
system_sizes: Sequence[int],
property_name: str = "feature",
):
"""
Create a metatensor ``TensorMap`` for batched per-atom outputs.
Args:
values: Tensor with shape ``(n_atoms_total,)`` or
``(n_atoms_total, n_properties)``.
system_sizes: Atom counts for each system in the batch.
property_name: Name to use for the output properties axis.
Returns:
Metatensor tensor map storing the per-atom values.
Raises:
ImportError: If ``metatensor-torch`` is not installed.
ValueError: If ``values`` has an unsupported shape or atom count.
"""
mts = _require_metatensor_torch()
if values.ndim == 1:
values = values.reshape(-1, 1)
elif values.ndim != 2:
raise ValueError(
"`values` must have shape (n_atoms_total,) or (n_atoms_total, n_props)"
)
if values.shape[0] != sum(system_sizes):
raise ValueError(
"`values` and `system_sizes` describe different numbers of atoms"
)
device = values.device
system_index = []
atom_index = []
for system_i, size in enumerate(system_sizes):
system_index.append(
torch.full((size,), system_i, device=device, dtype=torch.int64)
)
atom_index.append(torch.arange(size, device=device, dtype=torch.int64))
samples = mts.Labels(
["system", "atom"],
torch.stack([torch.cat(system_index), torch.cat(atom_index)], dim=1),
)
properties = mts.Labels(
[property_name],
torch.arange(values.shape[1], device=device, dtype=torch.int64).reshape(-1, 1),
)
block = mts.TensorBlock(
values=values,
samples=samples,
components=torch.jit.annotate(list[mts.Labels], []),
properties=properties,
)
return mts.TensorMap(
mts.Labels("_", torch.tensor([[0]], device=device, dtype=torch.int64)),
[block],
)
def _prediction_to_outputs(
prediction: UFPOutput,
inputs: UFPInput,
outputs,
) -> Dict[str, object]:
"""Map normalized UFP predictions into the requested metatomic output objects."""
result = {}
for name, output_spec in outputs.items():
is_per_atom = bool(getattr(output_spec, "per_atom", False))
if name == "energy":
if is_per_atom:
if prediction.per_atom_energy is None:
raise RuntimeError(
"the wrapped UFPotential did not provide `per_atom_energy`"
)
result[name] = make_per_atom_output(
torch.as_tensor(
prediction.per_atom_energy,
device=inputs.device,
dtype=inputs.dtype,
),
inputs.system_sizes,
property_name="energy",
)
else:
if prediction.energy is None:
raise RuntimeError(
"the wrapped UFPotential did not provide `energy`"
)
result[name] = make_energy_output(
torch.as_tensor(
prediction.energy,
device=inputs.device,
dtype=inputs.dtype,
).reshape(inputs.n_systems, -1)[:, 0]
)
continue
if name not in prediction.features:
raise RuntimeError(
f"the wrapped UFPotential did not provide the requested output `{name}`"
)
values = torch.as_tensor(
prediction.features[name],
device=inputs.device,
dtype=inputs.dtype,
)
if is_per_atom:
result[name] = make_per_atom_output(
values,
inputs.system_sizes,
property_name=name,
)
else:
result[name] = make_system_output(values, name)
return result
[docs]
def wrap_atomistic_model(
potential: UFPotential,
*,
atomic_types: Optional[Sequence[int]] = None,
length_unit: str = "Angstrom",
energy_unit: str = "eV",
supported_devices: Optional[Sequence[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
metadata=None,
full_neighbor_list: bool = True,
strict_neighbor_list: bool = True,
):
"""
Wrap a UFP potential as a metatomic ``AtomisticModel``.
Args:
potential: UFP potential to wrap.
atomic_types: Supported atomic numbers. When omitted, the value is inferred
from ``potential.atomic_types`` when available.
length_unit: Length unit advertised to metatomic.
energy_unit: Energy unit advertised for the total-energy output.
supported_devices: Explicit list of supported devices.
dtype: Explicit model dtype string or torch dtype.
metadata: Optional metatomic ``ModelMetadata`` object.
full_neighbor_list: Whether metatomic should request a full neighbor list.
strict_neighbor_list: Whether the requested list should be strict.
Returns:
Metatomic ``AtomisticModel``.
Raises:
ImportError: If ``metatomic-torch`` is not installed.
ValueError: If supported atomic types cannot be inferred.
"""
mta = _require_metatomic_torch()
if atomic_types is None:
atomic_types = getattr(potential, "atomic_types", None)
if atomic_types is None:
raise ValueError(
"`atomic_types` is required to build a metatomic AtomisticModel"
)
if supported_devices is None:
supported_devices = ["cpu"]
if isinstance(dtype, torch.dtype):
dtype_name = _dtype_name(dtype)
elif isinstance(dtype, str):
dtype_name = dtype
else:
dtype_name = _dtype_name(_parameter_dtype(potential))
capabilities = mta.ModelCapabilities(
length_unit=length_unit,
atomic_types=list(atomic_types),
interaction_range=0.0 if potential.cutoff is None else float(potential.cutoff),
outputs={
"energy": mta.ModelOutput(
quantity="energy",
unit=energy_unit,
per_atom=False,
)
},
supported_devices=list(supported_devices),
dtype=dtype_name,
)
if metadata is None:
metadata = mta.ModelMetadata()
raw_model = UFPMetatomicModule(
potential,
full_neighbor_list=full_neighbor_list,
strict_neighbor_list=strict_neighbor_list,
)
_annotate_metatomic_forward()
return mta.AtomisticModel(raw_model.eval(), metadata, capabilities)
__all__ = [
"UFPMetatomicModule",
"make_energy_output",
"make_per_atom_output",
"make_system_output",
"neighbor_list_to_data",
"system_to_ase_atoms",
"systems_to_ase_atoms",
"systems_to_input",
"wrap_atomistic_model",
]