"""Workflow helpers for physical pair priors and two-body warm starts."""
from __future__ import annotations
import math
from collections.abc import Sequence
from dataclasses import dataclass
import torch
from ufp.splines import UniformSpline1DFitResult, fit_uniform_spline_1d
from ufp.terms import (
PowerLawRepulsionTerm,
RepulsiveSplineTwoBodyTerm,
SplinePairTerm,
SplineTwoBodyTerm,
UFPModel,
)
@dataclass(frozen=True)
class _PairCoefficientRow:
"""One source two-body coefficient row plus compatible grid metadata."""
pair: tuple[int, int]
coeffs: torch.Tensor
cutoff: float
spline: str
full_support_start: float
symmetric: bool
def _term_grid_metadata(term) -> tuple[float, str, float, bool]:
"""Return grid metadata shared by pair and categorized two-body terms."""
return (
float(term.cutoff),
str(term.spline),
float(term.full_support_start),
bool(term.symmetric),
)
def _isclose(first: float, second: float) -> bool:
"""Return whether two grid scalars are effectively identical."""
return math.isclose(float(first), float(second), rel_tol=1.0e-12, abs_tol=1.0e-12)
def _validate_compatible_row(
source: _PairCoefficientRow,
target_term,
target_shape: tuple[int, ...],
) -> None:
"""Fail clearly when one coefficient row can not be copied into a target."""
target_cutoff, target_spline, target_start, target_symmetric = _term_grid_metadata(
target_term
)
if not _isclose(source.cutoff, target_cutoff):
raise ValueError(
f"cutoff mismatch for pair {source.pair}: "
f"{source.cutoff} != {target_cutoff}"
)
if source.spline != target_spline:
raise ValueError(
f"spline mismatch for pair {source.pair}: "
f"{source.spline!r} != {target_spline!r}"
)
if not _isclose(source.full_support_start, target_start):
raise ValueError(
f"full-support start mismatch for pair {source.pair}: "
f"{source.full_support_start} != {target_start}"
)
if source.symmetric != target_symmetric:
raise ValueError(
f"symmetry mismatch for pair {source.pair}: "
f"{source.symmetric} != {target_symmetric}"
)
if tuple(source.coeffs.shape) != tuple(target_shape):
raise ValueError(
f"coefficient shape mismatch for pair {source.pair}: "
f"{tuple(source.coeffs.shape)} != {tuple(target_shape)}"
)
def _as_twobody_terms(value) -> tuple[torch.nn.Module, ...]:
"""Return pair-like terms from a model, single term, or sequence of terms."""
if isinstance(value, UFPModel):
return tuple(value.pair_terms)
if isinstance(
value, (SplinePairTerm, SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm)
):
return (value,)
if isinstance(value, Sequence):
return tuple(value)
raise TypeError(
"`source` and `target` must be UFPModel, SplinePairTerm, "
"SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm, or a sequence of such terms"
)
def _iter_source_rows(source) -> tuple[_PairCoefficientRow, ...]:
"""Collect active source two-body coefficient rows."""
rows: list[_PairCoefficientRow] = []
for term in _as_twobody_terms(source):
if isinstance(term, SplinePairTerm):
if not term.enabled:
continue
rows.append(
_PairCoefficientRow(
pair=term.pair,
coeffs=term.true_coeffs.detach(),
cutoff=float(term.cutoff),
spline=str(term.spline),
full_support_start=float(term.full_support_start),
symmetric=bool(term.symmetric),
)
)
elif isinstance(term, (SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm)):
coeffs = term.true_coeffs_by_pair.detach()
for pair in term.active_pair_categories:
index = term.pair_category_index(pair[0], pair[1])
rows.append(
_PairCoefficientRow(
pair=pair,
coeffs=coeffs[index],
cutoff=float(term.cutoff),
spline=str(term.spline),
full_support_start=float(term.full_support_start),
symmetric=bool(term.symmetric),
)
)
else:
raise TypeError("all copied terms must be two-body spline terms")
return tuple(rows)
def _source_row_map(source) -> dict[tuple[int, int], _PairCoefficientRow]:
"""Index source rows by pair while rejecting ambiguous duplicates."""
rows_by_pair: dict[tuple[int, int], _PairCoefficientRow] = {}
for row in _iter_source_rows(source):
if row.pair in rows_by_pair:
raise ValueError(f"multiple source coefficient rows for pair {row.pair}")
rows_by_pair[row.pair] = row
return rows_by_pair
[docs]
def project_pair_prior_to_twobody(
prior_term: PowerLawRepulsionTerm,
*,
coeff_size: int,
spline: str = "cubic",
full_support_start: float,
n_samples: int | None = None,
derivative_weight: float = 1.0,
trainable: bool = True,
fittable: bool = True,
) -> tuple[SplineTwoBodyTerm, dict[tuple[int, int], UniformSpline1DFitResult]]:
"""
Project one pair-dependent analytic prior onto a spline two-body term.
Args:
prior_term: Fitted analytic prior term to project.
coeff_size: Number of 1D spline coefficients per pair category.
spline: Spline family for the projected term.
full_support_start: Lower distance where the projected spline has full
support.
n_samples: Optional number of midpoint samples per pair category.
derivative_weight: Weight for radial-derivative projection rows.
trainable: Whether projected spline coefficients require gradients.
fittable: Whether the projected term exposes coefficients to
least-squares fitters.
Returns:
The projected categorized two-body term and per-active-pair diagnostics.
"""
if not isinstance(prior_term, PowerLawRepulsionTerm):
raise TypeError("`prior_term` must be PowerLawRepulsionTerm")
if prior_term.cutoff is None:
raise ValueError("`prior_term.cutoff` must be set")
cutoff = float(prior_term.cutoff)
full_support_start = float(full_support_start)
if full_support_start >= cutoff:
raise ValueError("`full_support_start` must be smaller than the cutoff")
if prior_term.atomic_types is None:
raise ValueError("`prior_term.atomic_types` must be set")
prefactors = prior_term.true_prefactors_by_pair.detach()
coeffs_by_pair = torch.zeros(
(len(prior_term.pair_categories), int(coeff_size)),
dtype=prefactors.dtype,
device=prefactors.device,
)
diagnostics: dict[tuple[int, int], UniformSpline1DFitResult] = {}
for pair in prior_term.active_pair_categories:
pair_index = prior_term.pair_category_index(pair[0], pair[1])
prefactor = prefactors[pair_index].detach()
def prior_function(distances, prefactor=prefactor):
values = prefactor.to(dtype=distances.dtype, device=distances.device)
return values / torch.pow(
distances.clamp_min(prior_term.eps),
prior_term.power,
)
fit = fit_uniform_spline_1d(
prior_function,
coeff_size=coeff_size,
lower_full_support=full_support_start,
upper_full_support=cutoff,
spline=spline,
n_samples=n_samples,
derivative_weight=derivative_weight,
dtype=prefactors.dtype,
device=prefactors.device,
)
coeffs_by_pair[pair_index].copy_(fit.coeffs)
diagnostics[pair] = fit
projected = SplineTwoBodyTerm(
cutoff=cutoff,
atomic_types=prior_term.atomic_types,
coeffs_by_pair=coeffs_by_pair,
active_pairs=prior_term.active_pair_categories,
symmetric=prior_term.symmetric,
spline=spline,
full_support_start=full_support_start,
eps=prior_term.eps,
trainable=trainable,
fittable=fittable,
)
return projected, diagnostics
[docs]
def copy_twobody_coefficients(source, target, *, strict: bool = True) -> int:
"""
Copy matching two-body spline coefficient rows from ``source`` to ``target``.
``source`` and ``target`` may be models, individual two-body terms, or
sequences of such terms. Only active target rows are copied. With
``strict=True``, every active target row must exist in the source and must
have matching spline metadata.
Returns:
Number of coefficient rows copied.
"""
rows_by_pair = _source_row_map(source)
copied = 0
for term in _as_twobody_terms(target):
if isinstance(term, SplinePairTerm):
source_row = rows_by_pair.get(term.pair)
if source_row is None:
if strict:
raise ValueError(
f"missing source coefficients for pair {term.pair}"
)
continue
_validate_compatible_row(source_row, term, tuple(term.true_coeffs.shape))
term._write_true_coeffs(source_row.coeffs)
copied += 1
elif isinstance(term, SplineTwoBodyTerm):
target_coeffs = term.true_coeffs_by_pair.detach().clone()
changed = False
for pair in term.active_pair_categories:
source_row = rows_by_pair.get(pair)
if source_row is None:
if strict:
raise ValueError(f"missing source coefficients for pair {pair}")
continue
row_index = term.pair_category_index(pair[0], pair[1])
_validate_compatible_row(
source_row,
term,
tuple(target_coeffs[row_index].shape),
)
target_coeffs[row_index].copy_(source_row.coeffs.to(target_coeffs))
copied += 1
changed = True
if changed:
term._write_true_coeffs_by_pair(target_coeffs)
else:
raise TypeError("all target terms must be two-body spline terms")
return copied
__all__ = [
"copy_twobody_coefficients",
"project_pair_prior_to_twobody",
]