Source code for ufp.leastsquares.hybrid

"""Hybrid residualized linear fitting helpers."""

from __future__ import annotations

import hashlib
from collections.abc import Callable, Sequence
from pathlib import Path

import numpy as np
import torch

from ufp.core.output import UFPOutput
from ufp.leastsquares.dataset import FitSample
from ufp.leastsquares.linear import LinearFitter
from ufp.terms.model import UFPModel


FrozenTermSelector = int | str | torch.nn.Module
FrozenTermFilter = Callable[[torch.nn.Module], bool]


def _term_selector_labels(
    term: torch.nn.Module,
    *,
    index: int,
) -> set[str]:
    """Return labels that can address one term in a hybrid residual selection."""
    term_type = type(term)
    labels = {
        str(index),
        term_type.__name__,
        f"{term_type.__module__}.{term_type.__name__}",
    }
    name = getattr(term, "name", None)
    if name is not None:
        labels.add(str(name))
    label = getattr(term, "label", None)
    if label is not None:
        labels.add(str(label))
    for block in term.parameter_blocks():
        labels.update({block.name, block.kind, block.label or ""})
    labels.discard("")
    return labels


def _resolve_frozen_terms(
    model: UFPModel,
    *,
    frozen_terms: Sequence[FrozenTermSelector],
    frozen_term_filter: FrozenTermFilter | None,
) -> tuple[torch.nn.Module, ...]:
    """Resolve user selectors into unique terms from ``model`` in model order."""
    terms = tuple(model.terms)
    selected_ids: set[int] = set()
    selected: list[torch.nn.Module] = []

    if frozen_term_filter is not None:
        for term in terms:
            if bool(frozen_term_filter(term)):
                selected_ids.add(id(term))

    for selector in frozen_terms:
        matched: list[torch.nn.Module] = []
        if isinstance(selector, torch.nn.Module):
            matched = [term for term in terms if term is selector]
            if not matched:
                raise ValueError("frozen term selector is not part of `model.terms`")
        elif isinstance(selector, int):
            index = int(selector)
            if index < 0 or index >= len(terms):
                raise ValueError(
                    f"frozen term index {index} is outside [0, {len(terms)})"
                )
            matched = [terms[index]]
        else:
            value = str(selector)
            matched = [
                term
                for index, term in enumerate(terms)
                if value in _term_selector_labels(term, index=index)
            ]
            if not matched:
                raise ValueError(f"frozen term selector {value!r} did not match a term")
        selected_ids.update(id(term) for term in matched)

    for term in terms:
        if id(term) in selected_ids:
            selected.append(term)
    return tuple(selected)


def _term_indices(
    model: UFPModel,
    selected_terms: Sequence[torch.nn.Module],
) -> tuple[int, ...]:
    """Return original model-order indices for selected terms."""
    index_by_id = {id(term): index for index, term in enumerate(model.terms)}
    indices = []
    for term in selected_terms:
        try:
            indices.append(index_by_id[id(term)])
        except KeyError as exc:
            raise ValueError(
                "selected frozen term is not part of `model.terms`"
            ) from exc
    return tuple(indices)


def _state_signature(
    terms: Sequence[torch.nn.Module],
    term_indices: Sequence[int] | None = None,
) -> str:
    """Return a stable digest of selected frozen-term module state."""
    indices = (
        tuple(range(len(terms)))
        if term_indices is None
        else tuple(int(index) for index in term_indices)
    )
    if len(indices) != len(terms):
        raise ValueError("`term_indices` must match `terms` length")
    hasher = hashlib.sha256()
    for term_index, term in zip(indices, terms, strict=True):
        hasher.update(str(term_index).encode("utf8"))
        hasher.update(type(term).__module__.encode("utf8"))
        hasher.update(type(term).__name__.encode("utf8"))
        state = term.state_dict()
        for name in sorted(state):
            tensor = state[name].detach().cpu().contiguous()
            hasher.update(name.encode("utf8"))
            hasher.update(str(tensor.dtype).encode("utf8"))
            hasher.update(
                torch.tensor(tensor.shape, dtype=torch.int64).numpy().tobytes()
            )
            hasher.update(tensor.numpy().tobytes())
    return hasher.hexdigest()


