"""Advanced prepared-geometry helpers for workflow-level cache reuse."""
from __future__ import annotations
import hashlib
import json
from collections.abc import Mapping, Sequence
from dataclasses import asdict, dataclass, field
from pathlib import Path
import ase
import numpy as np
import torch
from ufp.core.input import UFPInput
from ufp.core.potential import UFPotential
from ufp.leastsquares.dataset import PreparedBatch
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list
from ufp.terms._base import TermCacheOptions
from ufp.training.batch import ASEAtomsBatch
_PREPARED_METADATA_VERSION = 1
_PREPARED_INPUT_METADATA_KEY = "prepared_geometry"
_THREEBODY_METADATA_PREFIX = "_ufp_threebody_"
def _normalized_tensor(value: object, name: str) -> torch.Tensor:
"""Return a tensor from a container that has already been normalized."""
if not isinstance(value, torch.Tensor):
raise TypeError(f"`{name}` must be a normalized torch.Tensor")
return value
def _json_dumps(payload: object) -> str:
"""Return deterministic JSON for metadata payloads."""
return json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str)
def _sha256_json(payload: object) -> str:
"""Hash one JSON-serializable payload."""
return hashlib.sha256(_json_dumps(payload).encode("utf-8")).hexdigest()
def _hash_tensor(
hasher: "hashlib._Hash",
tensor: torch.Tensor,
*,
dtype: torch.dtype | None = None,
) -> None:
"""Add tensor shape, dtype, device-independent values to a signature."""
value = tensor.detach().cpu()
if dtype is not None:
value = value.to(dtype=dtype)
array = np.ascontiguousarray(value.numpy())
hasher.update(str(array.dtype).encode("utf-8"))
hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
hasher.update(array.tobytes())
def _input_geometry_hash(inputs: UFPInput) -> str:
"""Return a strict geometry signature for one normalized input."""
hasher = hashlib.sha256()
_hash_tensor(
hasher,
_normalized_tensor(inputs.positions, "inputs.positions"),
dtype=torch.float64,
)
_hash_tensor(
hasher,
_normalized_tensor(inputs.cell, "inputs.cell"),
dtype=torch.float64,
)
_hash_tensor(hasher, _normalized_tensor(inputs.pbc, "inputs.pbc"))
_hash_tensor(
hasher,
_normalized_tensor(inputs.atomic_numbers, "inputs.atomic_numbers"),
)
_hash_tensor(
hasher,
_normalized_tensor(inputs.system_index, "inputs.system_index"),
)
neighbor_list = inputs.neighbor_list
if neighbor_list is None:
hasher.update(b"neighbor_list:none")
else:
hasher.update(b"neighbor_list")
_hash_tensor(
hasher,
_normalized_tensor(neighbor_list.pairs, "neighbor_list.pairs"),
)
_hash_tensor(
hasher,
_normalized_tensor(neighbor_list.shifts, "neighbor_list.shifts"),
)
if neighbor_list.distances is None:
hasher.update(b"distances:none")
else:
_hash_tensor(
hasher,
_normalized_tensor(neighbor_list.distances, "neighbor_list.distances"),
dtype=torch.float64,
)
if neighbor_list.vectors is None:
hasher.update(b"vectors:none")
else:
_hash_tensor(
hasher,
_normalized_tensor(neighbor_list.vectors, "neighbor_list.vectors"),
dtype=torch.float64,
)
hasher.update(str(neighbor_list.backend).encode("utf-8"))
hasher.update(str(neighbor_list.cutoff).encode("utf-8"))
hasher.update(str(bool(neighbor_list.full_list)).encode("utf-8"))
hasher.update(str(bool(neighbor_list.sorted)).encode("utf-8"))
hasher.update(str(neighbor_list.strict).encode("utf-8"))
return hasher.hexdigest()
def _source_identity(atoms: ase.Atoms) -> str | None:
"""Return a user-provided structure id when one is available."""
for key in ("source_id", "uid", "id", "name"):
value = atoms.info.get(key)
if value is not None:
return str(value)
return None
def _source_hash(atoms: ase.Atoms) -> str:
"""Hash ASE structure content used by UFP geometry conversion."""
hasher = hashlib.sha256()
hasher.update(np.ascontiguousarray(atoms.numbers, dtype=np.int64).tobytes())
hasher.update(np.ascontiguousarray(atoms.positions, dtype=np.float64).tobytes())
hasher.update(np.ascontiguousarray(atoms.cell.array, dtype=np.float64).tobytes())
hasher.update(np.ascontiguousarray(np.asarray(atoms.pbc), dtype=bool).tobytes())
return hasher.hexdigest()
def _source_metadata(
source_atoms: Sequence[ase.Atoms] | None,
) -> tuple[tuple[str | None, ...], tuple[str, ...]]:
"""Return source ids and content hashes when ASE sources are available."""
if source_atoms is None:
return (), ()
return (
tuple(_source_identity(atoms) for atoms in source_atoms),
tuple(_source_hash(atoms) for atoms in source_atoms),
)
def _normalize_species_ordering(
inputs: UFPInput,
species_ordering: Sequence[int] | None,
) -> tuple[int, ...]:
"""Normalize species ordering metadata for category construction."""
if species_ordering is not None:
return tuple(sorted(set(int(value) for value in species_ordering)))
values = (
_normalized_tensor(inputs.atomic_numbers, "inputs.atomic_numbers")
.detach()
.cpu()
.tolist()
)
return tuple(sorted(set(int(value) for value in values)))
def _periodicity(inputs: UFPInput) -> tuple[tuple[bool, bool, bool], ...]:
"""Return JSON-friendly periodic boundary flags."""
rows = _normalized_tensor(inputs.pbc, "inputs.pbc").detach().cpu().tolist()
periodicity: list[tuple[bool, bool, bool]] = []
for row in rows:
periodicity.append((bool(row[0]), bool(row[1]), bool(row[2])))
return tuple(periodicity)
[docs]
@dataclass(frozen=True)
class PairCategorySpec:
"""Pair-category request stored by a prepared geometry object."""
atomic_types: tuple[int, ...]
symmetric: bool = True
def __init__(self, atomic_types: Sequence[int], symmetric: bool = True) -> None:
"""Normalize category metadata to match UFP pair-category ordering."""
object.__setattr__(
self,
"atomic_types",
tuple(sorted(set(int(value) for value in atomic_types))),
)
object.__setattr__(self, "symmetric", bool(symmetric))
[docs]
@dataclass(frozen=True)
class PreparedGeometryReuseEstimate:
"""Memory and reuse estimate for a prepared geometry object."""
n_reuses: int
n_atoms: int
n_pairs: int
total_nbytes: int
pair_geometry_nbytes: int
triplet_cache_nbytes: int
avoided_pair_geometry_nbytes: int
reuse_helpful: bool
memory_dominated: bool
def _tensor_nbytes(value: object, seen: set[int]) -> int:
"""Return tensor memory estimate while avoiding exact object duplicates."""
if not isinstance(value, torch.Tensor):
return 0
key = id(value)
if key in seen:
return 0
seen.add(key)
return int(value.numel() * value.element_size())
def _neighbor_list_nbytes(
neighbor_list: NeighborListData | None,
seen: set[int],
) -> int:
"""Return estimated tensor storage for a neighbor list."""
if neighbor_list is None:
return 0
return sum(
_tensor_nbytes(value, seen)
for value in (
neighbor_list.pairs,
neighbor_list.shifts,
neighbor_list.distances,
neighbor_list.vectors,
)
)
def _nested_tensor_nbytes(value: object, seen: set[int]) -> int:
"""Estimate tensor bytes recursively for cache metadata containers."""
if isinstance(value, torch.Tensor):
return _tensor_nbytes(value, seen)
if isinstance(value, NeighborListData):
return _neighbor_list_nbytes(value, seen)
if isinstance(value, Mapping):
return sum(_nested_tensor_nbytes(item, seen) for item in value.values())
if isinstance(value, (tuple, list)):
return sum(_nested_tensor_nbytes(item, seen) for item in value)
if hasattr(value, "__dict__"):
return _nested_tensor_nbytes(vars(value), seen)
return 0
[docs]
@dataclass(frozen=True)
class PreparedGeometry:
"""
Reusable tensorized geometry for advanced workflow caching.
This object intentionally stays outside term ``forward`` dispatch. Convert it
back to ``UFPInput`` when a runtime model path is needed.
"""
positions: torch.Tensor
cell: torch.Tensor
pbc: torch.Tensor
atomic_numbers: torch.Tensor
system_index: torch.Tensor
neighbor_list: NeighborListData | None
pair_system_index: torch.Tensor | None
pair_vectors: torch.Tensor | None
pair_distances: torch.Tensor | None
pair_category_indices: Mapping[PairCategorySpec, torch.Tensor]
metadata: PreparedGeometryMetadata
input_metadata: Mapping[str, object] = field(default_factory=dict)
triplet_caches: Mapping[str, object] = field(default_factory=dict)
source_atoms: Sequence[ase.Atoms] | None = field(default=None, repr=False)
@property
def n_atoms(self) -> int:
"""Return the number of atoms in the prepared geometry."""
return int(self.positions.shape[0])
@property
def n_systems(self) -> int:
"""Return the number of systems in the prepared geometry."""
return int(self.cell.shape[0])
@property
def n_pairs(self) -> int:
"""Return the number of neighbor-list rows."""
if self.neighbor_list is None:
return 0
return int(self.neighbor_list.n_pairs)
@property
def memory_nbytes(self) -> int:
"""Return an estimated memory footprint for owned tensor payloads."""
seen: set[int] = set()
total = sum(
_tensor_nbytes(value, seen)
for value in (
self.positions,
self.cell,
self.pbc,
self.atomic_numbers,
self.system_index,
self.pair_system_index,
self.pair_vectors,
self.pair_distances,
)
)
total += _neighbor_list_nbytes(self.neighbor_list, seen)
total += sum(
_tensor_nbytes(value, seen) for value in self.pair_category_indices.values()
)
total += _nested_tensor_nbytes(self.triplet_caches, seen)
return int(total)
@property
def pair_geometry_nbytes(self) -> int:
"""Return the estimated memory footprint of reusable pair geometry."""
seen: set[int] = set()
total = sum(
_tensor_nbytes(value, seen)
for value in (
self.pair_system_index,
self.pair_vectors,
self.pair_distances,
)
)
total += sum(
_tensor_nbytes(value, seen) for value in self.pair_category_indices.values()
)
return int(total)
@property
def triplet_cache_nbytes(self) -> int:
"""Return the estimated memory footprint of opaque triplet caches."""
return _nested_tensor_nbytes(self.triplet_caches, set())
[docs]
def reuse_estimate(
self,
*,
n_reuses: int,
memory_budget_bytes: int | None = None,
) -> PreparedGeometryReuseEstimate:
"""Estimate when prepared pair geometry reuse is worth its memory cost."""
n_reuses = int(n_reuses)
if n_reuses <= 0:
raise ValueError("`n_reuses` must be positive")
pair_bytes = self.pair_geometry_nbytes
avoided = max(0, n_reuses - 1) * pair_bytes
return PreparedGeometryReuseEstimate(
n_reuses=n_reuses,
n_atoms=self.n_atoms,
n_pairs=self.n_pairs,
total_nbytes=self.memory_nbytes,
pair_geometry_nbytes=pair_bytes,
triplet_cache_nbytes=self.triplet_cache_nbytes,
avoided_pair_geometry_nbytes=avoided,
reuse_helpful=bool(self.n_pairs > 0 and avoided > 0),
memory_dominated=(
False
if memory_budget_bytes is None
else self.memory_nbytes > int(memory_budget_bytes)
),
)
[docs]
def assert_valid_for(
self,
inputs: UFPInput,
*,
species_ordering: Sequence[int] | None = None,
cutoff: float | None = None,
) -> None:
"""Raise if ``inputs`` no longer match this prepared geometry metadata."""
expected = PreparedGeometryMetadata.from_input(
inputs,
species_ordering=(
self.metadata.species_ordering
if species_ordering is None
else species_ordering
),
cutoff=self.metadata.cutoff if cutoff is None else cutoff,
)
if expected.to_dict() != self.metadata.to_dict():
raise ValueError("prepared geometry metadata does not match input")
[docs]
def is_valid_for(
self,
inputs: UFPInput,
*,
species_ordering: Sequence[int] | None = None,
cutoff: float | None = None,
) -> bool:
"""Return whether ``inputs`` match this prepared geometry metadata."""
try:
self.assert_valid_for(
inputs,
species_ordering=species_ordering,
cutoff=cutoff,
)
except ValueError:
return False
return True
def _normalize_pair_category_specs(
inputs: UFPInput,
*,
species_ordering: Sequence[int] | None,
pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None,
) -> tuple[PairCategorySpec, ...]:
"""Normalize pair-category requests."""
if pair_category_specs is None:
return (
PairCategorySpec(_normalize_species_ordering(inputs, species_ordering)),
)
specs: list[PairCategorySpec] = []
for spec in pair_category_specs:
if isinstance(spec, PairCategorySpec):
specs.append(spec)
else:
specs.append(PairCategorySpec(spec))
return tuple(specs)
def _triplet_cache_metadata(metadata: Mapping[str, object]) -> dict[str, object]:
"""Extract opaque three-body cache entries from input metadata."""
return {
str(key): value
for key, value in metadata.items()
if str(key).startswith(_THREEBODY_METADATA_PREFIX)
}
def _base_input_metadata(metadata: Mapping[str, object]) -> dict[str, object]:
"""Keep non-cache metadata when storing a prepared geometry."""
return {
str(key): value
for key, value in metadata.items()
if not str(key).startswith(_THREEBODY_METADATA_PREFIX)
and str(key) != _PREPARED_INPUT_METADATA_KEY
}
def _cache_triplet_terms(
inputs: UFPInput,
terms: Sequence[object],
*,
feature_cache_storage: str,
feature_cache_dir: Path | str | None,
cache_prefix: str,
include_per_atom_energy: bool,
) -> None:
"""Ask compatible terms to warm triplet caches on a working input."""
for term_index, term in enumerate(terms):
cache_input = getattr(term, "cache_input", None)
if not callable(cache_input):
continue
options = TermCacheOptions(
feature_cache_storage=feature_cache_storage,
feature_cache_dir=None
if feature_cache_dir is None
else Path(feature_cache_dir),
cache_prefix=f"{cache_prefix}_term{term_index}",
include_per_atom_energy=include_per_atom_energy,
)
cache_input(inputs, options=options)
def _normalize_atoms_sequence(
atoms: ase.Atoms | Sequence[ase.Atoms],
) -> tuple[ase.Atoms, ...]:
"""Normalize a single ASE object or a sequence of ASE objects."""
if isinstance(atoms, ase.Atoms):
return (atoms,)
normalized = tuple(atoms)
if not normalized:
raise ValueError("`atoms` must contain at least one structure")
if any(not isinstance(item, ase.Atoms) for item in normalized):
raise TypeError("all `atoms` entries must be ase.Atoms")
return normalized
[docs]
def prepare_geometry_from_ase(
atoms: ase.Atoms | Sequence[ase.Atoms],
*,
cutoff: float,
backend: str | NeighborListBackend = NeighborListBackend.AUTO,
full_list: bool = True,
sorted: bool = True,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
species_ordering: Sequence[int] | None = None,
pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None,
triplet_terms: Sequence[object] = (),
feature_cache_storage: str = "none",
feature_cache_dir: Path | str | None = None,
cache_prefix: str = "prepared",
include_per_atom_energy: bool = True,
) -> PreparedGeometry:
"""Build prepared geometry directly from one or more ASE structures."""
atoms_list = _normalize_atoms_sequence(atoms)
cutoff = float(cutoff)
neighbor_lists = [
build_neighbor_list(
item,
cutoff=cutoff,
backend=backend,
arrays="torch",
full_list=full_list,
sorted=sorted,
)
for item in atoms_list
]
atom_offsets = []
offset = 0
for item in atoms_list:
atom_offsets.append(offset)
offset += len(item)
inputs = UFPInput.from_ase_list(
atoms_list,
neighbor_list=concatenate_neighbor_lists(
neighbor_lists,
atom_offsets=atom_offsets,
),
dtype=dtype,
device=None if device is None else torch.device(device),
)
return prepare_geometry_from_input(
inputs,
species_ordering=species_ordering,
pair_category_specs=pair_category_specs,
triplet_terms=triplet_terms,
feature_cache_storage=feature_cache_storage,
feature_cache_dir=feature_cache_dir,
cache_prefix=cache_prefix,
include_per_atom_energy=include_per_atom_energy,
)
[docs]
def prepare_geometry_from_batch(
batch: ASEAtomsBatch,
model: UFPotential,
*,
backend: str | NeighborListBackend | None = None,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
species_ordering: Sequence[int] | None = None,
pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None,
triplet_terms: Sequence[object] = (),
feature_cache_storage: str = "none",
feature_cache_dir: Path | str | None = None,
cache_prefix: str = "prepared",
include_per_atom_energy: bool = True,
) -> PreparedGeometry:
"""Prepare geometry from a training ``ASEAtomsBatch``."""
inputs = batch.prepare_input(
model,
backend=backend,
dtype=dtype,
device=device,
requires_grad=False,
)
return prepare_geometry_from_input(
inputs,
species_ordering=species_ordering,
pair_category_specs=pair_category_specs,
triplet_terms=triplet_terms,
feature_cache_storage=feature_cache_storage,
feature_cache_dir=feature_cache_dir,
cache_prefix=cache_prefix,
include_per_atom_energy=include_per_atom_energy,
)
[docs]
def prepare_geometry_from_prepared_batch(
batch: PreparedBatch,
*,
species_ordering: Sequence[int] | None = None,
pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None,
triplet_terms: Sequence[object] = (),
feature_cache_storage: str = "none",
feature_cache_dir: Path | str | None = None,
cache_prefix: str = "prepared",
include_per_atom_energy: bool = True,
) -> PreparedGeometry:
"""Prepare geometry from a least-squares ``PreparedBatch``."""
return prepare_geometry_from_input(
batch.inputs,
species_ordering=species_ordering,
pair_category_specs=pair_category_specs,
triplet_terms=triplet_terms,
feature_cache_storage=feature_cache_storage,
feature_cache_dir=feature_cache_dir,
cache_prefix=cache_prefix,
include_per_atom_energy=include_per_atom_energy,
)
__all__ = [
"PairCategorySpec",
"PreparedGeometry",
"PreparedGeometryMetadata",
"PreparedGeometryReuseEstimate",
"prepare_geometry_from_ase",
"prepare_geometry_from_batch",
"prepare_geometry_from_input",
"prepare_geometry_from_prepared_batch",
]