Source code for ufp.workflows.prepared

"""Advanced prepared-geometry helpers for workflow-level cache reuse."""

from __future__ import annotations

import hashlib
import json
from collections.abc import Mapping, Sequence
from dataclasses import asdict, dataclass, field
from pathlib import Path

import ase
import numpy as np
import torch

from ufp.core.input import UFPInput
from ufp.core.potential import UFPotential
from ufp.leastsquares.dataset import PreparedBatch
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list
from ufp.terms._base import TermCacheOptions
from ufp.training.batch import ASEAtomsBatch


_PREPARED_METADATA_VERSION = 1
_PREPARED_INPUT_METADATA_KEY = "prepared_geometry"
_THREEBODY_METADATA_PREFIX = "_ufp_threebody_"


def _normalized_tensor(value: object, name: str) -> torch.Tensor:
    """Return a tensor from a container that has already been normalized."""
    if not isinstance(value, torch.Tensor):
        raise TypeError(f"`{name}` must be a normalized torch.Tensor")
    return value


def _json_dumps(payload: object) -> str:
    """Return deterministic JSON for metadata payloads."""
    return json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str)


def _sha256_json(payload: object) -> str:
    """Hash one JSON-serializable payload."""
    return hashlib.sha256(_json_dumps(payload).encode("utf-8")).hexdigest()


def _hash_tensor(
    hasher: "hashlib._Hash",
    tensor: torch.Tensor,
    *,
    dtype: torch.dtype | None = None,
) -> None:
    """Add tensor shape, dtype, device-independent values to a signature."""
    value = tensor.detach().cpu()
    if dtype is not None:
        value = value.to(dtype=dtype)
    array = np.ascontiguousarray(value.numpy())
    hasher.update(str(array.dtype).encode("utf-8"))
    hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
    hasher.update(array.tobytes())


def _input_geometry_hash(inputs: UFPInput) -> str:
    """Return a strict geometry signature for one normalized input."""
    hasher = hashlib.sha256()
    _hash_tensor(
        hasher,
        _normalized_tensor(inputs.positions, "inputs.positions"),
        dtype=torch.float64,
    )
    _hash_tensor(
        hasher,
        _normalized_tensor(inputs.cell, "inputs.cell"),
        dtype=torch.float64,
    )
    _hash_tensor(hasher, _normalized_tensor(inputs.pbc, "inputs.pbc"))
    _hash_tensor(
        hasher,
        _normalized_tensor(inputs.atomic_numbers, "inputs.atomic_numbers"),
    )
    _hash_tensor(
        hasher,
        _normalized_tensor(inputs.system_index, "inputs.system_index"),
    )
    neighbor_list = inputs.neighbor_list
    if neighbor_list is None:
        hasher.update(b"neighbor_list:none")
    else:
        hasher.update(b"neighbor_list")
        _hash_tensor(
            hasher,
            _normalized_tensor(neighbor_list.pairs, "neighbor_list.pairs"),
        )
        _hash_tensor(
            hasher,
            _normalized_tensor(neighbor_list.shifts, "neighbor_list.shifts"),
        )
        if neighbor_list.distances is None:
            hasher.update(b"distances:none")
        else:
            _hash_tensor(
                hasher,
                _normalized_tensor(neighbor_list.distances, "neighbor_list.distances"),
                dtype=torch.float64,
            )
        if neighbor_list.vectors is None:
            hasher.update(b"vectors:none")
        else:
            _hash_tensor(
                hasher,
                _normalized_tensor(neighbor_list.vectors, "neighbor_list.vectors"),
                dtype=torch.float64,
            )
        hasher.update(str(neighbor_list.backend).encode("utf-8"))
        hasher.update(str(neighbor_list.cutoff).encode("utf-8"))
        hasher.update(str(bool(neighbor_list.full_list)).encode("utf-8"))
        hasher.update(str(bool(neighbor_list.sorted)).encode("utf-8"))
        hasher.update(str(neighbor_list.strict).encode("utf-8"))
    return hasher.hexdigest()


def _source_identity(atoms: ase.Atoms) -> str | None:
    """Return a user-provided structure id when one is available."""
    for key in ("source_id", "uid", "id", "name"):
        value = atoms.info.get(key)
        if value is not None:
            return str(value)
    return None


