Source code for ufp.terms.twobody

"""
Spline-based two-body interaction term implementation.

Use this module when pair energies and forces should come from 1D spline
coefficients, optionally shared through alchemical providers.
"""

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Literal

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines._cubic import cubic_eval_1d_with_grads
from ufp.splines._quadratic import quadratic_eval_1d_with_grads
from ufp.splines._quartic import quartic_eval_1d_with_grads
from ufp.splines.derivatives import (
    cubic_derivative_rows_1d,
    cubic_second_derivative_operator,
    cubic_spline_diagnostics_1d,
    cubic_value_rows_1d,
)
from ufp.splines.representation import (
    spline_support_mask_1d,
    uniform_stencil_1d,
    uniform_support_parameters,
)
from ufp.terms._base import LinearAssemblyOptions, PairTerm
from ufp.terms._constraints import softplus_inverse as _softplus_inverse
from ufp.terms._parameters import (
    ParameterBlock,
    ParameterBlockCacheChannel,
    ParameterBlockCacheDescriptor,
    copy_parameter_data,
)
from ufp.terms._selected_assembly import add_selected_entries, selected_column_lookup
from ufp.terms._shared import empty_atomwise_output, pair_weight
from ufp.terms.alchemical import AlchemicalCoefficients
from ufp.terms.categories import (
    active_pair_mask as _active_pair_mask,
)
from ufp.terms.categories import (
    canonical_pair as _canonical_pair,
)
from ufp.terms.categories import (
    pair_categories as _pair_categories,
)


Eval1DWithGrads = Callable[
    [float, torch.Tensor, torch.Tensor],
    tuple[torch.Tensor, torch.Tensor],
]
SplineKind = Literal["quadratic", "cubic", "quartic"]

_SPLINE_EVAL_1D_WITH_GRADS: dict[str, Eval1DWithGrads] = {
    "quadratic": quadratic_eval_1d_with_grads,
    "cubic": cubic_eval_1d_with_grads,
    "quartic": quartic_eval_1d_with_grads,
}


