Source code for ufp.terms.analytic

"""Autograd-friendly analytic pair prior terms."""

from __future__ import annotations

import hashlib
import math
from collections.abc import Mapping, Sequence

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._base import PairTerm
from ufp.terms._shared import (
    accumulate_pair_energies,
    empty_atomwise_output,
    pair_weight,
)
from ufp.terms.cutoffs import CutoffEnvelope, normalize_cutoff_envelope
from ufp.terms.twobody import _active_pair_mask, _canonical_pair, _pair_categories
from ufp.terms.zbl import (
    _COULOMB_EV_ANGSTROM,
    _SCREENING_LENGTH_FACTOR,
    _ZBL_COEFFS,
    _ZBL_EXPONENTS,
)


def _pair_label(pair: tuple[int, int]) -> str:
    """Format one pair category for error messages."""
    return f"({pair[0]}, {pair[1]})"


def _validate_positive(name: str, value: float) -> float:
    """Validate a finite positive scalar."""
    value = float(value)
    if not math.isfinite(value) or value <= 0.0:
        raise ValueError(f"`{name}` must be a finite positive value")
    return value


def _validate_non_negative_tensor(name: str, tensor: torch.Tensor) -> None:
    """Reject non-finite or negative parameter initial values."""
    if not torch.all(torch.isfinite(tensor)):
        raise ValueError(f"`{name}` must contain only finite values")
    if torch.any(tensor < 0.0):
        raise ValueError(f"`{name}` must be non-negative")


def _validate_positive_tensor(name: str, tensor: torch.Tensor) -> None:
    """Reject non-finite or non-positive parameter initial values."""
    if not torch.all(torch.isfinite(tensor)):
        raise ValueError(f"`{name}` must contain only finite values")
    if torch.any(tensor <= 0.0):
        raise ValueError(f"`{name}` must be positive")


