Source code for ufp.workflows.prior

"""Workflow helpers for physical pair priors and two-body warm starts."""

from __future__ import annotations

import math
from collections.abc import Sequence
from dataclasses import dataclass

import torch

from ufp.splines import UniformSpline1DFitResult, fit_uniform_spline_1d
from ufp.terms import (
    PowerLawRepulsionTerm,
    RepulsiveSplineTwoBodyTerm,
    SplinePairTerm,
    SplineTwoBodyTerm,
    UFPModel,
)


@dataclass(frozen=True)
class _PairCoefficientRow:
    """One source two-body coefficient row plus compatible grid metadata."""

    pair: tuple[int, int]
    coeffs: torch.Tensor
    cutoff: float
    spline: str
    full_support_start: float
    symmetric: bool


def _term_grid_metadata(term) -> tuple[float, str, float, bool]:
    """Return grid metadata shared by pair and categorized two-body terms."""
    return (
        float(term.cutoff),
        str(term.spline),
        float(term.full_support_start),
        bool(term.symmetric),
    )


def _isclose(first: float, second: float) -> bool:
    """Return whether two grid scalars are effectively identical."""
    return math.isclose(float(first), float(second), rel_tol=1.0e-12, abs_tol=1.0e-12)


def _validate_compatible_row(
    source: _PairCoefficientRow,
    target_term,
    target_shape: tuple[int, ...],
) -> None:
    """Fail clearly when one coefficient row can not be copied into a target."""
    target_cutoff, target_spline, target_start, target_symmetric = _term_grid_metadata(
        target_term
    )
    if not _isclose(source.cutoff, target_cutoff):
        raise ValueError(
            f"cutoff mismatch for pair {source.pair}: "
            f"{source.cutoff} != {target_cutoff}"
        )
    if source.spline != target_spline:
        raise ValueError(
            f"spline mismatch for pair {source.pair}: "
            f"{source.spline!r} != {target_spline!r}"
        )
    if not _isclose(source.full_support_start, target_start):
        raise ValueError(
            f"full-support start mismatch for pair {source.pair}: "
            f"{source.full_support_start} != {target_start}"
        )
    if source.symmetric != target_symmetric:
        raise ValueError(
            f"symmetry mismatch for pair {source.pair}: "
            f"{source.symmetric} != {target_symmetric}"
        )
    if tuple(source.coeffs.shape) != tuple(target_shape):
        raise ValueError(
            f"coefficient shape mismatch for pair {source.pair}: "
            f"{tuple(source.coeffs.shape)} != {tuple(target_shape)}"
        )


def _as_twobody_terms(value) -> tuple[torch.nn.Module, ...]:
    """Return pair-like terms from a model, single term, or sequence of terms."""
    if isinstance(value, UFPModel):
        return tuple(value.pair_terms)
    if isinstance(
        value, (SplinePairTerm, SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm)
    ):
        return (value,)
    if isinstance(value, Sequence):
        return tuple(value)
    raise TypeError(
        "`source` and `target` must be UFPModel, SplinePairTerm, "
        "SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm, or a sequence of such terms"
    )


def _iter_source_rows(source) -> tuple[_PairCoefficientRow, ...]:
    """Collect active source two-body coefficient rows."""
    rows: list[_PairCoefficientRow] = []
    for term in _as_twobody_terms(source):
        if isinstance(term, SplinePairTerm):
            if not term.enabled:
                continue
            rows.append(
                _PairCoefficientRow(
                    pair=term.pair,
                    coeffs=term.true_coeffs.detach(),
                    cutoff=float(term.cutoff),
                    spline=str(term.spline),
                    full_support_start=float(term.full_support_start),
                    symmetric=bool(term.symmetric),
                )
            )
        elif isinstance(term, (SplineTwoBodyTerm, RepulsiveSplineTwoBodyTerm)):
            coeffs = term.true_coeffs_by_pair.detach()
            for pair in term.active_pair_categories:
                index = term.pair_category_index(pair[0], pair[1])
                rows.append(
                    _PairCoefficientRow(
                        pair=pair,
                        coeffs=coeffs[index],
                        cutoff=float(term.cutoff),
                        spline=str(term.spline),
                        full_support_start=float(term.full_support_start),
                        symmetric=bool(term.symmetric),
                    )
                )
        else:
            raise TypeError("all copied terms must be two-body spline terms")
    return tuple(rows)


def _source_row_map(source) -> dict[tuple[int, int], _PairCoefficientRow]:
    """Index source rows by pair while rejecting ambiguous duplicates."""
    rows_by_pair: dict[tuple[int, int], _PairCoefficientRow] = {}
    for row in _iter_source_rows(source):
        if row.pair in rows_by_pair:
            raise ValueError(f"multiple source coefficient rows for pair {row.pair}")
        rows_by_pair[row.pair] = row
    return rows_by_pair


