"""Hybrid residualized linear fitting helpers."""
from __future__ import annotations
import hashlib
from collections.abc import Callable, Sequence
from pathlib import Path
import numpy as np
import torch
from ufp.core.output import UFPOutput
from ufp.leastsquares.dataset import FitSample
from ufp.leastsquares.linear import LinearFitter
from ufp.terms.model import UFPModel
FrozenTermSelector = int | str | torch.nn.Module
FrozenTermFilter = Callable[[torch.nn.Module], bool]
def _term_selector_labels(
term: torch.nn.Module,
*,
index: int,
) -> set[str]:
"""Return labels that can address one term in a hybrid residual selection."""
term_type = type(term)
labels = {
str(index),
term_type.__name__,
f"{term_type.__module__}.{term_type.__name__}",
}
name = getattr(term, "name", None)
if name is not None:
labels.add(str(name))
label = getattr(term, "label", None)
if label is not None:
labels.add(str(label))
for block in term.parameter_blocks():
labels.update({block.name, block.kind, block.label or ""})
labels.discard("")
return labels
def _resolve_frozen_terms(
model: UFPModel,
*,
frozen_terms: Sequence[FrozenTermSelector],
frozen_term_filter: FrozenTermFilter | None,
) -> tuple[torch.nn.Module, ...]:
"""Resolve user selectors into unique terms from ``model`` in model order."""
terms = tuple(model.terms)
selected_ids: set[int] = set()
selected: list[torch.nn.Module] = []
if frozen_term_filter is not None:
for term in terms:
if bool(frozen_term_filter(term)):
selected_ids.add(id(term))
for selector in frozen_terms:
matched: list[torch.nn.Module] = []
if isinstance(selector, torch.nn.Module):
matched = [term for term in terms if term is selector]
if not matched:
raise ValueError("frozen term selector is not part of `model.terms`")
elif isinstance(selector, int):
index = int(selector)
if index < 0 or index >= len(terms):
raise ValueError(
f"frozen term index {index} is outside [0, {len(terms)})"
)
matched = [terms[index]]
else:
value = str(selector)
matched = [
term
for index, term in enumerate(terms)
if value in _term_selector_labels(term, index=index)
]
if not matched:
raise ValueError(f"frozen term selector {value!r} did not match a term")
selected_ids.update(id(term) for term in matched)
for term in terms:
if id(term) in selected_ids:
selected.append(term)
return tuple(selected)
def _term_indices(
model: UFPModel,
selected_terms: Sequence[torch.nn.Module],
) -> tuple[int, ...]:
"""Return original model-order indices for selected terms."""
index_by_id = {id(term): index for index, term in enumerate(model.terms)}
indices = []
for term in selected_terms:
try:
indices.append(index_by_id[id(term)])
except KeyError as exc:
raise ValueError(
"selected frozen term is not part of `model.terms`"
) from exc
return tuple(indices)
def _state_signature(
terms: Sequence[torch.nn.Module],
term_indices: Sequence[int] | None = None,
) -> str:
"""Return a stable digest of selected frozen-term module state."""
indices = (
tuple(range(len(terms)))
if term_indices is None
else tuple(int(index) for index in term_indices)
)
if len(indices) != len(terms):
raise ValueError("`term_indices` must match `terms` length")
hasher = hashlib.sha256()
for term_index, term in zip(indices, terms, strict=True):
hasher.update(str(term_index).encode("utf8"))
hasher.update(type(term).__module__.encode("utf8"))
hasher.update(type(term).__name__.encode("utf8"))
state = term.state_dict()
for name in sorted(state):
tensor = state[name].detach().cpu().contiguous()
hasher.update(name.encode("utf8"))
hasher.update(str(tensor.dtype).encode("utf8"))
hasher.update(
torch.tensor(tensor.shape, dtype=torch.int64).numpy().tobytes()
)
hasher.update(tensor.numpy().tobytes())
return hasher.hexdigest()
def _as_numpy_or_none(value: torch.Tensor | np.ndarray | None) -> np.ndarray | None:
"""Return a CPU numpy array for tensor-like values, preserving ``None``."""
if value is None:
return None
if isinstance(value, torch.Tensor):
return value.detach().cpu().numpy()
return np.asarray(value)
[docs]
class HybridLinearFitter:
"""
Residualize frozen nonlinear terms, then delegate to ``LinearFitter``.
Frozen terms are evaluated on the same ASE structures and target components as
least-squares fitting. The resulting residual ``FitSample`` objects are passed
to the ordinary linear fitter so dense, normal-equation, cache, and CG solve
paths remain centralized.
"""
def __init__(
self,
model: UFPModel,
*,
frozen_terms: Sequence[FrozenTermSelector] = (),
frozen_term_filter: FrozenTermFilter | None = None,
**linear_options,
) -> None:
"""Store the residual terms and create the delegated linear fitter."""
self.model = model
self.linear_fitter = LinearFitter(model, **linear_options)
self.frozen_terms = _resolve_frozen_terms(
model,
frozen_terms=tuple(frozen_terms),
frozen_term_filter=frozen_term_filter,
)
self._frozen_term_indices = _term_indices(model, self.frozen_terms)
self._validate_frozen_terms_are_not_linear_solve_terms()
self.frozen_model = (
None
if not self.frozen_terms
else UFPModel(
terms=self.frozen_terms,
atomic_types=model.atomic_types,
neighbor_backend=model.neighbor_backend,
)
)
def _validate_frozen_terms_are_not_linear_solve_terms(self) -> None:
"""Reject residual terms that also expose fitted linear coefficient blocks."""
frozen_ids = {id(term) for term in self.frozen_terms}
linear_blocks = [
block.label
for block in self.linear_fitter.layout.blocks
if id(block.term) in frozen_ids
]
if linear_blocks:
joined = ", ".join(linear_blocks)
raise ValueError(
"hybrid residual terms must not expose least-squares parameter "
f"blocks; use LinearFitter freeze_blocks for linear terms: {joined}"
)
@property
def layout(self):
"""Return the delegated linear fitter parameter layout."""
return self.linear_fitter.layout
@property
def frozen_term_indices(self) -> tuple[int, ...]:
"""Return model-order indices for the selected frozen residual terms."""
return self._frozen_term_indices
@property
def frozen_terms_state_hash(self) -> str:
"""Return a strict state digest for selected frozen residual terms."""
return _state_signature(self.frozen_terms, self._frozen_term_indices)
@property
def frozen_terms_signature(self) -> str:
"""Return the hybrid cache state digest.
This alias is kept for callers that used the original name. New code
should prefer :attr:`frozen_terms_state_hash`, matching residual
materialization metadata.
"""
return self.frozen_terms_state_hash
[docs]
def residualized_samples(
self,
samples: Sequence[FitSample],
) -> tuple[FitSample, ...]:
"""Return samples with frozen-term target contributions subtracted."""
items = tuple(samples)
if not items:
raise ValueError("`samples` must contain at least one FitSample")
if self.frozen_model is None:
return items
residuals: list[FitSample] = []
needs_forces = self.linear_fitter.fit_forces
self.frozen_model.to(
device=self.linear_fitter.device,
dtype=self.linear_fitter.dtype,
)
self.frozen_model.eval()
for sample in items:
derive_forces = (
needs_forces
and sample.forces is not None
and not self.frozen_model.provides_forces()
)
output = self.frozen_model.compute(
sample.atoms,
neighbor_list=sample.neighbor_list,
device=self.linear_fitter.device,
dtype=self.linear_fitter.dtype,
derive_forces=derive_forces,
)
residuals.append(self._residualized_sample(sample, output))
return tuple(residuals)
def _residualized_sample(
self,
sample: FitSample,
output: UFPOutput,
) -> FitSample:
"""Subtract one frozen-model output from one labeled fit sample."""
energy = sample.energy
if self.linear_fitter.fit_energy and sample.energy is not None:
if output.energy is None:
raise ValueError("frozen terms did not return requested energy output")
energy_values = _as_numpy_or_none(output.energy)
assert energy_values is not None
energy = float(sample.energy) - float(energy_values.reshape(-1)[0])
forces = sample.forces
if self.linear_fitter.fit_forces and sample.forces is not None:
if output.forces is None:
raise ValueError("frozen terms did not return requested force output")
force_values = _as_numpy_or_none(output.forces)
assert force_values is not None
forces = np.asarray(sample.forces, dtype=float) - np.asarray(
force_values,
dtype=float,
)
per_atom_energy = sample.per_atom_energy
if (
self.linear_fitter.fit_per_atom_energy
and sample.per_atom_energy is not None
):
if output.per_atom_energy is None:
raise ValueError(
"frozen terms did not return requested per-atom energy output"
)
per_atom_values = _as_numpy_or_none(output.per_atom_energy)
assert per_atom_values is not None
per_atom_energy = np.asarray(sample.per_atom_energy, dtype=float).reshape(
-1
) - np.asarray(per_atom_values, dtype=float).reshape(-1)
return FitSample(
atoms=sample.atoms.copy(),
neighbor_list=sample.neighbor_list,
energy=energy,
forces=forces,
per_atom_energy=per_atom_energy,
energy_weight=sample.energy_weight,
force_weight=sample.force_weight,
per_atom_weight=sample.per_atom_weight,
)
def _hybrid_cache_directory(
self,
cache_directory: Path | str | None,
) -> Path | str | None:
"""Scope delegated fitter caches by frozen-term state."""
if cache_directory is None or not self.frozen_terms:
return cache_directory
return Path(cache_directory) / f"hybrid_{self.frozen_terms_state_hash}"
[docs]
def build_problem(self, samples: Sequence[FitSample], **kwargs):
"""Build a residualized linear problem through the delegated fitter."""
residual_samples = self.residualized_samples(samples)
if "cache_directory" in kwargs:
kwargs["cache_directory"] = self._hybrid_cache_directory(
kwargs["cache_directory"]
)
return self.linear_fitter.build_problem(residual_samples, **kwargs)
[docs]
def make_linear_operator(self, samples: Sequence[FitSample], **kwargs):
"""Alias ``build_problem`` for matrix-free callers."""
return self.build_problem(samples, **kwargs)
[docs]
def materialize_design_matrix(self, samples: Sequence[FitSample], **kwargs):
"""Build the explicit residualized design matrix and target vector."""
residual_samples = self.residualized_samples(samples)
if "cache_directory" in kwargs:
kwargs["cache_directory"] = self._hybrid_cache_directory(
kwargs["cache_directory"]
)
return self.linear_fitter.materialize_design_matrix(residual_samples, **kwargs)
[docs]
def accumulate_normal_equations(self, samples: Sequence[FitSample], **kwargs):
"""Accumulate normal equations from residualized samples."""
residual_samples = self.residualized_samples(samples)
for key in ("cache_directory", "normal_equation_cache_directory"):
if key in kwargs:
kwargs[key] = self._hybrid_cache_directory(kwargs[key])
return self.linear_fitter.accumulate_normal_equations(
residual_samples,
**kwargs,
)
[docs]
def fit(self, samples: Sequence[FitSample], **kwargs):
"""Fit selected linear coefficients against residualized targets."""
residual_samples = self.residualized_samples(samples)
for key in ("cache_directory", "normal_equation_cache_directory"):
if key in kwargs:
kwargs[key] = self._hybrid_cache_directory(kwargs[key])
return self.linear_fitter.fit(residual_samples, **kwargs)
[docs]
def write_back(self, theta: torch.Tensor) -> None:
"""Write fitted coefficients through the delegated linear fitter."""
self.linear_fitter.write_back(theta)
__all__ = [
"FrozenTermFilter",
"FrozenTermSelector",
"HybridLinearFitter",
]