def _source_hash(atoms: ase.Atoms) -> str:
    """Hash ASE structure content used by UFP geometry conversion."""
    hasher = hashlib.sha256()
    hasher.update(np.ascontiguousarray(atoms.numbers, dtype=np.int64).tobytes())
    hasher.update(np.ascontiguousarray(atoms.positions, dtype=np.float64).tobytes())
    hasher.update(np.ascontiguousarray(atoms.cell.array, dtype=np.float64).tobytes())
    hasher.update(np.ascontiguousarray(np.asarray(atoms.pbc), dtype=bool).tobytes())
    return hasher.hexdigest()


def _source_metadata(
    source_atoms: Sequence[ase.Atoms] | None,
) -> tuple[tuple[str | None, ...], tuple[str, ...]]:
    """Return source ids and content hashes when ASE sources are available."""
    if source_atoms is None:
        return (), ()
    return (
        tuple(_source_identity(atoms) for atoms in source_atoms),
        tuple(_source_hash(atoms) for atoms in source_atoms),
    )


def _normalize_species_ordering(
    inputs: UFPInput,
    species_ordering: Sequence[int] | None,
) -> tuple[int, ...]:
    """Normalize species ordering metadata for category construction."""
    if species_ordering is not None:
        return tuple(sorted(set(int(value) for value in species_ordering)))
    values = (
        _normalized_tensor(inputs.atomic_numbers, "inputs.atomic_numbers")
        .detach()
        .cpu()
        .tolist()
    )
    return tuple(sorted(set(int(value) for value in values)))


def _periodicity(inputs: UFPInput) -> tuple[tuple[bool, bool, bool], ...]:
    """Return JSON-friendly periodic boundary flags."""
    rows = _normalized_tensor(inputs.pbc, "inputs.pbc").detach().cpu().tolist()
    periodicity: list[tuple[bool, bool, bool]] = []
    for row in rows:
        periodicity.append((bool(row[0]), bool(row[1]), bool(row[2])))
    return tuple(periodicity)