[docs] def project_pair_prior_to_twobody( prior_term: PowerLawRepulsionTerm, *, coeff_size: int, spline: str = "cubic", full_support_start: float, n_samples: int | None = None, derivative_weight: float = 1.0, trainable: bool = True, fittable: bool = True, ) -> tuple[SplineTwoBodyTerm, dict[tuple[int, int], UniformSpline1DFitResult]]: """ Project one pair-dependent analytic prior onto a spline two-body term. Args: prior_term: Fitted analytic prior term to project. coeff_size: Number of 1D spline coefficients per pair category. spline: Spline family for the projected term. full_support_start: Lower distance where the projected spline has full support. n_samples: Optional number of midpoint samples per pair category. derivative_weight: Weight for radial-derivative projection rows. trainable: Whether projected spline coefficients require gradients. fittable: Whether the projected term exposes coefficients to least-squares fitters. Returns: The projected categorized two-body term and per-active-pair diagnostics. """ if not isinstance(prior_term, PowerLawRepulsionTerm): raise TypeError("`prior_term` must be PowerLawRepulsionTerm") if prior_term.cutoff is None: raise ValueError("`prior_term.cutoff` must be set") cutoff = float(prior_term.cutoff) full_support_start = float(full_support_start) if full_support_start >= cutoff: raise ValueError("`full_support_start` must be smaller than the cutoff") if prior_term.atomic_types is None: raise ValueError("`prior_term.atomic_types` must be set") prefactors = prior_term.true_prefactors_by_pair.detach() coeffs_by_pair = torch.zeros( (len(prior_term.pair_categories), int(coeff_size)), dtype=prefactors.dtype, device=prefactors.device, ) diagnostics: dict[tuple[int, int], UniformSpline1DFitResult] = {} for pair in prior_term.active_pair_categories: pair_index = prior_term.pair_category_index(pair[0], pair[1]) prefactor = prefactors[pair_index].detach() def prior_function(distances, prefactor=prefactor): values = prefactor.to(dtype=distances.dtype, device=distances.device) return values / torch.pow( distances.clamp_min(prior_term.eps), prior_term.power, ) fit = fit_uniform_spline_1d( prior_function, coeff_size=coeff_size, lower_full_support=full_support_start, upper_full_support=cutoff, spline=spline, n_samples=n_samples, derivative_weight=derivative_weight, dtype=prefactors.dtype, device=prefactors.device, ) coeffs_by_pair[pair_index].copy_(fit.coeffs) diagnostics[pair] = fit projected = SplineTwoBodyTerm( cutoff=cutoff, atomic_types=prior_term.atomic_types, coeffs_by_pair=coeffs_by_pair, active_pairs=prior_term.active_pair_categories, symmetric=prior_term.symmetric, spline=spline, full_support_start=full_support_start, eps=prior_term.eps, trainable=trainable, fittable=fittable, ) return projected, diagnostics
[docs] def copy_twobody_coefficients(source, target, *, strict: bool = True) -> int: """ Copy matching two-body spline coefficient rows from ``source`` to ``target``. ``source`` and ``target`` may be models, individual two-body terms, or sequences of such terms. Only active target rows are copied. With ``strict=True``, every active target row must exist in the source and must have matching spline metadata. Returns: Number of coefficient rows copied. """ rows_by_pair = _source_row_map(source) copied = 0 for term in _as_twobody_terms(target): if isinstance(term, SplinePairTerm): source_row = rows_by_pair.get(term.pair) if source_row is None: if strict: raise ValueError( f"missing source coefficients for pair {term.pair}" ) continue _validate_compatible_row(source_row, term, tuple(term.true_coeffs.shape)) term._write_true_coeffs(source_row.coeffs) copied += 1 elif isinstance(term, SplineTwoBodyTerm): target_coeffs = term.true_coeffs_by_pair.detach().clone() changed = False for pair in term.active_pair_categories: source_row = rows_by_pair.get(pair) if source_row is None: if strict: raise ValueError(f"missing source coefficients for pair {pair}") continue row_index = term.pair_category_index(pair[0], pair[1]) _validate_compatible_row( source_row, term, tuple(target_coeffs[row_index].shape), ) target_coeffs[row_index].copy_(source_row.coeffs.to(target_coeffs)) copied += 1 changed = True if changed: term._write_true_coeffs_by_pair(target_coeffs) else: raise TypeError("all target terms must be two-body spline terms") return copied
__all__ = [ "copy_twobody_coefficients", "project_pair_prior_to_twobody", ]