"""Trainable pair-dependent inverse-power repulsion terms."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._base import PairTerm
from ufp.terms._shared import (
accumulate_pair_energies,
empty_atomwise_output,
pair_weight,
)
from ufp.terms.twobody import _active_pair_mask, _canonical_pair, _pair_categories
def _pair_label(pair: tuple[int, int]) -> str:
"""Format one pair category for error messages."""
return f"({pair[0]}, {pair[1]})"
[docs]
class PowerLawRepulsionTerm(PairTerm):
"""
Pair-dependent inverse-power repulsion prior.
The term evaluates
.. math::
E_{ij} = a_{Z_i, Z_j} / \\max(r_{ij}, \\epsilon)^p
where ``a`` is a trainable pair-channel prefactor. Forces are intentionally
not returned by this term; callers that request forces should use the normal
autograd path via ``derive_forces=True``.
Args:
cutoff: Maximum pair distance included in the interaction.
atomic_types: Atomic numbers used to enumerate pair categories.
prefactors_by_pair: Either a tensor with shape ``(n_pair_categories,)``
in the term's category order, or a mapping from pair tuples to
prefactor values for all active pair categories.
power: Positive inverse-power exponent. ``2.0`` gives ``a / r**2``.
active_pairs: Optional subset of pair categories that should be
evaluated. Inactive categories may have zero prefactors.
symmetric: If ``True``, treat ``(a, b)`` and ``(b, a)`` as the same pair
category.
eps: Minimum distance used before applying the inverse power.
trainable: Whether prefactors require gradients.
dtype: Optional dtype used when converting prefactors to a tensor.
Examples:
>>> term = PowerLawRepulsionTerm(
... cutoff=2.5,
... atomic_types=[1],
... prefactors_by_pair={(1, 1): 0.4},
... )
>>> term.covers_pair(1, 1)
True
"""
def __init__(
self,
*,
cutoff: float,
atomic_types: Sequence[int],
prefactors_by_pair,
power: float = 2.0,
active_pairs: Sequence[tuple[int, int]] | None = None,
symmetric: bool = True,
eps: float = 1.0e-12,
trainable: bool = True,
dtype: torch.dtype | None = None,
) -> None:
"""Store one trainable inverse-power prefactor per pair category."""
super().__init__(cutoff=cutoff, atomic_types=atomic_types)
if self.atomic_types is None or not self.atomic_types:
raise ValueError("`atomic_types` must contain at least one element")
self.symmetric = bool(symmetric)
self.power = float(power)
self.eps = float(eps)
if self.power <= 0.0:
raise ValueError("`power` must be positive")
if self.eps <= 0.0:
raise ValueError("`eps` must be positive")
pair_categories = _pair_categories(self.atomic_types, symmetric=self.symmetric)
object.__setattr__(self, "_pair_categories", pair_categories)
object.__setattr__(
self,
"_pair_index",
{pair: index for index, pair in enumerate(pair_categories)},
)
active_pair_mask = _active_pair_mask(
pair_categories,
active_pairs=active_pairs,
symmetric=self.symmetric,
)
self.register_buffer(
"active_pair_mask",
active_pair_mask,
persistent=False,
)
object.__setattr__(
self,
"_active_pair_indices",
tuple(
index
for index, enabled in enumerate(active_pair_mask.tolist())
if enabled
),
)
prefactors = self._coerce_prefactors(prefactors_by_pair, dtype=dtype)
self.prefactors_by_pair = torch.nn.Parameter(
prefactors,
requires_grad=bool(trainable),
)
def _coerce_prefactors(
self,
prefactors_by_pair,
*,
dtype: torch.dtype | None,
) -> torch.Tensor:
"""Normalize mapping or tensor prefactors into pair-category order."""
if isinstance(prefactors_by_pair, Mapping):
values = torch.zeros(len(self.pair_categories), dtype=dtype)
seen: set[tuple[int, int]] = set()
for pair, value in prefactors_by_pair.items():
if len(pair) != 2:
raise ValueError(
"prefactor mapping keys must contain exactly two atomic numbers"
)
canonical = self.canonical_pair(pair[0], pair[1])
if canonical in seen:
raise ValueError(
"duplicate prefactor after canonicalization for "
f"{_pair_label(canonical)}"
)
try:
pair_index = self._pair_index[canonical]
except KeyError as exc:
raise ValueError(
f"unknown prefactor pair category {_pair_label(canonical)}"
) from exc
values[pair_index] = torch.as_tensor(value, dtype=dtype)
seen.add(canonical)
missing = [pair for pair in self.active_pair_categories if pair not in seen]
if missing:
raise ValueError(
"missing prefactors for active pair categories: "
+ ", ".join(_pair_label(pair) for pair in missing)
)
return values
tensor = torch.as_tensor(prefactors_by_pair, dtype=dtype)
if tensor.ndim == 0 and len(self.pair_categories) == 1:
tensor = tensor.reshape(1)
if tensor.ndim != 1:
raise ValueError("`prefactors_by_pair` must have shape (n_pairs,)")
if tensor.shape[0] != len(self.pair_categories):
raise ValueError(
"`prefactors_by_pair.shape[0]` must equal "
f"{len(self.pair_categories)} for atomic_types={self.atomic_types}, "
f"got {tensor.shape[0]}"
)
return tensor
@property
def pair_categories(self) -> tuple[tuple[int, int], ...]:
"""Return the full ordered list of pair categories owned by this term."""
return self._pair_categories
@property
def active_pair_categories(self) -> tuple[tuple[int, int], ...]:
"""Return the subset of pair categories enabled for evaluation."""
return tuple(self.pair_categories[index] for index in self._active_pair_indices)
@property
def true_prefactors_by_pair(self) -> torch.Tensor:
"""Return prefactors indexed by pair category."""
return self.prefactors_by_pair
[docs]
def canonical_pair(self, first: int, second: int) -> tuple[int, int]:
"""Normalize a pair key using this term's symmetry convention."""
return _canonical_pair(first, second, symmetric=self.symmetric)
[docs]
def pair_category_index(self, first: int, second: int) -> int:
"""Return the category index for one pair."""
pair = self.canonical_pair(first, second)
try:
return self._pair_index[pair]
except KeyError as exc:
raise KeyError(f"pair {pair} is not part of this term") from exc
[docs]
def covers_pair(self, first_atomic_number: int, second_atomic_number: int) -> bool:
"""Report whether this term evaluates the requested pair category."""
pair = self.canonical_pair(first_atomic_number, second_atomic_number)
index = self._pair_index.get(pair)
if index is None:
return False
return bool(self.active_pair_mask[index].item())
[docs]
def is_pair_active(self, first: int, second: int) -> bool:
"""Report whether a canonical pair category remains enabled."""
return bool(
self.active_pair_mask[self.pair_category_index(first, second)].item()
)
[docs]
def prefactor_for_pair(self, first: int, second: int) -> torch.Tensor:
"""Return the prefactor tensor for one pair category."""
return self.true_prefactors_by_pair[self.pair_category_index(first, second)]
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate active inverse-power pair energies."""
if inputs.neighbor_list is None:
raise RuntimeError(
"PowerLawRepulsionTerm requires a neighbor list, but `inputs` does "
"not contain one"
)
if not self.symmetric and not inputs.neighbor_list.full_list:
raise RuntimeError(
"asymmetric power-law pair terms require a full neighbor list"
)
if not self._active_pair_indices:
return empty_atomwise_output(inputs, forces=False)
pair_category = inputs.pair_category_indices(
self.atomic_types,
symmetric=self.symmetric,
)
handled_mask = pair_category >= 0
if len(self._active_pair_indices) != len(self.pair_categories):
active_pair_mask = self.active_pair_mask.to(device=inputs.device)
active_handled_mask = torch.zeros_like(handled_mask)
active_handled_mask[handled_mask] = active_pair_mask.index_select(
0,
pair_category[handled_mask],
)
handled_mask = active_handled_mask
if not torch.any(handled_mask):
return empty_atomwise_output(inputs, forces=False)
distances = inputs.pair_distances(handled_mask).clamp_min(self.eps)
handled_pair_category = pair_category[handled_mask]
prefactors = self.true_prefactors_by_pair.to(
device=inputs.device,
dtype=inputs.dtype,
).index_select(0, handled_pair_category)
weighted_pair_energy = (
pair_weight(inputs) * prefactors / torch.pow(distances, self.power)
)
return accumulate_pair_energies(
inputs,
handled_mask,
weighted_pair_energy=weighted_pair_energy,
)
__all__ = [
"PowerLawRepulsionTerm",
]