Source code for ufp.workflows.residuals

"""Residual-label materialization for frozen model components."""

from __future__ import annotations

import hashlib
import json
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Protocol, cast

import numpy as np
import torch

from ufp.core.input import UFPInput
from ufp.core.output import (
    UFPOutput,
    _coerce_energy,
    _coerce_forces,
    _coerce_stress,
)
from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend
from ufp.terms._base import UFPTerm
from ufp.training import ASEAtomsDataset, ASEAtomsSample


TargetName = str
TermSelector = int | str | type[UFPTerm]
TermFilter = Callable[[int, UFPTerm], bool]

_SUPPORTED_TARGETS = ("energy", "forces", "stress")
_DEFAULT_TARGET_KEYS = {
    "energy": "energy",
    "forces": "forces",
    "stress": "stress",
}
_DEFAULT_UNITS = {
    "energy": "eV",
    "forces": "eV/angstrom",
    "stress": "eV/angstrom^3",
}
_RESIDUAL_METADATA_VERSION = 1


class SampleDataset(Protocol):
    """Dataset protocol for indexable ASE training samples."""

    def __len__(self) -> int:
        """Return the number of samples."""
        ...

    def __getitem__(self, index: int) -> ASEAtomsSample:
        """Return one sample."""
        ...


[docs] @dataclass(frozen=True) class ResidualDatasetResult: """Residualized ASE dataset and the metadata used to materialize it.""" dataset: ASEAtomsDataset metadata: Mapping[str, object] @property def metadata_hash(self) -> str: """Return the deterministic residual metadata hash.""" return str(self.metadata["metadata_hash"])
def _json_dumps(payload: object) -> str: """Return deterministic JSON for metadata payloads.""" return json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str) def _sha256_json(payload: object) -> str: """Hash one JSON-serializable payload.""" return hashlib.sha256(_json_dumps(payload).encode("utf-8")).hexdigest() def _normalize_dataset( dataset: SampleDataset, ) -> tuple[ASEAtomsSample, ...]: """Return a tuple of ASE samples from a dataset-like input.""" samples = tuple(dataset[index] for index in range(len(dataset))) if not samples: raise ValueError("`dataset` must contain at least one sample") if any(not isinstance(sample, ASEAtomsSample) for sample in samples): raise TypeError("all dataset entries must be ASEAtomsSample instances") return samples def _normalize_targets(targets: Sequence[TargetName] | None) -> tuple[TargetName, ...]: """Normalize requested residual target names.""" if targets is None: return () normalized = tuple(str(target) for target in targets) if not normalized: raise ValueError("`targets` must contain at least one target") unknown = sorted(set(normalized).difference(_SUPPORTED_TARGETS)) if unknown: raise ValueError("unsupported residual targets: " + ", ".join(unknown)) if len(set(normalized)) != len(normalized): raise ValueError("`targets` must not contain duplicates") return normalized def _infer_targets(samples: Sequence[ASEAtomsSample]) -> tuple[TargetName, ...]: """Infer targets present on every sample.""" inferred = [] if all(sample.energy is not None for sample in samples): inferred.append("energy") if all(sample.forces is not None for sample in samples): inferred.append("forces") if all(sample.stress is not None for sample in samples): inferred.append("stress") if not inferred: raise ValueError( "`targets` was not provided and no common labels were found in `dataset`" ) return tuple(inferred) def _validate_targets( samples: Sequence[ASEAtomsSample], targets: Sequence[TargetName], ) -> None: """Reject explicitly requested targets missing from any sample.""" for target in targets: missing = [ index for index, sample in enumerate(samples) if getattr(sample, target) is None ] if missing: raise ValueError( f"sample {missing[0]} is missing requested target `{target}`" ) def _normalize_mapping( values: Mapping[str, Any] | None, *, defaults: Mapping[str, Any], targets: Sequence[TargetName], name: str, ) -> dict[str, Any]: """Normalize target-keyed metadata mappings.""" supplied = {} if values is None else dict(values) unknown = sorted(set(supplied).difference(_SUPPORTED_TARGETS)) if unknown: raise ValueError(f"`{name}` contains unsupported targets: {unknown}") return {target: supplied.get(target, defaults[target]) for target in targets} def _selector_label(selector: TermSelector) -> str: """Return a stable label for one term selector.""" if isinstance(selector, type): return f"type:{selector.__module__}.{selector.__qualname__}" return f"{type(selector).__name__}:{selector}" def _term_family(model: UFPotential, term: UFPTerm) -> str: """Return the coarse family for one term inside a model.""" if any(term is item for item in getattr(model, "onebody_terms", ())): return "onebody" if any(term is item for item in getattr(model, "pair_terms", ())): return "pair" if any(term is item for item in getattr(model, "threebody_terms", ())): return "threebody" return "other" def _term_matches_selector( model: UFPotential, index: int, term: UFPTerm, selector: TermSelector, ) -> bool: """Return whether one selector matches a model term.""" if isinstance(selector, int): return int(selector) == index if isinstance(selector, type): return isinstance(term, selector) value = str(selector) if value in {str(index), type(term).__name__, _term_family(model, term)}: return True for block in term.parameter_blocks(): if value in { str(block.kind), str(block.label), str(block.regularization_group), }: return True return False def _term_is_frozen(term: UFPTerm) -> bool: """Return whether all parameters in one term are frozen.""" return not any(parameter.requires_grad for parameter in term.parameters()) def _resolve_terms( model: UFPotential, *, selectors: Sequence[TermSelector] | None, term_filter: TermFilter | None, require_frozen: bool, ) -> tuple[tuple[int, UFPTerm], ...]: """Resolve frozen components to subtract from labels.""" terms = tuple(getattr(model, "terms", ())) if not terms: raise TypeError("`model` must expose a non-empty `terms` sequence") if selectors is None: if term_filter is None: selected = [ (index, term) for index, term in enumerate(terms) if _term_is_frozen(term) ] else: selected = [ (index, term) for index, term in enumerate(terms) if term_filter(index, term) ] else: selected = [] for selector in selectors: matches = [ (index, term) for index, term in enumerate(terms) if _term_matches_selector(model, index, term, selector) ] if not matches: raise ValueError( f"term selector {_selector_label(selector)!r} matched no terms" ) selected.extend(matches) if term_filter is not None: selected = [ (index, term) for index, term in selected if term_filter(index, term) ] deduplicated: dict[int, UFPTerm] = {} for index, term in selected: deduplicated[int(index)] = term resolved = tuple((index, deduplicated[index]) for index in sorted(deduplicated)) if not resolved: raise ValueError("no model terms were selected for residualization") if require_frozen: trainable = [ (index, type(term).__name__) for index, term in resolved if not _term_is_frozen(term) ] if trainable: first_index, first_name = trainable[0] raise ValueError( "selected residual term is not frozen: " f"index {first_index} ({first_name})" ) return resolved def _term_metadata( model: UFPotential, selected_terms: Sequence[tuple[int, UFPTerm]], ) -> tuple[dict[str, object], ...]: """Return compact selected-term metadata.""" metadata = [] for index, term in selected_terms: blocks = [ { "kind": block.kind, "label": block.label, "shape": tuple(int(dim) for dim in block.shape), "regularization_group": block.regularization_group, } for block in term.parameter_blocks() ] metadata.append( { "index": int(index), "family": _term_family(model, term), "class": f"{type(term).__module__}.{type(term).__qualname__}", "cutoff": None if term.cutoff is None else float(term.cutoff), "atomic_types": None if term.atomic_types is None else tuple(int(value) for value in term.atomic_types), "parameter_blocks": blocks, "state_hash": _term_state_hash(index, term), } ) projection_metadata = _term_projection_metadata(term) if projection_metadata is not None: metadata[-1]["projection_metadata"] = projection_metadata return tuple(metadata) def _term_state_hash(index: int, term: UFPTerm) -> str: """Hash one selected frozen-term state, including its model index.""" hasher = hashlib.sha256() header = { "index": int(index), "class": f"{type(term).__module__}.{type(term).__qualname__}", } hasher.update(_json_dumps(header).encode("utf-8")) for name, tensor in sorted(term.state_dict().items()): detached = tensor.detach().cpu().contiguous() tensor_header = { "name": name, "shape": tuple(int(dim) for dim in detached.shape), "dtype": str(detached.dtype), } hasher.update(_json_dumps(tensor_header).encode("utf-8")) array = detached.numpy() hasher.update(np.ascontiguousarray(array).tobytes()) return hasher.hexdigest() def _term_projection_metadata(term: UFPTerm) -> Mapping[str, object] | None: """Return optional projection metadata exposed by analytic prior terms.""" projection_metadata = getattr(term, "projection_metadata", None) if not callable(projection_metadata): return None metadata = projection_metadata() if not isinstance(metadata, Mapping): raise TypeError("`projection_metadata()` must return a mapping") return dict(metadata) def _model_state_hash(selected_terms: Sequence[tuple[int, UFPTerm]]) -> str: """Hash the selected frozen-term states as a single strict signature.""" hasher = hashlib.sha256() for index, term in selected_terms: hasher.update(_term_state_hash(index, term).encode("utf-8")) return hasher.hexdigest() def _callable_label(callback: object | None) -> str | None: """Return a stable-enough label for callable metadata.""" if callback is None: return None module = getattr(callback, "__module__", type(callback).__module__) qualname = getattr(callback, "__qualname__", type(callback).__qualname__) return f"{module}.{qualname}" def _residual_metadata( model: UFPotential, *, selected_terms: Sequence[tuple[int, UFPTerm]], selectors: Sequence[TermSelector] | None, term_filter: TermFilter | None, targets: Sequence[TargetName], target_keys: Mapping[str, str], target_weights: Mapping[str, float], units: Mapping[str, str], ) -> dict[str, object]: """Build deterministic residual metadata and hash.""" state_hash = _model_state_hash(selected_terms) payload = { "version": _RESIDUAL_METADATA_VERSION, "model_class": f"{type(model).__module__}.{type(model).__qualname__}", "frozen_terms_state_hash": state_hash, "model_state_hash": state_hash, "selectors": None if selectors is None else tuple(_selector_label(selector) for selector in selectors), "term_filter": _callable_label(term_filter), "selected_terms": _term_metadata(model, selected_terms), "targets": tuple(targets), "target_keys": dict(target_keys), "target_weights": {target: float(target_weights[target]) for target in targets}, "units": dict(units), } return { **payload, "metadata_hash": _sha256_json(payload), } def _sum_energy(outputs: Sequence[UFPOutput], inputs: UFPInput) -> torch.Tensor | None: """Sum available energy contributions.""" energy = None for output in outputs: if output.energy is None: continue term_energy = _coerce_energy(output.energy, inputs) energy = term_energy if energy is None else energy + term_energy return energy def _sum_forces(outputs: Sequence[UFPOutput], inputs: UFPInput) -> torch.Tensor | None: """Sum available analytic force contributions.""" forces = None for output in outputs: if output.forces is None: return None term_forces = _coerce_forces(output.forces, inputs) forces = term_forces if forces is None else forces + term_forces return forces def _derive_forces(energy: torch.Tensor, inputs: UFPInput) -> torch.Tensor: """Derive force residuals from a selected energy contribution.""" gradient = torch.autograd.grad( energy.reshape(-1).sum(), cast(torch.Tensor, inputs.positions), create_graph=False, retain_graph=False, )[0] return -gradient def _sum_stress(outputs: Sequence[UFPOutput], inputs: UFPInput) -> torch.Tensor | None: """Sum available stress contributions, treating absent stress as zero.""" stress = None for output in outputs: if output.stress is None: continue term_stress = _coerce_stress(output.stress, inputs) stress = term_stress if stress is None else stress + term_stress return stress @dataclass(frozen=True) class _ResidualContribution: """Selected model contribution for one sample.""" energy: torch.Tensor | None forces: torch.Tensor | None stress: torch.Tensor | None def _evaluate_selected_terms( model: UFPotential, selected_terms: Sequence[tuple[int, UFPTerm]], sample: ASEAtomsSample, *, targets: Sequence[TargetName], dtype: torch.dtype, device: torch.device | None, backend: str | NeighborListBackend | None, ) -> _ResidualContribution: """Evaluate selected terms for one sample.""" need_forces = "forces" in targets can_use_analytic_forces = all(term.provides_forces for _, term in selected_terms) requires_grad = need_forces inputs = model.prepare_input( sample.atoms, neighbor_list=sample.neighbor_list, backend=backend, device=device, dtype=dtype, requires_grad=requires_grad, ) validator = getattr(model, "_validate_input_atomic_types", None) if callable(validator): validator(inputs) outputs = [term(inputs) for _, term in selected_terms] energy = _sum_energy(outputs, inputs) forces = None if need_forces: forces = _sum_forces(outputs, inputs) if can_use_analytic_forces else None if forces is None: if energy is None: raise ValueError("can not derive selected forces without energy") forces = _derive_forces(energy, inputs) stress = _sum_stress(outputs, inputs) if "stress" in targets else None return _ResidualContribution( energy=None if energy is None else energy.detach(), forces=None if forces is None else forces.detach(), stress=None if stress is None else stress.detach(), ) def _residual_energy( sample: ASEAtomsSample, contribution: torch.Tensor | None ) -> float: """Return one residual total energy label.""" assert sample.energy is not None if contribution is None: return float(sample.energy) if contribution.numel() != 1: raise ValueError("selected energy contribution must contain one scalar") return float(sample.energy) - float(contribution.reshape(-1)[0].cpu().item()) def _residual_forces( sample: ASEAtomsSample, contribution: torch.Tensor | None, ) -> torch.Tensor: """Return one residual force label tensor.""" assert sample.forces is not None if contribution is None: return sample.forces.detach().clone() contribution = contribution.detach().to(dtype=sample.forces.dtype, device="cpu") if tuple(contribution.shape) != tuple(sample.forces.shape): raise ValueError( "selected force contribution has shape " f"{tuple(contribution.shape)}, expected {tuple(sample.forces.shape)}" ) return sample.forces.detach().cpu() - contribution def _residual_stress( sample: ASEAtomsSample, contribution: torch.Tensor | None, ) -> torch.Tensor: """Return one residual stress label tensor.""" assert sample.stress is not None if contribution is None: return sample.stress.detach().clone() contribution = contribution.detach().to(dtype=sample.stress.dtype, device="cpu") if tuple(contribution.shape) == (1, 3, 3): contribution = contribution[0] if tuple(contribution.shape) != tuple(sample.stress.shape): raise ValueError( "selected stress contribution has shape " f"{tuple(contribution.shape)}, expected {tuple(sample.stress.shape)}" ) return sample.stress.detach().cpu() - contribution def _write_atoms_targets( sample: ASEAtomsSample, *, energy: float | None, forces: torch.Tensor | None, stress: torch.Tensor | None, target_keys: Mapping[str, str], ): """Copy source atoms and update residual labels under the requested keys.""" atoms = sample.atoms.copy() if energy is not None: atoms.info[target_keys["energy"]] = float(energy) if forces is not None: atoms.arrays[target_keys["forces"]] = forces.detach().cpu().numpy() if stress is not None: atoms.info[target_keys["stress"]] = stress.detach().cpu().numpy() return atoms
[docs] def materialize_residual_dataset( model: UFPotential, dataset: SampleDataset, *, selectors: Sequence[TermSelector] | None = None, term_filter: TermFilter | None = None, targets: Sequence[TargetName] | None = None, target_keys: Mapping[str, str] | None = None, target_weights: Mapping[str, float] | None = None, units: Mapping[str, str] | None = None, require_frozen: bool = True, dtype: torch.dtype | None = None, device: torch.device | str | None = None, backend: str | NeighborListBackend | None = None, metadata_key: str = "residuals", ) -> ResidualDatasetResult: """ Subtract selected frozen model contributions from supervised ASE labels. Args: model: Model that owns the frozen terms to subtract. dataset: Input ASE training samples. selectors: Optional term selectors. Integers match evaluation-order term indices; strings match indices, term families, class names, or parameter-block labels/kinds/groups; term classes match by ``isinstance``. term_filter: Optional callable receiving ``(index, term)``. targets: Labels to residualize. If omitted, all labels present on every sample are residualized. target_keys: Label keys stored into copied ASE objects and metadata. target_weights: Optional downstream target weights to record in metadata. These do not scale residual labels. units: Target unit strings recorded in metadata; values are not converted. require_frozen: If ``True``, reject selected terms with trainable parameters. dtype: Evaluation dtype. Defaults to ``model.preferred_dtype()``. device: Optional evaluation device. backend: Optional neighbor-list backend override. metadata_key: Per-sample metadata key that stores residual metadata. Returns: Residualized dataset plus deterministic metadata. """ samples = _normalize_dataset(dataset) resolved_targets = _normalize_targets(targets) if not resolved_targets: resolved_targets = _infer_targets(samples) _validate_targets(samples, resolved_targets) selected_terms = _resolve_terms( model, selectors=selectors, term_filter=term_filter, require_frozen=require_frozen, ) normalized_target_keys = _normalize_mapping( target_keys, defaults=_DEFAULT_TARGET_KEYS, targets=resolved_targets, name="target_keys", ) normalized_units = _normalize_mapping( units, defaults=_DEFAULT_UNITS, targets=resolved_targets, name="units", ) normalized_weights = { target: float(value) for target, value in _normalize_mapping( target_weights, defaults={target: 1.0 for target in _SUPPORTED_TARGETS}, targets=resolved_targets, name="target_weights", ).items() } metadata = _residual_metadata( model, selected_terms=selected_terms, selectors=selectors, term_filter=term_filter, targets=resolved_targets, target_keys=normalized_target_keys, target_weights=normalized_weights, units=normalized_units, ) resolved_dtype = model.preferred_dtype() if dtype is None else dtype resolved_device = None if device is None else torch.device(device) was_training = model.training residual_samples: list[ASEAtomsSample] = [] try: model.eval() for sample in samples: contribution = _evaluate_selected_terms( model, selected_terms, sample, targets=resolved_targets, dtype=resolved_dtype, device=resolved_device, backend=backend, ) energy = ( _residual_energy(sample, contribution.energy) if "energy" in resolved_targets else sample.energy ) forces = ( _residual_forces(sample, contribution.forces) if "forces" in resolved_targets else sample.forces ) stress = ( _residual_stress(sample, contribution.stress) if "stress" in resolved_targets else sample.stress ) sample_metadata = dict(sample.metadata) sample_metadata[metadata_key] = metadata atoms = _write_atoms_targets( sample, energy=energy if "energy" in resolved_targets else None, forces=forces if "forces" in resolved_targets else None, stress=stress if "stress" in resolved_targets else None, target_keys=normalized_target_keys, ) residual_samples.append( ASEAtomsSample( atoms=atoms, energy=energy, forces=forces, force_mask=sample.force_mask, stress=stress, neighbor_list=sample.neighbor_list, metadata=sample_metadata, ) ) finally: model.train(was_training) return ResidualDatasetResult( dataset=ASEAtomsDataset(residual_samples), metadata=metadata, )
[docs] def residualize_ase_dataset( model: UFPotential, dataset: SampleDataset, **kwargs: Any, ) -> ResidualDatasetResult: """Alias for :func:`materialize_residual_dataset`.""" return materialize_residual_dataset(model, dataset, **kwargs)
__all__ = [ "ResidualDatasetResult", "TermFilter", "TermSelector", "materialize_residual_dataset", "residualize_ase_dataset", ]