"""
Execution helpers shared by potential compute paths.
This module prepares ASE inputs, derives forces from energies, and validates
model outputs against the normalized ``UFPInput`` contract.
"""
from __future__ import annotations
import weakref
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Optional, Sequence, Union
import ase
import numpy as np
import torch
from ufp.core.input import UFPInput, UFPInputState
from ufp.core.output import UFPOutput
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list
def _shape(array) -> tuple[int, ...]:
"""Return a tuple view of a tensor shape for error messages."""
return tuple(array.shape)
@dataclass(frozen=True)
class _GeometrySnapshot:
"""Exact copied geometry used to validate identity-cache hits."""
positions: np.ndarray
cell: np.ndarray
pbc: np.ndarray
numbers: np.ndarray
@classmethod
def from_atoms(cls, atoms: ase.Atoms) -> "_GeometrySnapshot":
"""Copy the geometry fields that define an ASE neighbor list."""
return cls(
positions=np.array(atoms.positions, copy=True),
cell=np.array(atoms.cell.array, copy=True),
pbc=np.array(atoms.pbc, dtype=np.bool_, copy=True),
numbers=np.array(atoms.numbers, copy=True),
)
def matches(self, atoms: ase.Atoms) -> bool:
"""Return whether the stored geometry exactly matches ``atoms``."""
return (
np.array_equal(self.positions, atoms.positions)
and np.array_equal(self.cell, atoms.cell.array)
and np.array_equal(self.pbc, np.asarray(atoms.pbc, dtype=np.bool_))
and np.array_equal(self.numbers, atoms.numbers)
)
@dataclass(frozen=True)
class _GeometryCacheEntry:
"""Cached neighbor list plus the geometry snapshot used for validation."""
atoms_ref: weakref.ReferenceType[ase.Atoms] | None
atoms: ase.Atoms | None
snapshot: _GeometrySnapshot
neighbor_list: NeighborListData
@classmethod
def create(
cls,
atoms: ase.Atoms,
neighbor_list: NeighborListData,
) -> "_GeometryCacheEntry":
"""Create a cache entry, preferring a weak owner reference."""
try:
atoms_ref = weakref.ref(atoms)
except TypeError:
return cls(
atoms_ref=None,
atoms=atoms,
snapshot=_GeometrySnapshot.from_atoms(atoms),
neighbor_list=neighbor_list,
)
return cls(
atoms_ref=atoms_ref,
atoms=None,
snapshot=_GeometrySnapshot.from_atoms(atoms),
neighbor_list=neighbor_list,
)
def matches(self, atoms: ase.Atoms) -> bool:
"""Return whether this entry still belongs to unchanged ``atoms``."""
owner = self.atoms if self.atoms_ref is None else self.atoms_ref()
return owner is atoms and self.snapshot.matches(atoms)
@dataclass(frozen=True)
class _GeometryCacheCandidate:
"""First sighting of one geometry before paying for a validation snapshot."""
atoms_ref: weakref.ReferenceType[ase.Atoms] | None
atoms: ase.Atoms | None
@classmethod
def create(cls, atoms: ase.Atoms) -> "_GeometryCacheCandidate":
"""Create a candidate owner reference."""
try:
return cls(atoms_ref=weakref.ref(atoms), atoms=None)
except TypeError:
return cls(atoms_ref=None, atoms=atoms)
def matches(self, atoms: ase.Atoms) -> bool:
"""Return whether this candidate still belongs to ``atoms``."""
owner = self.atoms if self.atoms_ref is None else self.atoms_ref()
return owner is atoms
def _neighbor_list_option_key(
atoms: ase.Atoms,
*,
cutoff: float,
backend: NeighborListBackend,
arrays: str,
full_list: bool,
sorted: bool,
dtype: torch.dtype,
device: Optional[torch.device],
) -> tuple[object, ...]:
"""Build the non-geometric cache key for one ASE neighbor-list request."""
resolved_device = None if device is None else str(torch.device(device))
return (
id(atoms),
float(cutoff),
backend.value,
arrays,
bool(full_list),
bool(sorted),
str(dtype),
resolved_device,
)
[docs]
@dataclass
class GeometryNeighborListCache:
"""
Caller-owned LRU cache for ASE neighbor lists keyed by exact geometry.
The key includes positions, cell, periodic flags, atomic numbers, cutoff,
backend, list options, dtype, and device assumptions. The cache stores the
normalized :class:`NeighborListData` returned by the builder and leaves tensor
dtype/device coercion to :class:`UFPInput`.
"""
max_size: int = 128
min_atoms: int = 8
_entries: OrderedDict[tuple[object, ...], _GeometryCacheEntry] = field(
default_factory=OrderedDict,
init=False,
repr=False,
)
_candidates: OrderedDict[tuple[object, ...], _GeometryCacheCandidate] = field(
default_factory=OrderedDict,
init=False,
repr=False,
)
def __post_init__(self) -> None:
"""Validate cache sizing."""
if self.max_size <= 0:
raise ValueError("`max_size` must be positive")
if self.min_atoms <= 0:
raise ValueError("`min_atoms` must be positive")
def __len__(self) -> int:
"""Return the number of cached geometry entries."""
return len(self._entries) + len(self._candidates)
[docs]
def clear(self) -> None:
"""Remove all cached neighbor lists."""
self._entries.clear()
self._candidates.clear()
def _trim(self) -> None:
"""Trim candidate and validated entries to the configured maximum."""
while len(self) > self.max_size and self._candidates:
self._candidates.popitem(last=False)
while len(self) > self.max_size:
self._entries.popitem(last=False)
[docs]
def get_or_build(
self,
*,
atoms: ase.Atoms,
cutoff: float,
backend: NeighborListBackend,
arrays: str = "torch",
full_list: bool = True,
sorted: bool = True,
dtype: torch.dtype,
device: Optional[torch.device] = None,
) -> NeighborListData:
"""Return a cached neighbor list or build and store one."""
key = _neighbor_list_option_key(
atoms,
cutoff=cutoff,
backend=backend,
arrays=arrays,
full_list=full_list,
sorted=sorted,
dtype=dtype,
device=device,
)
cached = self._entries.get(key)
if cached is not None and cached.matches(atoms):
self._entries.move_to_end(key)
return cached.neighbor_list
if cached is not None:
del self._entries[key]
candidate = self._candidates.get(key)
promote = candidate is not None and candidate.matches(atoms)
if candidate is not None:
del self._candidates[key]
neighbor_list = build_neighbor_list(
atoms=atoms,
cutoff=cutoff,
backend=backend,
arrays=arrays,
full_list=full_list,
sorted=sorted,
)
if promote:
self._entries[key] = _GeometryCacheEntry.create(atoms, neighbor_list)
elif len(self) < self.max_size:
self._candidates[key] = _GeometryCacheCandidate.create(atoms)
self._trim()
return neighbor_list
def normalize_ase_atoms(
atoms: Union[ase.Atoms, Sequence[ase.Atoms]],
) -> list[ase.Atoms]:
"""Normalize ase atoms."""
if isinstance(atoms, ase.Atoms):
atoms_list = [atoms]
else:
atoms_list = list(atoms)
if not atoms_list:
raise ValueError("`atoms` must contain at least one ASE structure")
if any(not isinstance(item, ase.Atoms) for item in atoms_list):
raise TypeError("`atoms` should be ase.Atoms or a sequence of ase.Atoms")
return atoms_list
def atom_offsets(atoms_list: Sequence[ase.Atoms]) -> list[int]:
"""Return the starting atom index of each system in a concatenated batch."""
offsets = []
running_offset = 0
for atoms in atoms_list:
offsets.append(running_offset)
running_offset += len(atoms)
return offsets
def prepare_ase_input(
atoms: Union[ase.Atoms, Sequence[ase.Atoms]],
*,
cutoff: Optional[float],
default_backend: NeighborListBackend,
neighbor_list: Optional[Union[NeighborListData, Sequence[NeighborListData]]] = None,
backend: Optional[Union[str, NeighborListBackend]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
neighbor_list_cache: Optional[GeometryNeighborListCache] = None,
state: Optional[UFPInputState] = None,
extract_state: bool = True,
) -> UFPInput:
"""Prepare ase input."""
resolved_dtype = torch.get_default_dtype() if dtype is None else dtype
atoms_list = normalize_ase_atoms(atoms)
offsets = atom_offsets(atoms_list)
if neighbor_list is None and cutoff is not None:
resolved_backend = (
default_backend if backend is None else NeighborListBackend(backend)
)
active_cache = neighbor_list_cache
if active_cache is not None and all(
len(item) < active_cache.min_atoms for item in atoms_list
):
active_cache = None
if active_cache is None:
per_system_neighbor_lists = [
build_neighbor_list(
atoms=item,
cutoff=cutoff,
backend=resolved_backend,
arrays="torch",
)
for item in atoms_list
]
else:
per_system_neighbor_lists = [
(
active_cache.get_or_build(
atoms=item,
cutoff=cutoff,
backend=resolved_backend,
arrays="torch",
dtype=resolved_dtype,
device=device,
)
if len(item) >= active_cache.min_atoms
else build_neighbor_list(
atoms=item,
cutoff=cutoff,
backend=resolved_backend,
arrays="torch",
)
)
for item in atoms_list
]
neighbor_list = concatenate_neighbor_lists(
per_system_neighbor_lists,
atom_offsets=offsets,
)
elif isinstance(neighbor_list, Sequence) and not isinstance(
neighbor_list,
NeighborListData,
):
neighbor_list = concatenate_neighbor_lists(
list(neighbor_list),
atom_offsets=offsets,
)
return UFPInput.from_ase_list(
atoms_list,
neighbor_list=neighbor_list,
dtype=resolved_dtype,
device=device,
requires_grad=requires_grad,
state=state,
extract_state=extract_state,
)
def derive_forces_from_energy(
output: UFPOutput,
inputs: UFPInput,
*,
training: bool,
) -> torch.Tensor:
"""Differentiate the total energy with respect to positions for one flat input."""
if output.energy is None:
raise RuntimeError(
"automatic force derivation requires the model to return `energy`"
)
if not isinstance(output.energy, torch.Tensor):
raise TypeError(
"automatic force derivation requires `energy` to be a torch.Tensor"
)
if not inputs.positions.requires_grad:
raise RuntimeError(
"automatic force derivation requires `inputs.positions.requires_grad`"
)
if not output.energy.requires_grad:
return torch.zeros_like(inputs.positions)
total_energy = output.energy.reshape(-1).sum()
gradients = torch.autograd.grad(
total_energy,
inputs.positions,
retain_graph=training,
create_graph=training,
allow_unused=False,
)[0]
return -gradients
def derive_batched_forces_from_energy(
output: UFPOutput,
inputs: UFPInput,
*,
training: bool,
independence_tolerance: float,
) -> torch.Tensor:
"""Differentiate each system energy separately and enforce batch independence."""
if inputs.n_systems == 1:
return derive_forces_from_energy(output, inputs, training=training)
if output.energy is None:
raise RuntimeError(
"automatic force derivation requires the model to return `energy`"
)
if not isinstance(output.energy, torch.Tensor):
raise TypeError(
"automatic force derivation requires `energy` to be a torch.Tensor"
)
if not inputs.positions.requires_grad:
raise RuntimeError(
"automatic force derivation requires `inputs.positions.requires_grad`"
)
energies = output.energy.reshape(inputs.n_systems, -1)
if energies.shape[1] != 1:
raise ValueError("batch force derivation requires one total energy per system")
forces = torch.zeros_like(inputs.positions)
for system_i, atom_slice in enumerate(inputs.atom_slices):
system_energy = energies[system_i, 0]
if not system_energy.requires_grad:
continue
full_gradient = torch.autograd.grad(
system_energy,
inputs.positions,
retain_graph=training or system_i + 1 < inputs.n_systems,
create_graph=training,
allow_unused=False,
)[0]
leaked_gradient = torch.cat(
[
full_gradient[: atom_slice.start],
full_gradient[atom_slice.stop :],
],
dim=0,
)
if leaked_gradient.numel() != 0:
max_leak = torch.max(torch.abs(leaked_gradient)).item()
if max_leak > independence_tolerance:
raise RuntimeError(
"system energies are not independent inside the batch: "
f"energy[{system_i}] has gradient magnitude {max_leak:.3e} "
"with respect to atoms from another system"
)
forces[atom_slice] = -full_gradient[atom_slice]
return forces
def validate_output(output: UFPOutput, inputs: UFPInput) -> None:
"""Check that one output matches the shape and batching implied by the input."""
if not isinstance(output, UFPOutput):
raise TypeError(f"`forward` must return UFPOutput, got {type(output).__name__}")
if output.energy is not None:
energy_shape = _shape(output.energy)
valid_single = {(), (1,), (1, 1)}
valid_batch = {(inputs.n_systems,), (inputs.n_systems, 1)}
valid_shapes = (
valid_batch if inputs.n_systems > 1 else valid_single | valid_batch
)
if energy_shape not in valid_shapes:
raise ValueError(
"`energy` must have shape "
f"({inputs.n_systems},), ({inputs.n_systems}, 1), or be a "
f"single scalar for one-system inputs. Got {energy_shape}."
)
if output.forces is not None:
forces_shape = _shape(output.forces)
if forces_shape != (inputs.n_atoms, 3):
raise ValueError(
f"`forces` must have shape ({inputs.n_atoms}, 3), got {forces_shape}"
)
if output.per_atom_energy is not None:
per_atom_shape = _shape(output.per_atom_energy)
if per_atom_shape not in {(inputs.n_atoms,), (inputs.n_atoms, 1)}:
raise ValueError(
"`per_atom_energy` must have shape "
f"({inputs.n_atoms},) or ({inputs.n_atoms}, 1), got "
f"{per_atom_shape}"
)
if output.stress is not None:
stress_shape = _shape(output.stress)
valid_stress_shapes: set[tuple[int, ...]] = {(inputs.n_systems, 3, 3)}
if inputs.n_systems == 1:
valid_stress_shapes |= {(3, 3), (6,), (1, 3, 3)}
if stress_shape not in valid_stress_shapes:
raise ValueError(
"`stress` must have shape "
f"({inputs.n_systems}, 3, 3) for batches, or (3, 3)/(6,) for "
f"single systems. Got {stress_shape}."
)