def _as_numpy_or_none(value: torch.Tensor | np.ndarray | None) -> np.ndarray | None:
    """Return a CPU numpy array for tensor-like values, preserving ``None``."""
    if value is None:
        return None
    if isinstance(value, torch.Tensor):
        return value.detach().cpu().numpy()
    return np.asarray(value)


[docs] class HybridLinearFitter: """ Residualize frozen nonlinear terms, then delegate to ``LinearFitter``. Frozen terms are evaluated on the same ASE structures and target components as least-squares fitting. The resulting residual ``FitSample`` objects are passed to the ordinary linear fitter so dense, normal-equation, cache, and CG solve paths remain centralized. """ def __init__( self, model: UFPModel, *, frozen_terms: Sequence[FrozenTermSelector] = (), frozen_term_filter: FrozenTermFilter | None = None, **linear_options, ) -> None: """Store the residual terms and create the delegated linear fitter.""" self.model = model self.linear_fitter = LinearFitter(model, **linear_options) self.frozen_terms = _resolve_frozen_terms( model, frozen_terms=tuple(frozen_terms), frozen_term_filter=frozen_term_filter, ) self._frozen_term_indices = _term_indices(model, self.frozen_terms) self._validate_frozen_terms_are_not_linear_solve_terms() self.frozen_model = ( None if not self.frozen_terms else UFPModel( terms=self.frozen_terms, atomic_types=model.atomic_types, neighbor_backend=model.neighbor_backend, ) ) def _validate_frozen_terms_are_not_linear_solve_terms(self) -> None: """Reject residual terms that also expose fitted linear coefficient blocks.""" frozen_ids = {id(term) for term in self.frozen_terms} linear_blocks = [ block.label for block in self.linear_fitter.layout.blocks if id(block.term) in frozen_ids ] if linear_blocks: joined = ", ".join(linear_blocks) raise ValueError( "hybrid residual terms must not expose least-squares parameter " f"blocks; use LinearFitter freeze_blocks for linear terms: {joined}" ) @property def layout(self): """Return the delegated linear fitter parameter layout.""" return self.linear_fitter.layout @property def frozen_term_indices(self) -> tuple[int, ...]: """Return model-order indices for the selected frozen residual terms.""" return self._frozen_term_indices @property def frozen_terms_state_hash(self) -> str: """Return a strict state digest for selected frozen residual terms.""" return _state_signature(self.frozen_terms, self._frozen_term_indices) @property def frozen_terms_signature(self) -> str: """Return the hybrid cache state digest. This alias is kept for callers that used the original name. New code should prefer :attr:`frozen_terms_state_hash`, matching residual materialization metadata. """ return self.frozen_terms_state_hash
[docs] def residualized_samples( self, samples: Sequence[FitSample], ) -> tuple[FitSample, ...]: """Return samples with frozen-term target contributions subtracted.""" items = tuple(samples) if not items: raise ValueError("`samples` must contain at least one FitSample") if self.frozen_model is None: return items residuals: list[FitSample] = [] needs_forces = self.linear_fitter.fit_forces self.frozen_model.to( device=self.linear_fitter.device, dtype=self.linear_fitter.dtype, ) self.frozen_model.eval() for sample in items: derive_forces = ( needs_forces and sample.forces is not None and not self.frozen_model.provides_forces() ) output = self.frozen_model.compute( sample.atoms, neighbor_list=sample.neighbor_list, device=self.linear_fitter.device, dtype=self.linear_fitter.dtype, derive_forces=derive_forces, ) residuals.append(self._residualized_sample(sample, output)) return tuple(residuals)
def _residualized_sample( self, sample: FitSample, output: UFPOutput, ) -> FitSample: """Subtract one frozen-model output from one labeled fit sample.""" energy = sample.energy if self.linear_fitter.fit_energy and sample.energy is not None: if output.energy is None: raise ValueError("frozen terms did not return requested energy output") energy_values = _as_numpy_or_none(output.energy) assert energy_values is not None energy = float(sample.energy) - float(energy_values.reshape(-1)[0]) forces = sample.forces if self.linear_fitter.fit_forces and sample.forces is not None: if output.forces is None: raise ValueError("frozen terms did not return requested force output") force_values = _as_numpy_or_none(output.forces) assert force_values is not None forces = np.asarray(sample.forces, dtype=float) - np.asarray( force_values, dtype=float, ) per_atom_energy = sample.per_atom_energy if ( self.linear_fitter.fit_per_atom_energy and sample.per_atom_energy is not None ): if output.per_atom_energy is None: raise ValueError( "frozen terms did not return requested per-atom energy output" ) per_atom_values = _as_numpy_or_none(output.per_atom_energy) assert per_atom_values is not None per_atom_energy = np.asarray(sample.per_atom_energy, dtype=float).reshape( -1 ) - np.asarray(per_atom_values, dtype=float).reshape(-1) return FitSample( atoms=sample.atoms.copy(), neighbor_list=sample.neighbor_list, energy=energy, forces=forces, per_atom_energy=per_atom_energy, energy_weight=sample.energy_weight, force_weight=sample.force_weight, per_atom_weight=sample.per_atom_weight, ) def _hybrid_cache_directory( self, cache_directory: Path | str | None, ) -> Path | str | None: """Scope delegated fitter caches by frozen-term state.""" if cache_directory is None or not self.frozen_terms: return cache_directory return Path(cache_directory) / f"hybrid_{self.frozen_terms_state_hash}"
[docs] def build_problem(self, samples: Sequence[FitSample], **kwargs): """Build a residualized linear problem through the delegated fitter.""" residual_samples = self.residualized_samples(samples) if "cache_directory" in kwargs: kwargs["cache_directory"] = self._hybrid_cache_directory( kwargs["cache_directory"] ) return self.linear_fitter.build_problem(residual_samples, **kwargs)
[docs] def make_linear_operator(self, samples: Sequence[FitSample], **kwargs): """Alias ``build_problem`` for matrix-free callers.""" return self.build_problem(samples, **kwargs)
[docs] def materialize_design_matrix(self, samples: Sequence[FitSample], **kwargs): """Build the explicit residualized design matrix and target vector.""" residual_samples = self.residualized_samples(samples) if "cache_directory" in kwargs: kwargs["cache_directory"] = self._hybrid_cache_directory( kwargs["cache_directory"] ) return self.linear_fitter.materialize_design_matrix(residual_samples, **kwargs)
[docs] def accumulate_normal_equations(self, samples: Sequence[FitSample], **kwargs): """Accumulate normal equations from residualized samples.""" residual_samples = self.residualized_samples(samples) for key in ("cache_directory", "normal_equation_cache_directory"): if key in kwargs: kwargs[key] = self._hybrid_cache_directory(kwargs[key]) return self.linear_fitter.accumulate_normal_equations( residual_samples, **kwargs, )
[docs] def fit(self, samples: Sequence[FitSample], **kwargs): """Fit selected linear coefficients against residualized targets.""" residual_samples = self.residualized_samples(samples) for key in ("cache_directory", "normal_equation_cache_directory"): if key in kwargs: kwargs[key] = self._hybrid_cache_directory(kwargs[key]) return self.linear_fitter.fit(residual_samples, **kwargs)
[docs] def write_back(self, theta: torch.Tensor) -> None: """Write fitted coefficients through the delegated linear fitter.""" self.linear_fitter.write_back(theta)
__all__ = [ "FrozenTermFilter", "FrozenTermSelector", "HybridLinearFitter", ]