[docs] @dataclass(frozen=True) class PairCategorySpec: """Pair-category request stored by a prepared geometry object.""" atomic_types: tuple[int, ...] symmetric: bool = True def __init__(self, atomic_types: Sequence[int], symmetric: bool = True) -> None: """Normalize category metadata to match UFP pair-category ordering.""" object.__setattr__( self, "atomic_types", tuple(sorted(set(int(value) for value in atomic_types))), ) object.__setattr__(self, "symmetric", bool(symmetric))
[docs] @dataclass(frozen=True) class PreparedGeometryMetadata: """Strict validity metadata for one prepared geometry object.""" schema_version: int species_ordering: tuple[int, ...] cutoff: float | None neighbor_backend: str | None neighbor_full_list: bool | None neighbor_sorted: bool | None neighbor_strict: bool | None device: str dtype: str periodicity: tuple[tuple[bool, bool, bool], ...] n_atoms: int n_systems: int system_sizes: tuple[int, ...] source_structure_ids: tuple[str | None, ...] source_structure_hashes: tuple[str, ...] geometry_hash: str metadata_hash: str
[docs] @classmethod def from_input( cls, inputs: UFPInput, *, species_ordering: Sequence[int] | None = None, cutoff: float | None = None, ) -> "PreparedGeometryMetadata": """Build strict prepared-geometry metadata from a normalized input.""" neighbor_list = inputs.neighbor_list if cutoff is not None: resolved_cutoff = float(cutoff) elif neighbor_list is None or neighbor_list.cutoff is None: resolved_cutoff = None else: resolved_cutoff = float(neighbor_list.cutoff) source_ids, source_hashes = _source_metadata(inputs.source_atoms) resolved_species_ordering = _normalize_species_ordering( inputs, species_ordering, ) resolved_neighbor_backend = ( None if neighbor_list is None else str(neighbor_list.backend) ) resolved_neighbor_full_list = ( None if neighbor_list is None else bool(neighbor_list.full_list) ) resolved_neighbor_sorted = ( None if neighbor_list is None else bool(neighbor_list.sorted) ) resolved_neighbor_strict = ( None if neighbor_list is None or neighbor_list.strict is None else bool(neighbor_list.strict) ) resolved_periodicity = _periodicity(inputs) resolved_system_sizes = tuple(int(value) for value in inputs.system_sizes) resolved_geometry_hash = _input_geometry_hash(inputs) payload = { "schema_version": _PREPARED_METADATA_VERSION, "species_ordering": resolved_species_ordering, "cutoff": resolved_cutoff, "neighbor_backend": resolved_neighbor_backend, "neighbor_full_list": resolved_neighbor_full_list, "neighbor_sorted": resolved_neighbor_sorted, "neighbor_strict": resolved_neighbor_strict, "device": str(inputs.device), "dtype": str(inputs.dtype), "periodicity": resolved_periodicity, "n_atoms": int(inputs.n_atoms), "n_systems": int(inputs.n_systems), "system_sizes": resolved_system_sizes, "source_structure_ids": source_ids, "source_structure_hashes": source_hashes, "geometry_hash": resolved_geometry_hash, } return cls( schema_version=_PREPARED_METADATA_VERSION, species_ordering=resolved_species_ordering, cutoff=resolved_cutoff, neighbor_backend=resolved_neighbor_backend, neighbor_full_list=resolved_neighbor_full_list, neighbor_sorted=resolved_neighbor_sorted, neighbor_strict=resolved_neighbor_strict, device=str(inputs.device), dtype=str(inputs.dtype), periodicity=resolved_periodicity, n_atoms=int(inputs.n_atoms), n_systems=int(inputs.n_systems), system_sizes=resolved_system_sizes, source_structure_ids=source_ids, source_structure_hashes=source_hashes, geometry_hash=resolved_geometry_hash, metadata_hash=_sha256_json(payload), )
[docs] def to_dict(self) -> dict[str, object]: """Return JSON-friendly metadata.""" return asdict(self)
[docs] @dataclass(frozen=True) class PreparedGeometryReuseEstimate: """Memory and reuse estimate for a prepared geometry object.""" n_reuses: int n_atoms: int n_pairs: int total_nbytes: int pair_geometry_nbytes: int triplet_cache_nbytes: int avoided_pair_geometry_nbytes: int reuse_helpful: bool memory_dominated: bool
def _tensor_nbytes(value: object, seen: set[int]) -> int: """Return tensor memory estimate while avoiding exact object duplicates.""" if not isinstance(value, torch.Tensor): return 0 key = id(value) if key in seen: return 0 seen.add(key) return int(value.numel() * value.element_size()) def _neighbor_list_nbytes( neighbor_list: NeighborListData | None, seen: set[int], ) -> int: """Return estimated tensor storage for a neighbor list.""" if neighbor_list is None: return 0 return sum( _tensor_nbytes(value, seen) for value in ( neighbor_list.pairs, neighbor_list.shifts, neighbor_list.distances, neighbor_list.vectors, ) ) def _nested_tensor_nbytes(value: object, seen: set[int]) -> int: """Estimate tensor bytes recursively for cache metadata containers.""" if isinstance(value, torch.Tensor): return _tensor_nbytes(value, seen) if isinstance(value, NeighborListData): return _neighbor_list_nbytes(value, seen) if isinstance(value, Mapping): return sum(_nested_tensor_nbytes(item, seen) for item in value.values()) if isinstance(value, (tuple, list)): return sum(_nested_tensor_nbytes(item, seen) for item in value) if hasattr(value, "__dict__"): return _nested_tensor_nbytes(vars(value), seen) return 0
[docs] @dataclass(frozen=True) class PreparedGeometry: """ Reusable tensorized geometry for advanced workflow caching. This object intentionally stays outside term ``forward`` dispatch. Convert it back to ``UFPInput`` when a runtime model path is needed. """ positions: torch.Tensor cell: torch.Tensor pbc: torch.Tensor atomic_numbers: torch.Tensor system_index: torch.Tensor neighbor_list: NeighborListData | None pair_system_index: torch.Tensor | None pair_vectors: torch.Tensor | None pair_distances: torch.Tensor | None pair_category_indices: Mapping[PairCategorySpec, torch.Tensor] metadata: PreparedGeometryMetadata input_metadata: Mapping[str, object] = field(default_factory=dict) triplet_caches: Mapping[str, object] = field(default_factory=dict) source_atoms: Sequence[ase.Atoms] | None = field(default=None, repr=False) @property def n_atoms(self) -> int: """Return the number of atoms in the prepared geometry.""" return int(self.positions.shape[0]) @property def n_systems(self) -> int: """Return the number of systems in the prepared geometry.""" return int(self.cell.shape[0]) @property def n_pairs(self) -> int: """Return the number of neighbor-list rows.""" if self.neighbor_list is None: return 0 return int(self.neighbor_list.n_pairs) @property def memory_nbytes(self) -> int: """Return an estimated memory footprint for owned tensor payloads.""" seen: set[int] = set() total = sum( _tensor_nbytes(value, seen) for value in ( self.positions, self.cell, self.pbc, self.atomic_numbers, self.system_index, self.pair_system_index, self.pair_vectors, self.pair_distances, ) ) total += _neighbor_list_nbytes(self.neighbor_list, seen) total += sum( _tensor_nbytes(value, seen) for value in self.pair_category_indices.values() ) total += _nested_tensor_nbytes(self.triplet_caches, seen) return int(total) @property def pair_geometry_nbytes(self) -> int: """Return the estimated memory footprint of reusable pair geometry.""" seen: set[int] = set() total = sum( _tensor_nbytes(value, seen) for value in ( self.pair_system_index, self.pair_vectors, self.pair_distances, ) ) total += sum( _tensor_nbytes(value, seen) for value in self.pair_category_indices.values() ) return int(total) @property def triplet_cache_nbytes(self) -> int: """Return the estimated memory footprint of opaque triplet caches.""" return _nested_tensor_nbytes(self.triplet_caches, set())
[docs] def reuse_estimate( self, *, n_reuses: int, memory_budget_bytes: int | None = None, ) -> PreparedGeometryReuseEstimate: """Estimate when prepared pair geometry reuse is worth its memory cost.""" n_reuses = int(n_reuses) if n_reuses <= 0: raise ValueError("`n_reuses` must be positive") pair_bytes = self.pair_geometry_nbytes avoided = max(0, n_reuses - 1) * pair_bytes return PreparedGeometryReuseEstimate( n_reuses=n_reuses, n_atoms=self.n_atoms, n_pairs=self.n_pairs, total_nbytes=self.memory_nbytes, pair_geometry_nbytes=pair_bytes, triplet_cache_nbytes=self.triplet_cache_nbytes, avoided_pair_geometry_nbytes=avoided, reuse_helpful=bool(self.n_pairs > 0 and avoided > 0), memory_dominated=( False if memory_budget_bytes is None else self.memory_nbytes > int(memory_budget_bytes) ), )
[docs] def to_input( self, *, device: torch.device | str | None = None, dtype: torch.dtype | None = None, requires_grad: bool = False, include_triplet_caches: bool = True, ) -> UFPInput: """Convert prepared geometry back to the runtime ``UFPInput`` container.""" resolved_device = ( self.positions.device if device is None else torch.device(device) ) resolved_dtype = self.positions.dtype if dtype is None else dtype positions = self.positions.to(device=resolved_device, dtype=resolved_dtype) if requires_grad: positions = positions.detach().clone().requires_grad_(True) else: positions = positions.detach() neighbor_list = None if self.neighbor_list is not None: neighbor_list = NeighborListData( pairs=self.neighbor_list.pairs, shifts=self.neighbor_list.shifts, distances=( self.neighbor_list.distances if self.pair_distances is None else self.pair_distances ), vectors=( self.neighbor_list.vectors if self.pair_vectors is None else self.pair_vectors ), backend=self.neighbor_list.backend, cutoff=self.neighbor_list.cutoff, full_list=self.neighbor_list.full_list, sorted=self.neighbor_list.sorted, strict=self.neighbor_list.strict, ).as_torch(dtype=resolved_dtype, device=resolved_device) metadata = dict(self.input_metadata) metadata[_PREPARED_INPUT_METADATA_KEY] = self.metadata.to_dict() same_geometry_device = str(resolved_device) == self.metadata.device same_geometry_dtype = str(resolved_dtype) == self.metadata.dtype if ( include_triplet_caches and not requires_grad and same_geometry_device and same_geometry_dtype ): metadata.update(self.triplet_caches) return UFPInput( positions=positions, cell=self.cell.to(device=resolved_device, dtype=resolved_dtype), pbc=self.pbc.to(device=resolved_device), atomic_numbers=self.atomic_numbers.to(device=resolved_device), system_index=self.system_index.to(device=resolved_device), neighbor_list=neighbor_list, metadata=metadata, source_atoms=self.source_atoms, )
[docs] def assert_valid_for( self, inputs: UFPInput, *, species_ordering: Sequence[int] | None = None, cutoff: float | None = None, ) -> None: """Raise if ``inputs`` no longer match this prepared geometry metadata.""" expected = PreparedGeometryMetadata.from_input( inputs, species_ordering=( self.metadata.species_ordering if species_ordering is None else species_ordering ), cutoff=self.metadata.cutoff if cutoff is None else cutoff, ) if expected.to_dict() != self.metadata.to_dict(): raise ValueError("prepared geometry metadata does not match input")
[docs] def is_valid_for( self, inputs: UFPInput, *, species_ordering: Sequence[int] | None = None, cutoff: float | None = None, ) -> bool: """Return whether ``inputs`` match this prepared geometry metadata.""" try: self.assert_valid_for( inputs, species_ordering=species_ordering, cutoff=cutoff, ) except ValueError: return False return True
def _normalize_pair_category_specs( inputs: UFPInput, *, species_ordering: Sequence[int] | None, pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None, ) -> tuple[PairCategorySpec, ...]: """Normalize pair-category requests.""" if pair_category_specs is None: return ( PairCategorySpec(_normalize_species_ordering(inputs, species_ordering)), ) specs: list[PairCategorySpec] = [] for spec in pair_category_specs: if isinstance(spec, PairCategorySpec): specs.append(spec) else: specs.append(PairCategorySpec(spec)) return tuple(specs) def _triplet_cache_metadata(metadata: Mapping[str, object]) -> dict[str, object]: """Extract opaque three-body cache entries from input metadata.""" return { str(key): value for key, value in metadata.items() if str(key).startswith(_THREEBODY_METADATA_PREFIX) } def _base_input_metadata(metadata: Mapping[str, object]) -> dict[str, object]: """Keep non-cache metadata when storing a prepared geometry.""" return { str(key): value for key, value in metadata.items() if not str(key).startswith(_THREEBODY_METADATA_PREFIX) and str(key) != _PREPARED_INPUT_METADATA_KEY } def _cache_triplet_terms( inputs: UFPInput, terms: Sequence[object], *, feature_cache_storage: str, feature_cache_dir: Path | str | None, cache_prefix: str, include_per_atom_energy: bool, ) -> None: """Ask compatible terms to warm triplet caches on a working input.""" for term_index, term in enumerate(terms): cache_input = getattr(term, "cache_input", None) if not callable(cache_input): continue options = TermCacheOptions( feature_cache_storage=feature_cache_storage, feature_cache_dir=None if feature_cache_dir is None else Path(feature_cache_dir), cache_prefix=f"{cache_prefix}_term{term_index}", include_per_atom_energy=include_per_atom_energy, ) cache_input(inputs, options=options)
[docs] def prepare_geometry_from_input( inputs: UFPInput, *, species_ordering: Sequence[int] | None = None, pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None, triplet_terms: Sequence[object] = (), feature_cache_storage: str = "none", feature_cache_dir: Path | str | None = None, cache_prefix: str = "prepared", include_per_atom_energy: bool = True, ) -> PreparedGeometry: """Materialize reusable geometry tensors from an existing ``UFPInput``.""" working = inputs.to(device=inputs.device, dtype=inputs.dtype, requires_grad=False) if feature_cache_storage not in {"none", "cpu", "disk"}: raise ValueError("`feature_cache_storage` must be 'none', 'cpu', or 'disk'") if feature_cache_storage == "disk" and feature_cache_dir is None: raise ValueError( "`feature_cache_dir` is required when `feature_cache_storage='disk'`" ) pair_system_index = None pair_vectors = None pair_distances = None pair_categories: dict[PairCategorySpec, torch.Tensor] = {} if working.neighbor_list is not None: pair_system_index = working.pair_system_index().detach() pair_vectors = working.pair_vectors().detach() pair_distances = working.pair_distances().detach() for spec in _normalize_pair_category_specs( working, species_ordering=species_ordering, pair_category_specs=pair_category_specs, ): pair_categories[spec] = working.pair_category_indices( spec.atomic_types, symmetric=spec.symmetric, ).detach() if triplet_terms: _cache_triplet_terms( working, triplet_terms, feature_cache_storage=feature_cache_storage, feature_cache_dir=feature_cache_dir, cache_prefix=cache_prefix, include_per_atom_energy=include_per_atom_energy, ) metadata = PreparedGeometryMetadata.from_input( working, species_ordering=species_ordering, ) return PreparedGeometry( positions=_normalized_tensor(working.positions, "working.positions").detach(), cell=_normalized_tensor(working.cell, "working.cell").detach(), pbc=_normalized_tensor(working.pbc, "working.pbc").detach(), atomic_numbers=_normalized_tensor( working.atomic_numbers, "working.atomic_numbers", ).detach(), system_index=_normalized_tensor( working.system_index, "working.system_index", ).detach(), neighbor_list=working.neighbor_list, pair_system_index=pair_system_index, pair_vectors=pair_vectors, pair_distances=pair_distances, pair_category_indices=pair_categories, metadata=metadata, input_metadata=_base_input_metadata(inputs.metadata), triplet_caches=_triplet_cache_metadata(working.metadata), source_atoms=working.source_atoms, )
def _normalize_atoms_sequence( atoms: ase.Atoms | Sequence[ase.Atoms], ) -> tuple[ase.Atoms, ...]: """Normalize a single ASE object or a sequence of ASE objects.""" if isinstance(atoms, ase.Atoms): return (atoms,) normalized = tuple(atoms) if not normalized: raise ValueError("`atoms` must contain at least one structure") if any(not isinstance(item, ase.Atoms) for item in normalized): raise TypeError("all `atoms` entries must be ase.Atoms") return normalized
[docs] def prepare_geometry_from_ase( atoms: ase.Atoms | Sequence[ase.Atoms], *, cutoff: float, backend: str | NeighborListBackend = NeighborListBackend.AUTO, full_list: bool = True, sorted: bool = True, dtype: torch.dtype | None = None, device: torch.device | str | None = None, species_ordering: Sequence[int] | None = None, pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None, triplet_terms: Sequence[object] = (), feature_cache_storage: str = "none", feature_cache_dir: Path | str | None = None, cache_prefix: str = "prepared", include_per_atom_energy: bool = True, ) -> PreparedGeometry: """Build prepared geometry directly from one or more ASE structures.""" atoms_list = _normalize_atoms_sequence(atoms) cutoff = float(cutoff) neighbor_lists = [ build_neighbor_list( item, cutoff=cutoff, backend=backend, arrays="torch", full_list=full_list, sorted=sorted, ) for item in atoms_list ] atom_offsets = [] offset = 0 for item in atoms_list: atom_offsets.append(offset) offset += len(item) inputs = UFPInput.from_ase_list( atoms_list, neighbor_list=concatenate_neighbor_lists( neighbor_lists, atom_offsets=atom_offsets, ), dtype=dtype, device=None if device is None else torch.device(device), ) return prepare_geometry_from_input( inputs, species_ordering=species_ordering, pair_category_specs=pair_category_specs, triplet_terms=triplet_terms, feature_cache_storage=feature_cache_storage, feature_cache_dir=feature_cache_dir, cache_prefix=cache_prefix, include_per_atom_energy=include_per_atom_energy, )
[docs] def prepare_geometry_from_batch( batch: ASEAtomsBatch, model: UFPotential, *, backend: str | NeighborListBackend | None = None, dtype: torch.dtype | None = None, device: torch.device | str | None = None, species_ordering: Sequence[int] | None = None, pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None, triplet_terms: Sequence[object] = (), feature_cache_storage: str = "none", feature_cache_dir: Path | str | None = None, cache_prefix: str = "prepared", include_per_atom_energy: bool = True, ) -> PreparedGeometry: """Prepare geometry from a training ``ASEAtomsBatch``.""" inputs = batch.prepare_input( model, backend=backend, dtype=dtype, device=device, requires_grad=False, ) return prepare_geometry_from_input( inputs, species_ordering=species_ordering, pair_category_specs=pair_category_specs, triplet_terms=triplet_terms, feature_cache_storage=feature_cache_storage, feature_cache_dir=feature_cache_dir, cache_prefix=cache_prefix, include_per_atom_energy=include_per_atom_energy, )
[docs] def prepare_geometry_from_prepared_batch( batch: PreparedBatch, *, species_ordering: Sequence[int] | None = None, pair_category_specs: Sequence[PairCategorySpec | Sequence[int]] | None = None, triplet_terms: Sequence[object] = (), feature_cache_storage: str = "none", feature_cache_dir: Path | str | None = None, cache_prefix: str = "prepared", include_per_atom_energy: bool = True, ) -> PreparedGeometry: """Prepare geometry from a least-squares ``PreparedBatch``.""" return prepare_geometry_from_input( batch.inputs, species_ordering=species_ordering, pair_category_specs=pair_category_specs, triplet_terms=triplet_terms, feature_cache_storage=feature_cache_storage, feature_cache_dir=feature_cache_dir, cache_prefix=cache_prefix, include_per_atom_energy=include_per_atom_energy, )
__all__ = [ "PairCategorySpec", "PreparedGeometry", "PreparedGeometryMetadata", "PreparedGeometryReuseEstimate", "prepare_geometry_from_ase", "prepare_geometry_from_batch", "prepare_geometry_from_input", "prepare_geometry_from_prepared_batch", ]