Source code for ufp.core._execution

"""
Execution helpers shared by potential compute paths.

This module prepares ASE inputs, derives forces from energies, and validates
model outputs against the normalized ``UFPInput`` contract.
"""

from __future__ import annotations

import weakref
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Optional, Sequence, Union

import ase
import numpy as np
import torch

from ufp.core.input import UFPInput, UFPInputState
from ufp.core.output import UFPOutput
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list


def _shape(array) -> tuple[int, ...]:
    """Return a tuple view of a tensor shape for error messages."""
    return tuple(array.shape)


@dataclass(frozen=True)
class _GeometrySnapshot:
    """Exact copied geometry used to validate identity-cache hits."""

    positions: np.ndarray
    cell: np.ndarray
    pbc: np.ndarray
    numbers: np.ndarray

    @classmethod
    def from_atoms(cls, atoms: ase.Atoms) -> "_GeometrySnapshot":
        """Copy the geometry fields that define an ASE neighbor list."""
        return cls(
            positions=np.array(atoms.positions, copy=True),
            cell=np.array(atoms.cell.array, copy=True),
            pbc=np.array(atoms.pbc, dtype=np.bool_, copy=True),
            numbers=np.array(atoms.numbers, copy=True),
        )

    def matches(self, atoms: ase.Atoms) -> bool:
        """Return whether the stored geometry exactly matches ``atoms``."""
        return (
            np.array_equal(self.positions, atoms.positions)
            and np.array_equal(self.cell, atoms.cell.array)
            and np.array_equal(self.pbc, np.asarray(atoms.pbc, dtype=np.bool_))
            and np.array_equal(self.numbers, atoms.numbers)
        )


@dataclass(frozen=True)
class _GeometryCacheEntry:
    """Cached neighbor list plus the geometry snapshot used for validation."""

    atoms_ref: weakref.ReferenceType[ase.Atoms] | None
    atoms: ase.Atoms | None
    snapshot: _GeometrySnapshot
    neighbor_list: NeighborListData

    @classmethod
    def create(
        cls,
        atoms: ase.Atoms,
        neighbor_list: NeighborListData,
    ) -> "_GeometryCacheEntry":
        """Create a cache entry, preferring a weak owner reference."""
        try:
            atoms_ref = weakref.ref(atoms)
        except TypeError:
            return cls(
                atoms_ref=None,
                atoms=atoms,
                snapshot=_GeometrySnapshot.from_atoms(atoms),
                neighbor_list=neighbor_list,
            )

        return cls(
            atoms_ref=atoms_ref,
            atoms=None,
            snapshot=_GeometrySnapshot.from_atoms(atoms),
            neighbor_list=neighbor_list,
        )

    def matches(self, atoms: ase.Atoms) -> bool:
        """Return whether this entry still belongs to unchanged ``atoms``."""
        owner = self.atoms if self.atoms_ref is None else self.atoms_ref()
        return owner is atoms and self.snapshot.matches(atoms)


@dataclass(frozen=True)
class _GeometryCacheCandidate:
    """First sighting of one geometry before paying for a validation snapshot."""

    atoms_ref: weakref.ReferenceType[ase.Atoms] | None
    atoms: ase.Atoms | None

    @classmethod
    def create(cls, atoms: ase.Atoms) -> "_GeometryCacheCandidate":
        """Create a candidate owner reference."""
        try:
            return cls(atoms_ref=weakref.ref(atoms), atoms=None)
        except TypeError:
            return cls(atoms_ref=None, atoms=atoms)

    def matches(self, atoms: ase.Atoms) -> bool:
        """Return whether this candidate still belongs to ``atoms``."""
        owner = self.atoms if self.atoms_ref is None else self.atoms_ref()
        return owner is atoms


def _neighbor_list_option_key(
    atoms: ase.Atoms,
    *,
    cutoff: float,
    backend: NeighborListBackend,
    arrays: str,
    full_list: bool,
    sorted: bool,
    dtype: torch.dtype,
    device: Optional[torch.device],
) -> tuple[object, ...]:
    """Build the non-geometric cache key for one ASE neighbor-list request."""
    resolved_device = None if device is None else str(torch.device(device))
    return (
        id(atoms),
        float(cutoff),
        backend.value,
        arrays,
        bool(full_list),
        bool(sorted),
        str(dtype),
        resolved_device,
    )