class _CategorizedAnalyticPairTerm(PairTerm):
    """Shared category, parameter, and cutoff handling for analytic priors."""

    prior_kind = "analytic_pair"
    pair_parameter_names: tuple[str, ...] = ()

    def __init__(
        self,
        *,
        cutoff: float,
        atomic_types: Sequence[int],
        active_pairs: Sequence[tuple[int, int]] | None = None,
        symmetric: bool = True,
        cutoff_envelope: CutoffEnvelope | str | None = None,
    ) -> None:
        """Store pair categories and cutoff metadata."""
        cutoff = _validate_positive("cutoff", cutoff)
        super().__init__(cutoff=cutoff, atomic_types=atomic_types)
        if self.atomic_types is None or not self.atomic_types:
            raise ValueError("`atomic_types` must contain at least one element")
        self.symmetric = bool(symmetric)
        self.cutoff_envelope = normalize_cutoff_envelope(
            cutoff_envelope,
            cutoff=cutoff,
        )
        pair_categories = _pair_categories(self.atomic_types, symmetric=self.symmetric)
        object.__setattr__(self, "_pair_categories", pair_categories)
        object.__setattr__(
            self,
            "_pair_index",
            {pair: index for index, pair in enumerate(pair_categories)},
        )
        active_pair_mask = _active_pair_mask(
            pair_categories,
            active_pairs=active_pairs,
            symmetric=self.symmetric,
        )
        self.register_buffer(
            "active_pair_mask",
            active_pair_mask,
            persistent=False,
        )
        object.__setattr__(
            self,
            "_active_pair_indices",
            tuple(
                index
                for index, enabled in enumerate(active_pair_mask.tolist())
                if enabled
            ),
        )

    @property
    def pair_categories(self) -> tuple[tuple[int, int], ...]:
        """Return the full ordered list of pair categories owned by this term."""
        return self._pair_categories

    @property
    def active_pair_categories(self) -> tuple[tuple[int, int], ...]:
        """Return the subset of pair categories enabled for evaluation."""
        return tuple(self.pair_categories[index] for index in self._active_pair_indices)

    def canonical_pair(self, first: int, second: int) -> tuple[int, int]:
        """Normalize a pair key using this term's symmetry convention."""
        return _canonical_pair(first, second, symmetric=self.symmetric)

    def pair_category_index(self, first: int, second: int) -> int:
        """Return the category index for one pair."""
        pair = self.canonical_pair(first, second)
        try:
            return self._pair_index[pair]
        except KeyError as exc:
            raise KeyError(f"pair {pair} is not part of this term") from exc

    def covers_pair(self, first_atomic_number: int, second_atomic_number: int) -> bool:
        """Report whether this term evaluates the requested pair category."""
        pair = self.canonical_pair(first_atomic_number, second_atomic_number)
        index = self._pair_index.get(pair)
        if index is None:
            return False
        return bool(self.active_pair_mask[index].item())

    def is_pair_active(self, first: int, second: int) -> bool:
        """Report whether a canonical pair category remains enabled."""
        return bool(
            self.active_pair_mask[self.pair_category_index(first, second)].item()
        )

    def _coerce_pair_values(
        self,
        name: str,
        values_by_pair,
        *,
        dtype: torch.dtype | None,
    ) -> torch.Tensor:
        """Normalize mapping or tensor values into pair-category order."""
        if isinstance(values_by_pair, Mapping):
            values = torch.zeros(len(self.pair_categories), dtype=dtype)
            seen: set[tuple[int, int]] = set()
            for pair, value in values_by_pair.items():
                if len(pair) != 2:
                    raise ValueError(f"`{name}` mapping keys must contain two values")
                canonical = self.canonical_pair(pair[0], pair[1])
                if canonical in seen:
                    raise ValueError(
                        "duplicate parameter after canonicalization for "
                        f"{_pair_label(canonical)}"
                    )
                try:
                    pair_index = self._pair_index[canonical]
                except KeyError as exc:
                    raise ValueError(
                        f"unknown parameter pair category {_pair_label(canonical)}"
                    ) from exc
                values[pair_index] = torch.as_tensor(value, dtype=dtype)
                seen.add(canonical)

            missing = [pair for pair in self.active_pair_categories if pair not in seen]
            if missing:
                raise ValueError(
                    f"missing `{name}` for active pair categories: "
                    + ", ".join(_pair_label(pair) for pair in missing)
                )
            return values

        tensor = torch.as_tensor(values_by_pair, dtype=dtype)
        if tensor.ndim == 0 and len(self.pair_categories) == 1:
            tensor = tensor.reshape(1)
        if tensor.ndim != 1:
            raise ValueError(f"`{name}` must have shape (n_pairs,)")
        if tensor.shape[0] != len(self.pair_categories):
            raise ValueError(
                f"`{name}.shape[0]` must equal {len(self.pair_categories)} "
                f"for atomic_types={self.atomic_types}, got {tensor.shape[0]}"
            )
        if not torch.all(torch.isfinite(tensor)):
            raise ValueError(f"`{name}` must contain only finite values")
        return tensor

    def _register_pair_parameter(
        self,
        name: str,
        values_by_pair,
        *,
        dtype: torch.dtype | None,
        trainable: bool,
    ) -> torch.Tensor:
        """Register one pair-category parameter tensor."""
        tensor = self._coerce_pair_values(name, values_by_pair, dtype=dtype)
        parameter = torch.nn.Parameter(tensor, requires_grad=bool(trainable))
        setattr(self, name, parameter)
        return tensor

    def _active_initial_values(self, tensor: torch.Tensor) -> torch.Tensor:
        """Return initial values for active pair categories only."""
        active_indices = torch.tensor(
            self._active_pair_indices,
            dtype=torch.int64,
            device=tensor.device,
        )
        return tensor.index_select(0, active_indices)

    def _pair_parameter(
        self,
        name: str,
        pair_category: torch.Tensor,
        inputs: UFPInput,
    ) -> torch.Tensor:
        """Return one pair-parameter vector for selected neighbor-list rows."""
        values = getattr(self, name).to(device=inputs.device, dtype=inputs.dtype)
        return values.index_select(0, pair_category)

    def _selected_pairs(
        self,
        inputs: UFPInput,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
        """Return a full-list mask, category indices, and distances for active pairs."""
        if inputs.neighbor_list is None:
            raise RuntimeError(f"{type(self).__name__} requires a neighbor list")
        if not self.symmetric and not inputs.neighbor_list.full_list:
            raise RuntimeError(
                f"asymmetric {type(self).__name__} terms require a full neighbor list"
            )
        if not self._active_pair_indices:
            return None

        pair_category = inputs.pair_category_indices(
            self.atomic_types,
            symmetric=self.symmetric,
        )
        handled_mask = pair_category >= 0
        if len(self._active_pair_indices) != len(self.pair_categories):
            active_pair_mask = self.active_pair_mask.to(device=inputs.device)
            active_handled_mask = torch.zeros_like(handled_mask)
            active_handled_mask[handled_mask] = active_pair_mask.index_select(
                0,
                pair_category[handled_mask],
            )
            handled_mask = active_handled_mask
        if not torch.any(handled_mask):
            return None

        distances = inputs.pair_distances(handled_mask)
        support_mask = distances < float(self.cutoff)
        if not torch.any(support_mask):
            return None

        selected_indices = torch.nonzero(handled_mask, as_tuple=False).reshape(-1)
        selected_indices = selected_indices[support_mask]
        selected_mask = torch.zeros_like(handled_mask)
        selected_mask[selected_indices] = True
        return (
            selected_mask,
            pair_category[handled_mask][support_mask],
            distances[support_mask],
        )

    def _energy_from_category(
        self,
        distances: torch.Tensor,
        pair_category: torch.Tensor,
        inputs: UFPInput,
    ) -> torch.Tensor:
        """Evaluate unclipped pair energies for selected rows."""
        raise NotImplementedError

    def radial_values(
        self,
        pair: Sequence[int],
        distances: object,
    ) -> torch.Tensor:
        """Evaluate this prior for one pair channel on arbitrary radial samples."""
        if len(pair) != 2:
            raise ValueError("`pair` must contain exactly two atomic numbers")
        pair_index = self.pair_category_index(pair[0], pair[1])
        r = torch.as_tensor(distances)
        if not r.is_floating_point():
            r = r.to(dtype=torch.get_default_dtype())
        category = torch.full(
            (r.numel(),), pair_index, dtype=torch.int64, device=r.device
        )
        values = self._radial_values_from_category(r.reshape(-1), category)
        return values.reshape(r.shape)

    def _radial_values_from_category(
        self,
        distances: torch.Tensor,
        pair_category: torch.Tensor,
    ) -> torch.Tensor:
        """Evaluate radial values for category indices without an input object."""
        raise NotImplementedError

    def parameter_state_hash(self) -> str:
        """Return a deterministic hash of the current prior parameter tensors."""
        hasher = hashlib.sha256()
        for name, tensor in sorted(self.state_dict().items()):
            detached = tensor.detach().cpu().contiguous()
            hasher.update(name.encode("utf-8"))
            hasher.update(str(detached.dtype).encode("utf-8"))
            hasher.update(
                str(tuple(int(dim) for dim in detached.shape)).encode("utf-8")
            )
            hasher.update(detached.numpy().tobytes())
        return hasher.hexdigest()

    def projection_metadata(self) -> dict[str, object]:
        """Return metadata needed by offline projection workflows."""
        return {
            "kind": self.prior_kind,
            "cutoff": self.cutoff,
            "atomic_types": self.atomic_types,
            "pair_categories": self.pair_categories,
            "active_pair_categories": self.active_pair_categories,
            "symmetric": self.symmetric,
            "cutoff_envelope": self.cutoff_envelope.to_dict(),
            "pair_parameter_names": self.pair_parameter_names,
            "parameter_state_hash": self.parameter_state_hash(),
        }

    def forward(self, inputs: UFPInput) -> UFPOutput:
        """Evaluate active analytic pair energies."""
        selected = self._selected_pairs(inputs)
        if selected is None:
            return empty_atomwise_output(inputs, forces=False)

        selected_mask, pair_category, distances = selected
        pair_energy = self._energy_from_category(distances, pair_category, inputs)
        envelope = self.cutoff_envelope.values(distances).to(
            device=inputs.device,
            dtype=inputs.dtype,
        )
        weighted_pair_energy = pair_weight(inputs) * pair_energy * envelope
        return accumulate_pair_energies(
            inputs,
            selected_mask,
            weighted_pair_energy=weighted_pair_energy,
        )


[docs] class InversePowerPairPrior(_CategorizedAnalyticPairTerm): """Pair-dependent inverse-power analytic prior.""" prior_kind = "inverse_power" pair_parameter_names = ("prefactors_by_pair",) def __init__( self, *, cutoff: float, atomic_types: Sequence[int], prefactors_by_pair, power: float = 2.0, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, eps: float = 1.0e-12, cutoff_envelope: CutoffEnvelope | str | None = None, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store inverse-power pair prefactors.""" super().__init__( cutoff=cutoff, atomic_types=atomic_types, active_pairs=active_pairs, symmetric=symmetric, cutoff_envelope=cutoff_envelope, ) self.power = _validate_positive("power", power) self.eps = _validate_positive("eps", eps) self._register_pair_parameter( "prefactors_by_pair", prefactors_by_pair, dtype=dtype, trainable=trainable, ) def _radial_values_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, ) -> torch.Tensor: prefactors = self.prefactors_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) values = prefactors / torch.pow(distances.clamp_min(self.eps), self.power) return values * self.cutoff_envelope.values(distances) def _energy_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, inputs: UFPInput, ) -> torch.Tensor: prefactors = self._pair_parameter("prefactors_by_pair", pair_category, inputs) return prefactors / torch.pow(distances.clamp_min(self.eps), self.power)
[docs] class DampedInversePowerPairPrior(_CategorizedAnalyticPairTerm): """Inverse-power prior with a short-range exponential damping factor.""" prior_kind = "damped_inverse_power" pair_parameter_names = ("prefactors_by_pair", "damping_rates_by_pair") def __init__( self, *, cutoff: float, atomic_types: Sequence[int], prefactors_by_pair, damping_rates_by_pair, power: float = 6.0, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, eps: float = 1.0e-12, cutoff_envelope: CutoffEnvelope | str | None = None, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store damped inverse-power parameters.""" super().__init__( cutoff=cutoff, atomic_types=atomic_types, active_pairs=active_pairs, symmetric=symmetric, cutoff_envelope=cutoff_envelope, ) self.power = _validate_positive("power", power) self.eps = _validate_positive("eps", eps) self._register_pair_parameter( "prefactors_by_pair", prefactors_by_pair, dtype=dtype, trainable=trainable, ) damping = self._register_pair_parameter( "damping_rates_by_pair", damping_rates_by_pair, dtype=dtype, trainable=trainable, ) _validate_non_negative_tensor("damping_rates_by_pair", damping) def _undamped_values( self, distances: torch.Tensor, prefactors: torch.Tensor, damping_rates: torch.Tensor, ) -> torch.Tensor: """Evaluate damped inverse-power values before the cutoff envelope.""" safe_distances = distances.clamp_min(self.eps) damping = 1.0 - torch.exp(-damping_rates * safe_distances) return ( prefactors * torch.pow(damping, self.power) / torch.pow( safe_distances, self.power, ) ) def _radial_values_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, ) -> torch.Tensor: prefactors = self.prefactors_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) rates = self.damping_rates_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) return self._undamped_values( distances, prefactors, rates, ) * self.cutoff_envelope.values(distances) def _energy_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, inputs: UFPInput, ) -> torch.Tensor: prefactors = self._pair_parameter("prefactors_by_pair", pair_category, inputs) rates = self._pair_parameter("damping_rates_by_pair", pair_category, inputs) return self._undamped_values(distances, prefactors, rates)
[docs] class ExponentialRepulsionPairPrior(_CategorizedAnalyticPairTerm): """Pair-dependent exponential repulsion prior.""" prior_kind = "exponential_repulsion" pair_parameter_names = ("amplitudes_by_pair", "decay_rates_by_pair") def __init__( self, *, cutoff: float, atomic_types: Sequence[int], amplitudes_by_pair, decay_rates_by_pair, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, cutoff_envelope: CutoffEnvelope | str | None = None, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store exponential amplitudes and decay rates.""" super().__init__( cutoff=cutoff, atomic_types=atomic_types, active_pairs=active_pairs, symmetric=symmetric, cutoff_envelope=cutoff_envelope, ) self._register_pair_parameter( "amplitudes_by_pair", amplitudes_by_pair, dtype=dtype, trainable=trainable, ) rates = self._register_pair_parameter( "decay_rates_by_pair", decay_rates_by_pair, dtype=dtype, trainable=trainable, ) _validate_positive_tensor( "decay_rates_by_pair", self._active_initial_values(rates), ) def _raw_values( self, distances: torch.Tensor, amplitudes: torch.Tensor, decay_rates: torch.Tensor, ) -> torch.Tensor: """Evaluate exponential values before the cutoff envelope.""" return amplitudes * torch.exp(-decay_rates * distances) def _radial_values_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, ) -> torch.Tensor: amplitudes = self.amplitudes_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) rates = self.decay_rates_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) return self._raw_values( distances, amplitudes, rates, ) * self.cutoff_envelope.values(distances) def _energy_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, inputs: UFPInput, ) -> torch.Tensor: amplitudes = self._pair_parameter("amplitudes_by_pair", pair_category, inputs) rates = self._pair_parameter("decay_rates_by_pair", pair_category, inputs) return self._raw_values(distances, amplitudes, rates)
[docs] class MorsePairPrior(_CategorizedAnalyticPairTerm): """Pair-dependent Morse-like analytic interaction prior.""" prior_kind = "morse" pair_parameter_names = ( "depths_by_pair", "equilibrium_distances_by_pair", "widths_by_pair", ) def __init__( self, *, cutoff: float, atomic_types: Sequence[int], depths_by_pair, equilibrium_distances_by_pair, widths_by_pair, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, cutoff_envelope: CutoffEnvelope | str | None = None, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store Morse depths, equilibrium distances, and widths.""" super().__init__( cutoff=cutoff, atomic_types=atomic_types, active_pairs=active_pairs, symmetric=symmetric, cutoff_envelope=cutoff_envelope, ) depths = self._register_pair_parameter( "depths_by_pair", depths_by_pair, dtype=dtype, trainable=trainable, ) _validate_non_negative_tensor("depths_by_pair", depths) equilibrium = self._register_pair_parameter( "equilibrium_distances_by_pair", equilibrium_distances_by_pair, dtype=dtype, trainable=trainable, ) _validate_non_negative_tensor("equilibrium_distances_by_pair", equilibrium) widths = self._register_pair_parameter( "widths_by_pair", widths_by_pair, dtype=dtype, trainable=trainable, ) _validate_positive_tensor("widths_by_pair", self._active_initial_values(widths)) def _raw_values( self, distances: torch.Tensor, depths: torch.Tensor, equilibrium_distances: torch.Tensor, widths: torch.Tensor, ) -> torch.Tensor: """Evaluate Morse values before the cutoff envelope.""" exponent = torch.exp(-widths * (distances - equilibrium_distances)) return depths * ((1.0 - exponent).square() - 1.0) def _radial_values_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, ) -> torch.Tensor: depths = self.depths_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) equilibrium = self.equilibrium_distances_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) widths = self.widths_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) return self._raw_values( distances, depths, equilibrium, widths, ) * self.cutoff_envelope.values(distances) def _energy_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, inputs: UFPInput, ) -> torch.Tensor: depths = self._pair_parameter("depths_by_pair", pair_category, inputs) equilibrium = self._pair_parameter( "equilibrium_distances_by_pair", pair_category, inputs, ) widths = self._pair_parameter("widths_by_pair", pair_category, inputs) return self._raw_values(distances, depths, equilibrium, widths)
[docs] class ScaledZBLPairPrior(_CategorizedAnalyticPairTerm): """ZBL screened nuclear repulsion with trainable pair-channel scales.""" prior_kind = "scaled_zbl" pair_parameter_names = ("scales_by_pair",) def __init__( self, *, cutoff: float, atomic_types: Sequence[int], scales_by_pair, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, eps: float = 1.0e-12, cutoff_envelope: CutoffEnvelope | str | None = None, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store ZBL scale factors.""" super().__init__( cutoff=cutoff, atomic_types=atomic_types, active_pairs=active_pairs, symmetric=symmetric, cutoff_envelope=cutoff_envelope, ) self.eps = _validate_positive("eps", eps) scales = self._register_pair_parameter( "scales_by_pair", scales_by_pair, dtype=dtype, trainable=trainable, ) _validate_non_negative_tensor("scales_by_pair", scales) first_z = torch.tensor( [pair[0] for pair in self.pair_categories], dtype=torch.float64, ) second_z = torch.tensor( [pair[1] for pair in self.pair_categories], dtype=torch.float64, ) self.register_buffer("_first_z_by_pair", first_z, persistent=False) self.register_buffer("_second_z_by_pair", second_z, persistent=False) def _zbl_values( self, distances: torch.Tensor, pair_category: torch.Tensor, scales: torch.Tensor, ) -> torch.Tensor: """Evaluate scaled ZBL values before the cutoff envelope.""" dtype = distances.dtype device = distances.device z1 = self._first_z_by_pair.to(dtype=dtype, device=device).index_select( 0, pair_category, ) z2 = self._second_z_by_pair.to(dtype=dtype, device=device).index_select( 0, pair_category, ) screening = _SCREENING_LENGTH_FACTOR / ( torch.pow(z1, 0.23) + torch.pow(z2, 0.23) ) scaled = distances.clamp_min(self.eps) / screening coeffs = torch.tensor(_ZBL_COEFFS, dtype=dtype, device=device) exponents = torch.tensor(_ZBL_EXPONENTS, dtype=dtype, device=device) exp_terms = torch.exp(-scaled[:, None] * exponents[None, :]) phi = torch.sum(coeffs[None, :] * exp_terms, dim=1) prefactor = _COULOMB_EV_ANGSTROM * z1 * z2 return scales * prefactor * phi / distances.clamp_min(self.eps) def _radial_values_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, ) -> torch.Tensor: scales = self.scales_by_pair.to( device=distances.device, dtype=distances.dtype, ).index_select(0, pair_category) return self._zbl_values( distances, pair_category, scales, ) * self.cutoff_envelope.values(distances) def _energy_from_category( self, distances: torch.Tensor, pair_category: torch.Tensor, inputs: UFPInput, ) -> torch.Tensor: scales = self._pair_parameter("scales_by_pair", pair_category, inputs) return self._zbl_values(distances, pair_category, scales)
__all__ = [ "DampedInversePowerPairPrior", "ExponentialRepulsionPairPrior", "InversePowerPairPrior", "MorsePairPrior", "ScaledZBLPairPrior", ]