[docs] def get_eval_1d_with_grads(spline: SplineKind | str) -> Eval1DWithGrads: """Return the two-body evaluator for a spline family.""" try: return _SPLINE_EVAL_1D_WITH_GRADS[spline] except KeyError as exc: choices = ", ".join(sorted(_SPLINE_EVAL_1D_WITH_GRADS)) raise ValueError( f"Unsupported spline '{spline}'. Expected one of: {choices}." ) from exc
def _selected_block_matrix( targets, selected_indices: Sequence[int], *, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """Create a compact matrix for a selected coefficient block.""" return torch.zeros( (targets.n_rows, len(tuple(selected_indices))), dtype=dtype, device=device, ) def _infer_dtype( dtype: torch.dtype | None, *values, ) -> torch.dtype: """Infer an initialization dtype from explicit tensors or the default dtype.""" if dtype is not None: return dtype for value in values: if isinstance(value, torch.Tensor): return value.dtype return torch.get_default_dtype() def _as_pair_parameter_tensor( value, *, shape: tuple[int, ...], name: str, dtype: torch.dtype, default: float, ) -> torch.Tensor: """Normalize scalar or full-shaped pair-family parameter initializers.""" if value is None: return torch.full(shape, float(default), dtype=dtype) tensor = torch.as_tensor(value, dtype=dtype) if tensor.ndim == 0: return torch.full( shape, float(tensor.item()), dtype=dtype, device=tensor.device ) if tuple(int(dim) for dim in tensor.shape) != shape: raise ValueError(f"`{name}` must be a scalar or have shape {shape}") return tensor.detach().clone() def _accumulate_spline_pair_output( output: UFPOutput, inputs: UFPInput, *, pair_mask: torch.Tensor, coeffs: torch.Tensor, first_knot: float, knot_spacing: float, spline: SplineKind | str, eps: float, ) -> None: """Evaluate one coefficient block into ``output``.""" eval_1d_with_grads = get_eval_1d_with_grads(spline) pair_distances = inputs.pair_distances(pair_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(coeffs.shape[0]), first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, ) if not torch.any(support_mask): return pair_distances = pair_distances[support_mask] shifted_distances = pair_distances - first_knot pair_energy, pair_grad = eval_1d_with_grads( knot_spacing, shifted_distances, coeffs, ) scale = pair_weight(inputs) weighted_pair_energy = scale * pair_energy weighted_pair_grad = scale * pair_grad first_atom, second_atom = inputs.pair_indices(pair_mask) pair_system_index = inputs.pair_system_index(pair_mask) pair_vectors = inputs.pair_vectors(pair_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] pair_vectors = pair_vectors[support_mask] assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None output.energy.index_add_(0, pair_system_index, weighted_pair_energy) per_atom_contribution = 0.5 * weighted_pair_energy output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution) output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution) inv_r = torch.where( pair_distances > eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) force_on_second = -weighted_pair_grad[:, None] * pair_vectors * inv_r[:, None] output.forces.index_add_(0, first_atom, -force_on_second) output.forces.index_add_(0, second_atom, force_on_second)
[docs] class SplinePairTerm(PairTerm): """ One-dimensional spline two-body interaction term for a specific element pair. The term evaluates a uniformly spaced spline over pair distances in the same length unit as the input coordinates, normally angstroms for ASE structures. Coefficients store interaction energies, normally in electron volts, and the analytic force contribution is derived from the spline gradient. Direct coefficients are supplied through ``coeffs``. For alchemical fitting, pass ``coefficient_provider`` and ``coefficient_index`` instead; the provider supplies the true coefficient vector used during evaluation. Args: cutoff: Maximum pair distance included in the interaction. pair: Two atomic numbers identifying the element pair handled by this term. coeffs: One-dimensional coefficient vector with shape ``(n_knots,)``. coefficient_provider: Optional shared provider that projects proxy coefficients into true pair-specific coefficients. coefficient_index: Index into ``coefficient_provider`` for this pair channel. symmetric: If ``True``, treat ``(a, b)`` and ``(b, a)`` as the same pair category. spline: Spline family name. Supported values are ``"quadratic"``, ``"cubic"``, and ``"quartic"``. full_support_start: Lower distance where the spline has full support. The upper full-support boundary is ``cutoff``. eps: Small distance threshold used to avoid division by zero when forming force directions. enabled: If ``False``, keep the term parameters but skip evaluation. trainable: If ``True``, direct ``coeffs`` are stored as trainable parameters. fittable: Whether this term exposes a coefficient block to linear fitters. frozen: Whether this term's parameter block is fixed during fitting. dtype: Optional dtype used when converting direct coefficients to a tensor. Examples: >>> import torch >>> term = SplinePairTerm( ... cutoff=2.0, ... pair=(1, 1), ... coeffs=torch.zeros(6), ... ) >>> term.covers_pair(1, 1) True """ def __init__( self, *, cutoff: float, pair: tuple[int, int], coeffs=None, coefficient_provider: AlchemicalCoefficients | None = None, coefficient_index: int | None = None, symmetric: bool = True, spline: SplineKind = "cubic", full_support_start: float = 0.0, eps: float = 1.0e-12, enabled: bool = True, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store one pair-specific spline block.""" if len(pair) != 2: raise ValueError("`pair` must contain exactly two atomic numbers") first_atomic_number = int(pair[0]) second_atomic_number = int(pair[1]) super().__init__( cutoff=cutoff, atomic_types=[first_atomic_number, second_atomic_number], ) self.symmetric = bool(symmetric) self.pair = _canonical_pair( first_atomic_number, second_atomic_number, symmetric=self.symmetric, ) self.spline = spline self.eps = float(eps) self.enabled = bool(enabled) self.fittable = bool(fittable) self.frozen = bool(frozen) self.coefficient_index = ( None if coefficient_index is None else int(coefficient_index) ) object.__setattr__(self, "_coefficient_provider", coefficient_provider) if coefficient_provider is None: if coeffs is None: raise ValueError( "`coeffs` is required when `coefficient_provider` is not set" ) coeffs_tensor = torch.as_tensor(coeffs, dtype=dtype) if coeffs_tensor.ndim != 1: raise ValueError("`coeffs` must have shape (n_knots,)") self.coeffs = torch.nn.Parameter( coeffs_tensor, requires_grad=bool(trainable) and not self.frozen, ) coeff_size = int(coeffs_tensor.shape[0]) else: if len(coefficient_provider.coefficient_shape) != 1: raise ValueError( "`coefficient_provider` must provide one-dimensional coefficients" ) if self.coefficient_index is None: raise ValueError( "`coefficient_index` is required when `coefficient_provider` is set" ) coefficient_provider.true_coeffs_for(self.coefficient_index) coeff_size = int(coefficient_provider.coefficient_shape[0]) self.full_support_start = float(full_support_start) self.first_knot, self.knot_spacing = uniform_support_parameters( coeff_size=coeff_size, lower_full_support=self.full_support_start, upper_full_support=cutoff, spline=self.spline, ) @property def coefficient_provider(self) -> AlchemicalCoefficients | None: """Return the shared coefficient provider, if any.""" return self._coefficient_provider @property def provides_forces(self) -> bool: """Report that this term produces analytic forces directly.""" return True @property def true_coeffs(self) -> torch.Tensor: """Return direct or provider-projected true coefficients.""" if self.coefficient_provider is None: return self.coeffs assert self.coefficient_index is not None return self.coefficient_provider.true_coeffs_for(self.coefficient_index) def _write_true_coeffs(self, values: torch.Tensor) -> None: """Write a solved true coefficient vector back into direct storage.""" if self.coefficient_provider is None: copy_parameter_data(self.coeffs, values) return if not self.coefficient_provider.uses_identity_weights: raise ValueError( "can not write true coefficients directly into a non-identity " "alchemical provider" ) assert self.coefficient_index is not None self.coefficient_provider.proxy_coeffs.data[self.coefficient_index].copy_( values.reshape(self.true_coeffs.shape).to( self.coefficient_provider.proxy_coeffs ) )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the pair spline coefficient block.""" return ( ParameterBlock( name="coeffs", kind="pair", shape=tuple(int(dim) for dim in self.true_coeffs.shape), read=lambda: self.true_coeffs, write=self._write_true_coeffs, label=f"pair[{self.pair}]", coefficient_provider=self.coefficient_provider, coefficient_index=self.coefficient_index, regularization_group="twobody", fittable=self.fittable and self.enabled, frozen=self.frozen, assembler="pair", cache_descriptor=ParameterBlockCacheDescriptor( family={ "kind": "pair_spline", "symmetric": bool(self.symmetric), "spline": str(self.spline), "first_knot": float(self.first_knot), "knot_spacing": float(self.knot_spacing), "coeff_size": int(self.true_coeffs.shape[0]), "eps": float(self.eps), }, channels=( ParameterBlockCacheChannel( kind="pair", values=self.pair, start=0, stop=int(self.true_coeffs.shape[0]), ), ), ), ), )
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble pair least-squares blocks for this term.""" from ufp.leastsquares._assemble import _assemble_pair_block blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := _assemble_pair_block(block, batch.inputs, targets)) is not None }
[docs] def assemble_selected_linear_block( self, block, inputs: UFPInput, targets, selected_indices: Sequence[int], ) -> torch.Tensor | None: """Assemble only requested coefficient columns for this pair block.""" selected_indices = tuple(int(index) for index in selected_indices) if inputs.neighbor_list is None: raise RuntimeError("SplinePairTerm requires a neighbor list") if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric spline two-body terms require a full neighbor list" ) selected_lookup = selected_column_lookup( selected_indices, block_size=block.size, device=inputs.device, ) pair_mask = self.covered_pair_mask(inputs) if not torch.any(pair_mask): return None pair_distances = inputs.pair_distances(pair_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=block.size, first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return None pair_distances = pair_distances[support_mask] pair_vectors = inputs.pair_vectors(pair_mask)[support_mask] first_atom, second_atom = inputs.pair_indices(pair_mask) pair_system_index = inputs.pair_system_index(pair_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] stencil = uniform_stencil_1d( pair_distances, coeff_size=block.size, first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) matrix = _selected_block_matrix( targets, selected_indices, device=inputs.device, dtype=inputs.dtype, ) scale = pair_weight(inputs) values = scale * stencil.values grads = scale * stencil.grads inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) direction = pair_vectors * inv_r[:, None] cols = stencil.indices add_selected_entries( matrix, targets.energy_rows.index_select(0, pair_system_index)[:, None], cols, values, selected_lookup, ) half_values = 0.5 * values add_selected_entries( matrix, targets.per_atom_rows.index_select(0, first_atom)[:, None], cols, half_values, selected_lookup, ) add_selected_entries( matrix, targets.per_atom_rows.index_select(0, second_atom)[:, None], cols, half_values, selected_lookup, ) force_second = -(grads[:, :, None] * direction[:, None, :]) force_first = -force_second add_selected_entries( matrix, targets.force_rows.index_select(0, first_atom)[:, :, None], cols[:, None, :], force_first.permute(0, 2, 1), selected_lookup, ) add_selected_entries( matrix, targets.force_rows.index_select(0, second_atom)[:, :, None], cols[:, None, :], force_second.permute(0, 2, 1), selected_lookup, ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def covers_pair(self, first_atomic_number: int, second_atomic_number: int) -> bool: """Report whether this term handles the requested pair.""" return self.enabled and ( _canonical_pair( first_atomic_number, second_atomic_number, symmetric=self.symmetric, ) == self.pair )
[docs] def covered_pair_mask(self, inputs: UFPInput) -> torch.Tensor: """Select neighbor-list rows handled by this pair-specific spline block.""" if not self.enabled: assert inputs.neighbor_list is not None return torch.zeros( inputs.neighbor_list.n_pairs, dtype=torch.bool, device=inputs.device, ) return inputs.pair_mask( self.pair[0], self.pair[1], symmetric=self.symmetric, )
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate this spline block over covered pairs.""" if inputs.neighbor_list is None: raise RuntimeError( "SplinePairTerm requires a neighbor list, but `inputs` does not " "contain one" ) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric spline pair terms require a full neighbor list" ) pair_mask = self.covered_pair_mask(inputs) if not torch.any(pair_mask): return empty_atomwise_output(inputs, forces=True) output = empty_atomwise_output(inputs, forces=True) _accumulate_spline_pair_output( output, inputs, pair_mask=pair_mask, coeffs=self.true_coeffs.to(device=inputs.device, dtype=inputs.dtype), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, eps=self.eps, ) return output
[docs] class SplineTwoBodyTerm(PairTerm): """ One-dimensional spline family covering all configured chemical pair channels. """ def __init__( self, *, cutoff: float, atomic_types: Sequence[int], coeffs_by_pair=None, coefficient_provider: AlchemicalCoefficients | None = None, coefficient_index: int | None = None, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, spline: SplineKind = "cubic", full_support_start: float = 0.0, eps: float = 1.0e-12, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store one spline block per pair category.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) if self.atomic_types is None or not self.atomic_types: raise ValueError("`atomic_types` must contain at least one element") self.symmetric = bool(symmetric) self.spline = spline self.eps = float(eps) self.fittable = bool(fittable) self.frozen = bool(frozen) self.coefficient_index = ( None if coefficient_index is None else int(coefficient_index) ) object.__setattr__(self, "_coefficient_provider", coefficient_provider) pair_categories = _pair_categories(self.atomic_types, symmetric=self.symmetric) object.__setattr__(self, "_pair_categories", pair_categories) object.__setattr__( self, "_pair_index", {pair: index for index, pair in enumerate(pair_categories)}, ) active_pair_mask = _active_pair_mask( pair_categories, active_pairs=active_pairs, symmetric=self.symmetric, ) self.register_buffer( "active_pair_mask", active_pair_mask, persistent=False, ) object.__setattr__( self, "_active_pair_indices", tuple( index for index, enabled in enumerate(active_pair_mask.tolist()) if enabled ), ) expected_pair_categories = len(pair_categories) if coefficient_provider is None: if coeffs_by_pair is None: raise ValueError( "`coeffs_by_pair` is required when " "`coefficient_provider` is not set" ) coeffs_tensor = torch.as_tensor(coeffs_by_pair, dtype=dtype) if coeffs_tensor.ndim != 2: raise ValueError( "`coeffs_by_pair` must have shape (n_pair_categories, n_knots)" ) if coeffs_tensor.shape[0] != expected_pair_categories: raise ValueError( "`coeffs_by_pair.shape[0]` must equal " f"{expected_pair_categories} for atomic_types={self.atomic_types}, " f"got {coeffs_tensor.shape[0]}" ) self.coeffs_by_pair = torch.nn.Parameter( coeffs_tensor, requires_grad=bool(trainable) and not self.frozen, ) coeff_size = int(coeffs_tensor.shape[1]) else: if len(coefficient_provider.coefficient_shape) != 2: raise ValueError( "`coefficient_provider` must provide two-dimensional coefficients" ) if coefficient_provider.coefficient_shape[0] != expected_pair_categories: raise ValueError( "`coefficient_provider` must provide " f"{expected_pair_categories} pair categories for " f"atomic_types={self.atomic_types}, got " f"{coefficient_provider.coefficient_shape[0]}" ) if self.coefficient_index is None: raise ValueError( "`coefficient_index` is required when `coefficient_provider` is set" ) coefficient_provider.true_coeffs_for(self.coefficient_index) coeff_size = int(coefficient_provider.coefficient_shape[1]) self.full_support_start = float(full_support_start) self.first_knot, self.knot_spacing = uniform_support_parameters( coeff_size=coeff_size, lower_full_support=self.full_support_start, upper_full_support=cutoff, spline=self.spline, ) @property def coefficient_provider(self) -> AlchemicalCoefficients | None: """Return the shared coefficient provider, when one supplies the pair blocks.""" return self._coefficient_provider @property def provides_forces(self) -> bool: """Report that this term produces analytic forces directly.""" return True @property def pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the full ordered list of pair categories owned by this term.""" return self._pair_categories @property def active_pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the subset of pair categories still enabled for evaluation.""" return tuple(self.pair_categories[index] for index in self._active_pair_indices) @property def true_coeffs_by_pair(self) -> torch.Tensor: """Return coefficients indexed by pair category.""" if self.coefficient_provider is None: return self.coeffs_by_pair assert self.coefficient_index is not None return self.coefficient_provider.true_coeffs_for(self.coefficient_index) def _write_true_coeffs_by_pair(self, values: torch.Tensor) -> None: """Write solved categorized pair coefficients back into storage.""" if self.coefficient_provider is None: copy_parameter_data(self.coeffs_by_pair, values) return if not self.coefficient_provider.uses_identity_weights: raise ValueError( "can not write true coefficients directly into a non-identity " "alchemical provider" ) assert self.coefficient_index is not None self.coefficient_provider.proxy_coeffs.data[self.coefficient_index].copy_( values.reshape(self.true_coeffs_by_pair.shape).to( self.coefficient_provider.proxy_coeffs ) )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the categorized pair spline coefficient block.""" return ( ParameterBlock( name="coeffs_by_pair", kind="twobody", shape=tuple(int(dim) for dim in self.true_coeffs_by_pair.shape), read=lambda: self.true_coeffs_by_pair, write=self._write_true_coeffs_by_pair, label=f"twobody[{self.atomic_types}]", coefficient_provider=self.coefficient_provider, coefficient_index=self.coefficient_index, regularization_group="twobody", fittable=self.fittable, frozen=self.frozen, assembler="twobody", cache_descriptor=ParameterBlockCacheDescriptor( family={ "kind": "pair_spline", "symmetric": bool(self.symmetric), "spline": str(self.spline), "first_knot": float(self.first_knot), "knot_spacing": float(self.knot_spacing), "coeff_size": int(self.true_coeffs_by_pair.shape[1]), "eps": float(self.eps), }, channels=tuple( ParameterBlockCacheChannel( kind="pair", values=self.pair_categories[pair_index], start=int(pair_index) * int(self.true_coeffs_by_pair.shape[1]), stop=(int(pair_index) + 1) * int(self.true_coeffs_by_pair.shape[1]), ) for pair_index in self._active_pair_indices ), ), ), )
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble categorized pair least-squares blocks for this term.""" from ufp.leastsquares._assemble import _assemble_twobody_block blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := _assemble_twobody_block(block, batch.inputs, targets)) is not None }
[docs] def assemble_selected_linear_block( self, block, inputs: UFPInput, targets, selected_indices: Sequence[int], ) -> torch.Tensor | None: """Assemble only requested coefficient columns for this two-body block.""" selected_indices = tuple(int(index) for index in selected_indices) if inputs.neighbor_list is None: raise RuntimeError("SplineTwoBodyTerm requires a neighbor list") if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric spline two-body terms require a full neighbor list" ) _, n_knots = block.shape selected_lookup = selected_column_lookup( selected_indices, block_size=block.size, device=inputs.device, ) matrix = _selected_block_matrix( targets, selected_indices, device=inputs.device, dtype=inputs.dtype, ) scale = pair_weight(inputs) pair_category = inputs.pair_category_indices( self.atomic_types, symmetric=self.symmetric, ) handled_mask = pair_category >= 0 if len(self._active_pair_indices) != len(self.pair_categories): active_pair_mask = self.active_pair_mask.to(device=inputs.device) active_handled_mask = torch.zeros_like(handled_mask) active_handled_mask[handled_mask] = active_pair_mask.index_select( 0, pair_category[handled_mask], ) handled_mask = active_handled_mask if not torch.any(handled_mask): return None pair_distances = inputs.pair_distances(handled_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(n_knots), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return None pair_distances = pair_distances[support_mask] handled_pair_category = pair_category[handled_mask][support_mask] pair_vectors = inputs.pair_vectors(handled_mask)[support_mask] first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] stencil = uniform_stencil_1d( pair_distances, coeff_size=int(n_knots), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) cols = stencil.indices + handled_pair_category[:, None] * int(n_knots) values = scale * stencil.values grads = scale * stencil.grads inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) direction = pair_vectors * inv_r[:, None] add_selected_entries( matrix, targets.energy_rows.index_select(0, pair_system_index)[:, None], cols, values, selected_lookup, ) half_values = 0.5 * values add_selected_entries( matrix, targets.per_atom_rows.index_select(0, first_atom)[:, None], cols, half_values, selected_lookup, ) add_selected_entries( matrix, targets.per_atom_rows.index_select(0, second_atom)[:, None], cols, half_values, selected_lookup, ) force_second = -(grads[:, :, None] * direction[:, None, :]) force_first = -force_second add_selected_entries( matrix, targets.force_rows.index_select(0, first_atom)[:, :, None], cols[:, None, :], force_first.permute(0, 2, 1), selected_lookup, ) add_selected_entries( matrix, targets.force_rows.index_select(0, second_atom)[:, :, None], cols[:, None, :], force_second.permute(0, 2, 1), selected_lookup, ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def canonical_pair(self, first: int, second: int) -> tuple[int, int]: """Normalize a pair key using this term's symmetry convention.""" return _canonical_pair(first, second, symmetric=self.symmetric)
[docs] def pair_category_index(self, first: int, second: int) -> int: """Return the block index for one canonical pair category.""" pair = self.canonical_pair(first, second) try: return self._pair_index[pair] except KeyError as exc: raise KeyError(f"pair {pair} is not part of this term") from exc
[docs] def is_pair_active(self, first: int, second: int) -> bool: """Report whether a canonical pair category remains enabled.""" return bool( self.active_pair_mask[self.pair_category_index(first, second)].item() )
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Accumulate spline blocks for all active pair categories.""" if inputs.neighbor_list is None: raise RuntimeError( "SplineTwoBodyTerm requires a neighbor list, but `inputs` does not " "contain one" ) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric spline two-body terms require a full neighbor list" ) if not self._active_pair_indices: return empty_atomwise_output(inputs, forces=True) output = empty_atomwise_output(inputs, forces=True) coeffs_by_pair = self.true_coeffs_by_pair.to( device=inputs.device, dtype=inputs.dtype, ) pair_category = inputs.pair_category_indices( self.atomic_types, symmetric=self.symmetric, ) handled_mask = pair_category >= 0 if len(self._active_pair_indices) != len(self.pair_categories): active_pair_mask = self.active_pair_mask.to(device=inputs.device) active_handled_mask = torch.zeros_like(handled_mask) active_handled_mask[handled_mask] = active_pair_mask.index_select( 0, pair_category[handled_mask], ) handled_mask = active_handled_mask if not torch.any(handled_mask): return output pair_distances = inputs.pair_distances(handled_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return output pair_distances = pair_distances[support_mask] handled_pair_category = pair_category[handled_mask][support_mask] stencil = uniform_stencil_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) coeff_window = coeffs_by_pair[handled_pair_category[:, None], stencil.indices] scale = pair_weight(inputs) weighted_pair_energy = scale * (stencil.values * coeff_window).sum(dim=1) weighted_pair_grad = scale * (stencil.grads * coeff_window).sum(dim=1) first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) pair_vectors = inputs.pair_vectors(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] pair_vectors = pair_vectors[support_mask] assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None output.energy.index_add_(0, pair_system_index, weighted_pair_energy) per_atom_contribution = 0.5 * weighted_pair_energy output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution) output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution) inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) force_on_second = -weighted_pair_grad[:, None] * pair_vectors * inv_r[:, None] output.forces.index_add_(0, first_atom, -force_on_second) output.forces.index_add_(0, second_atom, force_on_second) return output
[docs] class RepulsiveSplineTwoBodyTerm(PairTerm): """ Cubic two-body spline family with a constrained repulsive inner wall. The term owns one generated cubic spline coefficient row per pair category. Low-distance coefficients are solved from positive curvature coefficients, a positive transition force, and a trainable transition value. The generated rows are evaluated through the same uniform cubic stencil path as :class:`SplineTwoBodyTerm`. """ def __init__( self, *, cutoff: float, atomic_types: Sequence[int], coeff_size: int, transition_span: int, active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, spline: SplineKind = "cubic", full_support_start: float = 0.0, eps: float = 1.0e-12, trainable: bool = True, outer_coeffs_by_pair=None, curvature_floors_by_pair=None, initial_curvatures_by_pair=None, transition_force_floors_by_pair=None, initial_transition_forces_by_pair=None, initial_transition_values_by_pair=None, dtype: torch.dtype | None = None, ) -> None: """Store nonlinear wall parameters for a cubic pair-family spline.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) if self.atomic_types is None or not self.atomic_types: raise ValueError("`atomic_types` must contain at least one element") if spline != "cubic": raise ValueError("RepulsiveSplineTwoBodyTerm only supports cubic splines") coeff_size = int(coeff_size) if coeff_size <= 3: raise ValueError("`coeff_size` must be greater than 3") transition_span = int(transition_span) if transition_span < 0: raise ValueError("`transition_span` must be non-negative") inner_count = transition_span + 3 if inner_count >= coeff_size: raise ValueError( "`transition_span` leaves no outer coefficients; expected " f"transition_span <= {coeff_size - 4}" ) self.symmetric = bool(symmetric) self.spline = "cubic" self.eps = float(eps) self.coeff_size = coeff_size self.transition_span = transition_span self.inner_coeff_count = inner_count self.curvature_count = transition_span + 1 pair_categories = _pair_categories(self.atomic_types, symmetric=self.symmetric) object.__setattr__(self, "_pair_categories", pair_categories) object.__setattr__( self, "_pair_index", {pair: index for index, pair in enumerate(pair_categories)}, ) active_pair_mask = _active_pair_mask( pair_categories, active_pairs=active_pairs, symmetric=self.symmetric, ) self.register_buffer( "active_pair_mask", active_pair_mask, persistent=False, ) object.__setattr__( self, "_active_pair_indices", tuple( index for index, enabled in enumerate(active_pair_mask.tolist()) if enabled ), ) self.full_support_start = float(full_support_start) self.first_knot, self.knot_spacing = uniform_support_parameters( coeff_size=coeff_size, lower_full_support=self.full_support_start, upper_full_support=cutoff, spline=self.spline, ) self.transition_distance = ( self.full_support_start + self.transition_span * self.knot_spacing ) resolved_dtype = _infer_dtype( dtype, outer_coeffs_by_pair, curvature_floors_by_pair, initial_curvatures_by_pair, transition_force_floors_by_pair, initial_transition_forces_by_pair, initial_transition_values_by_pair, ) n_pairs = len(pair_categories) n_outer = coeff_size - inner_count outer = _as_pair_parameter_tensor( outer_coeffs_by_pair, shape=(n_pairs, n_outer), name="outer_coeffs_by_pair", dtype=resolved_dtype, default=0.0, ) curvature_floor = _as_pair_parameter_tensor( curvature_floors_by_pair, shape=(n_pairs, self.curvature_count), name="curvature_floors_by_pair", dtype=resolved_dtype, default=0.0, ) if bool(torch.any(curvature_floor < 0.0)): raise ValueError("`curvature_floors_by_pair` must be non-negative") if initial_curvatures_by_pair is None: initial_curvature = curvature_floor + 1.0 else: initial_curvature = _as_pair_parameter_tensor( initial_curvatures_by_pair, shape=(n_pairs, self.curvature_count), name="initial_curvatures_by_pair", dtype=resolved_dtype, default=1.0, ) curvature_delta = initial_curvature - curvature_floor if bool(torch.any(curvature_delta <= 0.0)): raise ValueError( "`initial_curvatures_by_pair` must be greater than " "`curvature_floors_by_pair`" ) force_floor = _as_pair_parameter_tensor( transition_force_floors_by_pair, shape=(n_pairs,), name="transition_force_floors_by_pair", dtype=resolved_dtype, default=0.0, ) if bool(torch.any(force_floor < 0.0)): raise ValueError("`transition_force_floors_by_pair` must be non-negative") if initial_transition_forces_by_pair is None: initial_force = force_floor + 1.0 else: initial_force = _as_pair_parameter_tensor( initial_transition_forces_by_pair, shape=(n_pairs,), name="initial_transition_forces_by_pair", dtype=resolved_dtype, default=1.0, ) force_delta = initial_force - force_floor if bool(torch.any(force_delta <= 0.0)): raise ValueError( "`initial_transition_forces_by_pair` must be greater than " "`transition_force_floors_by_pair`" ) transition_values = _as_pair_parameter_tensor( initial_transition_values_by_pair, shape=(n_pairs,), name="initial_transition_values_by_pair", dtype=resolved_dtype, default=0.0, ) self.outer_coeffs_by_pair = torch.nn.Parameter( outer, requires_grad=bool(trainable), ) self.raw_curvatures_by_pair = torch.nn.Parameter( _softplus_inverse(curvature_delta), requires_grad=bool(trainable), ) self.raw_transition_forces_by_pair = torch.nn.Parameter( _softplus_inverse(force_delta), requires_grad=bool(trainable), ) self.transition_values_by_pair = torch.nn.Parameter( transition_values, requires_grad=bool(trainable), ) self.register_buffer("curvature_floors_by_pair", curvature_floor) self.register_buffer("transition_force_floors_by_pair", force_floor) d2 = cubic_second_derivative_operator( coeff_size=coeff_size, knot_spacing=self.knot_spacing, dtype=resolved_dtype, ) transition_point = torch.tensor( [self.transition_distance], dtype=resolved_dtype, ) value_row = cubic_value_rows_1d( transition_point, coeff_size=coeff_size, first_knot=self.first_knot, knot_spacing=self.knot_spacing, )[0] derivative_row = cubic_derivative_rows_1d( transition_point, coeff_size=coeff_size, first_knot=self.first_knot, knot_spacing=self.knot_spacing, )[0] constraint_matrix = torch.cat( ( d2[: self.curvature_count, :inner_count], derivative_row[None, :inner_count], value_row[None, :inner_count], ), dim=0, ) self.register_buffer("_d2_matrix", d2) self.register_buffer("_transition_value_row", value_row) self.register_buffer("_transition_derivative_row", derivative_row) self.register_buffer("_constraint_matrix", constraint_matrix) @property def provides_forces(self) -> bool: """Report that this term produces analytic forces directly.""" return True @property def pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the full ordered list of pair categories owned by this term.""" return self._pair_categories @property def active_pair_categories(self) -> tuple[tuple[int, int], ...]: """Return the subset of pair categories enabled for evaluation.""" return tuple(self.pair_categories[index] for index in self._active_pair_indices) @property def curvature_coeffs_by_pair(self) -> torch.Tensor: """Return positive inner curvature coefficients by pair category.""" floors = self.curvature_floors_by_pair.to( dtype=self.raw_curvatures_by_pair.dtype, device=self.raw_curvatures_by_pair.device, ) return floors + torch.nn.functional.softplus(self.raw_curvatures_by_pair) @property def transition_forces_by_pair(self) -> torch.Tensor: """Return positive transition force magnitudes by pair category.""" floors = self.transition_force_floors_by_pair.to( dtype=self.raw_transition_forces_by_pair.dtype, device=self.raw_transition_forces_by_pair.device, ) return floors + torch.nn.functional.softplus(self.raw_transition_forces_by_pair) @property def true_coeffs_by_pair(self) -> torch.Tensor: """Return generated cubic spline coefficients by pair category.""" outer = self.outer_coeffs_by_pair dtype = outer.dtype device = outer.device inner_count = self.inner_coeff_count curvature_count = self.curvature_count d2 = self._d2_matrix.to(dtype=dtype, device=device) value_row = self._transition_value_row.to(dtype=dtype, device=device) derivative_row = self._transition_derivative_row.to( dtype=dtype, device=device, ) constraint_matrix = self._constraint_matrix.to(dtype=dtype, device=device) rhs_curvature = self.curvature_coeffs_by_pair - torch.matmul( outer, d2[:curvature_count, inner_count:].transpose(0, 1), ) rhs_force = -self.transition_forces_by_pair - torch.matmul( outer, derivative_row[inner_count:], ) rhs_value = self.transition_values_by_pair - torch.matmul( outer, value_row[inner_count:], ) rhs = torch.cat( ( rhs_curvature, rhs_force[:, None], rhs_value[:, None], ), dim=1, ) inner = torch.linalg.solve(constraint_matrix, rhs.transpose(0, 1)).transpose( 0, 1, ) return torch.cat((inner, outer), dim=1)
[docs] def to_spline_twobody_term( self, *, trainable: bool = True, fittable: bool = True, ) -> SplineTwoBodyTerm: """Return an ordinary spline term with the current generated coefficients.""" return SplineTwoBodyTerm( cutoff=float(self.cutoff), atomic_types=self.atomic_types, coeffs_by_pair=self.true_coeffs_by_pair.detach().clone(), active_pairs=self.active_pair_categories, symmetric=self.symmetric, spline=self.spline, full_support_start=self.full_support_start, eps=self.eps, trainable=trainable, fittable=fittable, )
[docs] def wall_diagnostics( self, *, n_samples: int = 200, ) -> dict[tuple[int, int], object]: """Sample value, gradient, and curvature on the constrained interval.""" n_samples = int(n_samples) if n_samples < 2: raise ValueError("`n_samples` must be at least 2") coeffs = self.true_coeffs_by_pair distances = torch.linspace( self.full_support_start, self.transition_distance, n_samples, dtype=coeffs.dtype, device=coeffs.device, ) diagnostics = {} for pair in self.active_pair_categories: index = self.pair_category_index(pair[0], pair[1]) diagnostics[pair] = cubic_spline_diagnostics_1d( distances, coeffs[index], first_knot=self.first_knot, knot_spacing=self.knot_spacing, ) return diagnostics
[docs] def canonical_pair(self, first: int, second: int) -> tuple[int, int]: """Normalize a pair key using this term's symmetry convention.""" return _canonical_pair(first, second, symmetric=self.symmetric)
[docs] def pair_category_index(self, first: int, second: int) -> int: """Return the block index for one canonical pair category.""" pair = self.canonical_pair(first, second) try: return self._pair_index[pair] except KeyError as exc: raise KeyError(f"pair {pair} is not part of this term") from exc
[docs] def is_pair_active(self, first: int, second: int) -> bool: """Report whether a canonical pair category is enabled.""" return bool( self.active_pair_mask[self.pair_category_index(first, second)].item() )
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Accumulate generated spline blocks for all active pair categories.""" if inputs.neighbor_list is None: raise RuntimeError( "RepulsiveSplineTwoBodyTerm requires a neighbor list, but `inputs` " "does not contain one" ) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric spline two-body terms require a full neighbor list" ) if not self._active_pair_indices: return empty_atomwise_output(inputs, forces=True) output = empty_atomwise_output(inputs, forces=True) coeffs_by_pair = self.true_coeffs_by_pair.to( device=inputs.device, dtype=inputs.dtype, ) pair_category = inputs.pair_category_indices( self.atomic_types, symmetric=self.symmetric, ) handled_mask = pair_category >= 0 if len(self._active_pair_indices) != len(self.pair_categories): active_pair_mask = self.active_pair_mask.to(device=inputs.device) active_handled_mask = torch.zeros_like(handled_mask) active_handled_mask[handled_mask] = active_pair_mask.index_select( 0, pair_category[handled_mask], ) handled_mask = active_handled_mask if not torch.any(handled_mask): return output pair_distances = inputs.pair_distances(handled_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return output pair_distances = pair_distances[support_mask] handled_pair_category = pair_category[handled_mask][support_mask] stencil = uniform_stencil_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) coeff_window = coeffs_by_pair[handled_pair_category[:, None], stencil.indices] scale = pair_weight(inputs) weighted_pair_energy = scale * (stencil.values * coeff_window).sum(dim=1) weighted_pair_grad = scale * (stencil.grads * coeff_window).sum(dim=1) first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) pair_vectors = inputs.pair_vectors(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] pair_vectors = pair_vectors[support_mask] assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None output.energy.index_add_(0, pair_system_index, weighted_pair_energy) per_atom_contribution = 0.5 * weighted_pair_energy output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution) output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution) inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) force_on_second = -weighted_pair_grad[:, None] * pair_vectors * inv_r[:, None] output.forces.index_add_(0, first_atom, -force_on_second) output.forces.index_add_(0, second_atom, force_on_second) return output
__all__ = [ "RepulsiveSplineTwoBodyTerm", "SplineKind", "SplinePairTerm", "SplineTwoBodyTerm", "get_eval_1d_with_grads", ]