[docs] @dataclass class GeometryNeighborListCache: """ Caller-owned LRU cache for ASE neighbor lists keyed by exact geometry. The key includes positions, cell, periodic flags, atomic numbers, cutoff, backend, list options, dtype, and device assumptions. The cache stores the normalized :class:`NeighborListData` returned by the builder and leaves tensor dtype/device coercion to :class:`UFPInput`. """ max_size: int = 128 min_atoms: int = 8 _entries: OrderedDict[tuple[object, ...], _GeometryCacheEntry] = field( default_factory=OrderedDict, init=False, repr=False, ) _candidates: OrderedDict[tuple[object, ...], _GeometryCacheCandidate] = field( default_factory=OrderedDict, init=False, repr=False, ) def __post_init__(self) -> None: """Validate cache sizing.""" if self.max_size <= 0: raise ValueError("`max_size` must be positive") if self.min_atoms <= 0: raise ValueError("`min_atoms` must be positive") def __len__(self) -> int: """Return the number of cached geometry entries.""" return len(self._entries) + len(self._candidates)
[docs] def clear(self) -> None: """Remove all cached neighbor lists.""" self._entries.clear() self._candidates.clear()
def _trim(self) -> None: """Trim candidate and validated entries to the configured maximum.""" while len(self) > self.max_size and self._candidates: self._candidates.popitem(last=False) while len(self) > self.max_size: self._entries.popitem(last=False)
[docs] def get_or_build( self, *, atoms: ase.Atoms, cutoff: float, backend: NeighborListBackend, arrays: str = "torch", full_list: bool = True, sorted: bool = True, dtype: torch.dtype, device: Optional[torch.device] = None, ) -> NeighborListData: """Return a cached neighbor list or build and store one.""" key = _neighbor_list_option_key( atoms, cutoff=cutoff, backend=backend, arrays=arrays, full_list=full_list, sorted=sorted, dtype=dtype, device=device, ) cached = self._entries.get(key) if cached is not None and cached.matches(atoms): self._entries.move_to_end(key) return cached.neighbor_list if cached is not None: del self._entries[key] candidate = self._candidates.get(key) promote = candidate is not None and candidate.matches(atoms) if candidate is not None: del self._candidates[key] neighbor_list = build_neighbor_list( atoms=atoms, cutoff=cutoff, backend=backend, arrays=arrays, full_list=full_list, sorted=sorted, ) if promote: self._entries[key] = _GeometryCacheEntry.create(atoms, neighbor_list) elif len(self) < self.max_size: self._candidates[key] = _GeometryCacheCandidate.create(atoms) self._trim() return neighbor_list
def normalize_ase_atoms( atoms: Union[ase.Atoms, Sequence[ase.Atoms]], ) -> list[ase.Atoms]: """Normalize ase atoms.""" if isinstance(atoms, ase.Atoms): atoms_list = [atoms] else: atoms_list = list(atoms) if not atoms_list: raise ValueError("`atoms` must contain at least one ASE structure") if any(not isinstance(item, ase.Atoms) for item in atoms_list): raise TypeError("`atoms` should be ase.Atoms or a sequence of ase.Atoms") return atoms_list def atom_offsets(atoms_list: Sequence[ase.Atoms]) -> list[int]: """Return the starting atom index of each system in a concatenated batch.""" offsets = [] running_offset = 0 for atoms in atoms_list: offsets.append(running_offset) running_offset += len(atoms) return offsets def prepare_ase_input( atoms: Union[ase.Atoms, Sequence[ase.Atoms]], *, cutoff: Optional[float], default_backend: NeighborListBackend, neighbor_list: Optional[Union[NeighborListData, Sequence[NeighborListData]]] = None, backend: Optional[Union[str, NeighborListBackend]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, neighbor_list_cache: Optional[GeometryNeighborListCache] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> UFPInput: """Prepare ase input.""" resolved_dtype = torch.get_default_dtype() if dtype is None else dtype atoms_list = normalize_ase_atoms(atoms) offsets = atom_offsets(atoms_list) if neighbor_list is None and cutoff is not None: resolved_backend = ( default_backend if backend is None else NeighborListBackend(backend) ) active_cache = neighbor_list_cache if active_cache is not None and all( len(item) < active_cache.min_atoms for item in atoms_list ): active_cache = None if active_cache is None: per_system_neighbor_lists = [ build_neighbor_list( atoms=item, cutoff=cutoff, backend=resolved_backend, arrays="torch", ) for item in atoms_list ] else: per_system_neighbor_lists = [ ( active_cache.get_or_build( atoms=item, cutoff=cutoff, backend=resolved_backend, arrays="torch", dtype=resolved_dtype, device=device, ) if len(item) >= active_cache.min_atoms else build_neighbor_list( atoms=item, cutoff=cutoff, backend=resolved_backend, arrays="torch", ) ) for item in atoms_list ] neighbor_list = concatenate_neighbor_lists( per_system_neighbor_lists, atom_offsets=offsets, ) elif isinstance(neighbor_list, Sequence) and not isinstance( neighbor_list, NeighborListData, ): neighbor_list = concatenate_neighbor_lists( list(neighbor_list), atom_offsets=offsets, ) return UFPInput.from_ase_list( atoms_list, neighbor_list=neighbor_list, dtype=resolved_dtype, device=device, requires_grad=requires_grad, state=state, extract_state=extract_state, ) def derive_forces_from_energy( output: UFPOutput, inputs: UFPInput, *, training: bool, ) -> torch.Tensor: """Differentiate the total energy with respect to positions for one flat input.""" if output.energy is None: raise RuntimeError( "automatic force derivation requires the model to return `energy`" ) if not isinstance(output.energy, torch.Tensor): raise TypeError( "automatic force derivation requires `energy` to be a torch.Tensor" ) if not inputs.positions.requires_grad: raise RuntimeError( "automatic force derivation requires `inputs.positions.requires_grad`" ) if not output.energy.requires_grad: return torch.zeros_like(inputs.positions) total_energy = output.energy.reshape(-1).sum() gradients = torch.autograd.grad( total_energy, inputs.positions, retain_graph=training, create_graph=training, allow_unused=False, )[0] return -gradients def derive_batched_forces_from_energy( output: UFPOutput, inputs: UFPInput, *, training: bool, independence_tolerance: float, ) -> torch.Tensor: """Differentiate each system energy separately and enforce batch independence.""" if inputs.n_systems == 1: return derive_forces_from_energy(output, inputs, training=training) if output.energy is None: raise RuntimeError( "automatic force derivation requires the model to return `energy`" ) if not isinstance(output.energy, torch.Tensor): raise TypeError( "automatic force derivation requires `energy` to be a torch.Tensor" ) if not inputs.positions.requires_grad: raise RuntimeError( "automatic force derivation requires `inputs.positions.requires_grad`" ) energies = output.energy.reshape(inputs.n_systems, -1) if energies.shape[1] != 1: raise ValueError("batch force derivation requires one total energy per system") forces = torch.zeros_like(inputs.positions) for system_i, atom_slice in enumerate(inputs.atom_slices): system_energy = energies[system_i, 0] if not system_energy.requires_grad: continue full_gradient = torch.autograd.grad( system_energy, inputs.positions, retain_graph=training or system_i + 1 < inputs.n_systems, create_graph=training, allow_unused=False, )[0] leaked_gradient = torch.cat( [ full_gradient[: atom_slice.start], full_gradient[atom_slice.stop :], ], dim=0, ) if leaked_gradient.numel() != 0: max_leak = torch.max(torch.abs(leaked_gradient)).item() if max_leak > independence_tolerance: raise RuntimeError( "system energies are not independent inside the batch: " f"energy[{system_i}] has gradient magnitude {max_leak:.3e} " "with respect to atoms from another system" ) forces[atom_slice] = -full_gradient[atom_slice] return forces def validate_output(output: UFPOutput, inputs: UFPInput) -> None: """Check that one output matches the shape and batching implied by the input.""" if not isinstance(output, UFPOutput): raise TypeError(f"`forward` must return UFPOutput, got {type(output).__name__}") if output.energy is not None: energy_shape = _shape(output.energy) valid_single = {(), (1,), (1, 1)} valid_batch = {(inputs.n_systems,), (inputs.n_systems, 1)} valid_shapes = ( valid_batch if inputs.n_systems > 1 else valid_single | valid_batch ) if energy_shape not in valid_shapes: raise ValueError( "`energy` must have shape " f"({inputs.n_systems},), ({inputs.n_systems}, 1), or be a " f"single scalar for one-system inputs. Got {energy_shape}." ) if output.forces is not None: forces_shape = _shape(output.forces) if forces_shape != (inputs.n_atoms, 3): raise ValueError( f"`forces` must have shape ({inputs.n_atoms}, 3), got {forces_shape}" ) if output.per_atom_energy is not None: per_atom_shape = _shape(output.per_atom_energy) if per_atom_shape not in {(inputs.n_atoms,), (inputs.n_atoms, 1)}: raise ValueError( "`per_atom_energy` must have shape " f"({inputs.n_atoms},) or ({inputs.n_atoms}, 1), got " f"{per_atom_shape}" ) if output.stress is not None: stress_shape = _shape(output.stress) valid_stress_shapes: set[tuple[int, ...]] = {(inputs.n_systems, 3, 3)} if inputs.n_systems == 1: valid_stress_shapes |= {(3, 3), (6,), (1, 3, 3)} if stress_shape not in valid_stress_shapes: raise ValueError( "`stress` must have shape " f"({inputs.n_systems}, 3, 3) for batches, or (3, 3)/(6,) for " f"single systems. Got {stress_shape}." )