Source code for ufp.terms.powerlaw

"""Trainable pair-dependent inverse-power repulsion terms."""

from __future__ import annotations

from collections.abc import Mapping, Sequence

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._base import PairTerm
from ufp.terms._shared import (
    accumulate_pair_energies,
    empty_atomwise_output,
    pair_weight,
)
from ufp.terms.twobody import _active_pair_mask, _canonical_pair, _pair_categories


def _pair_label(pair: tuple[int, int]) -> str:
    """Format one pair category for error messages."""
    return f"({pair[0]}, {pair[1]})"


[docs] class PowerLawRepulsionTerm(PairTerm): """ Pair-dependent inverse-power repulsion prior. The term evaluates .. math:: E_{ij} = a_{Z_i, Z_j} / \\max(r_{ij}, \\epsilon)^p where ``a`` is a trainable pair-channel prefactor. Forces are intentionally not returned by this term; callers that request forces should use the normal autograd path via ``derive_forces=True``. Args: cutoff: Maximum pair distance included in the interaction. atomic_types: Atomic numbers used to enumerate pair categories. prefactors_by_pair: Either a tensor with shape ``(n_pair_categories,)`` in the term's category order, or a mapping from pair tuples to prefactor values for all active pair categories. power: Positive inverse-power exponent. ``2.0`` gives ``a / r**2``. active_pairs: Optional subset of pair categories that should be evaluated. Inactive categories may have zero prefactors. symmetric: If ``True``, treat ``(a, b)`` and ``(b, a)`` as the same pair category. eps: Minimum distance used before applying the inverse power. trainable: Whether prefactors require gradients. dtype: Optional dtype used when converting prefactors to a tensor. Examples: >>> term = PowerLawRepulsionTerm( ... cutoff=2.5, ... atomic_types=[1], ... prefactors_by_pair={(1, 1): 0.4}, ... ) >>> term.covers_pair(1, 1) True """ def __init__( self, *, cutoff: float, atomic_types: Sequence[int], prefactors_by_pair, power: float = 2.0, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, eps: float = 1.0e-12, trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Store one trainable inverse-power prefactor per pair category.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) if self.atomic_types is None or not self.atomic_types: raise ValueError("`atomic_types` must contain at least one element") self.symmetric = bool(symmetric) self.power = float(power) self.eps = float(eps) if self.power <= 0.0: raise ValueError("`power` must be positive") if self.eps <= 0.0: raise ValueError("`eps` must be positive") pair_categories = _pair_categories(self.atomic_types, symmetric=self.symmetric) object.__setattr__(self, "_pair_categories", pair_categories) object.__setattr__( self, "_pair_index", {pair: index for index, pair in enumerate(pair_categories)}, ) active_pair_mask = _active_pair_mask( pair_categories, active_pairs=active_pairs, symmetric=self.symmetric, ) self.register_buffer( "active_pair_mask", active_pair_mask, persistent=False, ) object.__setattr__( self, "_active_pair_indices", tuple( index for index, enabled in enumerate(active_pair_mask.tolist()) if enabled ), ) prefactors = self._coerce_prefactors(prefactors_by_pair, dtype=dtype) self.prefactors_by_pair = torch.nn.Parameter( prefactors, requires_grad=bool(trainable), ) def _coerce_prefactors( self, prefactors_by_pair, *, dtype: torch.dtype | None, ) -> torch.Tensor: """Normalize mapping or tensor prefactors into pair-category order.""" if isinstance(prefactors_by_pair, Mapping): values = torch.zeros(len(self.pair_categories), dtype=dtype) seen: set[tuple[int, int]] = set() for pair, value in prefactors_by_pair.items(): if len(pair) != 2: raise ValueError( "prefactor mapping keys must contain exactly two atomic numbers" ) canonical = self.canonical_pair(pair[0], pair[1]) if canonical in seen: raise ValueError( "duplicate prefactor after canonicalization for " f"{_pair_label(canonical)}" ) try: pair_index = self._pair_index[canonical] except KeyError as exc: raise ValueError( f"unknown prefactor pair category {_pair_label(canonical)}" ) from exc values[pair_index] = torch.as_tensor(value, dtype=dtype) seen.add(canonical) missing = [pair for pair in self.active_pair_categories if pair not in seen] if missing: raise ValueError( "missing prefactors for active pair categories: " + ", ".join(_pair_label(pair) for pair in missing) ) return values tensor = torch.as_tensor(prefactors_by_pair, dtype=dtype) if tensor.ndim == 0 and len(self.pair_categories) == 1: tensor = tensor.reshape(1) if tensor.ndim != 1: raise ValueError("`prefactors_by_pair` must have shape (n_pairs,)") if tensor.shape[0] != len(self.pair_categories): raise ValueError( "`prefactors_by_pair.shape[0]` must equal " f"{len(self.pair_categories)} for atomic_types={self.atomic_types}, " f"got {tensor.shape[0]}" ) return tensor @property def pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the full ordered list of pair categories owned by this term.""" return self._pair_categories @property def active_pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the subset of pair categories enabled for evaluation.""" return tuple(self.pair_categories[index] for index in self._active_pair_indices) @property def true_prefactors_by_pair(self) -> torch.Tensor: """Return prefactors indexed by pair category.""" return self.prefactors_by_pair
[docs] def canonical_pair(self, first: int, second: int) -> tuple[int, int]: """Normalize a pair key using this term's symmetry convention.""" return _canonical_pair(first, second, symmetric=self.symmetric)
[docs] def pair_category_index(self, first: int, second: int) -> int: """Return the category index for one pair.""" pair = self.canonical_pair(first, second) try: return self._pair_index[pair] except KeyError as exc: raise KeyError(f"pair {pair} is not part of this term") from exc
[docs] def covers_pair(self, first_atomic_number: int, second_atomic_number: int) -> bool: """Report whether this term evaluates the requested pair category.""" pair = self.canonical_pair(first_atomic_number, second_atomic_number) index = self._pair_index.get(pair) if index is None: return False return bool(self.active_pair_mask[index].item())
[docs] def is_pair_active(self, first: int, second: int) -> bool: """Report whether a canonical pair category remains enabled.""" return bool( self.active_pair_mask[self.pair_category_index(first, second)].item() )
[docs] def prefactor_for_pair(self, first: int, second: int) -> torch.Tensor: """Return the prefactor tensor for one pair category.""" return self.true_prefactors_by_pair[self.pair_category_index(first, second)]
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate active inverse-power pair energies.""" if inputs.neighbor_list is None: raise RuntimeError( "PowerLawRepulsionTerm requires a neighbor list, but `inputs` does " "not contain one" ) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric power-law pair terms require a full neighbor list" ) if not self._active_pair_indices: return empty_atomwise_output(inputs, forces=False) pair_category = inputs.pair_category_indices( self.atomic_types, symmetric=self.symmetric, ) handled_mask = pair_category >= 0 if len(self._active_pair_indices) != len(self.pair_categories): active_pair_mask = self.active_pair_mask.to(device=inputs.device) active_handled_mask = torch.zeros_like(handled_mask) active_handled_mask[handled_mask] = active_pair_mask.index_select( 0, pair_category[handled_mask], ) handled_mask = active_handled_mask if not torch.any(handled_mask): return empty_atomwise_output(inputs, forces=False) distances = inputs.pair_distances(handled_mask).clamp_min(self.eps) handled_pair_category = pair_category[handled_mask] prefactors = self.true_prefactors_by_pair.to( device=inputs.device, dtype=inputs.dtype, ).index_select(0, handled_pair_category) weighted_pair_energy = ( pair_weight(inputs) * prefactors / torch.pow(distances, self.power) ) return accumulate_pair_energies( inputs, handled_mask, weighted_pair_energy=weighted_pair_energy, )
__all__ = [ "PowerLawRepulsionTerm", ]