Source code for ufp.terms.fourbody

"""Torch-only four-body spline energy term."""

from __future__ import annotations

from itertools import combinations
from typing import Literal

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
    spline_support_mask_6d,
    uniform_stencil_6d,
    uniform_support_parameters,
)
from ufp.terms._base import TermInputRequirements, UFPTerm
from ufp.terms._parameters import ParameterBlock, copy_parameter_data


SplineKind = Literal["quadratic", "cubic", "quartic"]


[docs] class SplineFourBody6DTerm(UFPTerm): """ Torch-only four-body term over the six pair distances in a source-neighbor triple. This term contributes energy directly. Forces can be derived through autograd by calling model APIs with ``derive_forces=True``. Args: cutoff: Maximum center-neighbor pair distance included in the term. coeffs: Six-dimensional coefficient tensor. spline: Spline family name. full_support_start: Lower center-neighbor distance with full spline support. neighbor_neighbor_full_support_start: Lower neighbor-neighbor distance with full spline support. atomic_types: Optional atomic-number coverage for the term. trainable: Whether coefficients require gradients. fittable: Whether this term exposes a coefficient block to linear fitters. frozen: Whether the coefficient block is fixed during fitting. dtype: Optional dtype used when converting coefficients. """ def __init__( self, *, cutoff: float, coeffs, spline: SplineKind = "cubic", full_support_start: float = 0.0, neighbor_neighbor_full_support_start: float = 0.0, atomic_types: tuple[int, ...] | None = None, trainable: bool = True, fittable: bool = False, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store one uncategorized 6D coefficient block.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) coeff_tensor = torch.as_tensor(coeffs, dtype=dtype) if coeff_tensor.ndim != 6: raise ValueError("`coeffs` must have shape (N1, N2, N3, N4, N5, N6)") self.spline = spline self.full_support_start = float(full_support_start) self.neighbor_neighbor_full_support_start = float( neighbor_neighbor_full_support_start ) self.fittable = bool(fittable) self.frozen = bool(frozen) self.coeffs = torch.nn.Parameter( coeff_tensor, requires_grad=bool(trainable) and not self.frozen, ) center_first, center_spacing = uniform_support_parameters( coeff_size=int(coeff_tensor.shape[0]), lower_full_support=self.full_support_start, upper_full_support=cutoff, spline=spline, ) outer_first, outer_spacing = uniform_support_parameters( coeff_size=int(coeff_tensor.shape[3]), lower_full_support=self.neighbor_neighbor_full_support_start, upper_full_support=2.0 * float(cutoff), spline=spline, ) self.first_knots = ( center_first, center_first, center_first, outer_first, outer_first, outer_first, ) self.knot_spacings = ( center_spacing, center_spacing, center_spacing, outer_spacing, outer_spacing, outer_spacing, )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the 6D coefficient block when explicitly fittable.""" return ( ParameterBlock( name="coeffs", kind="fourbody6d", shape=tuple(int(dim) for dim in self.coeffs.shape), read=lambda: self.coeffs, write=lambda values: copy_parameter_data(self.coeffs, values), label="fourbody6d", regularization_group="fourbody", fittable=self.fittable, frozen=self.frozen, assembler=None, ), )
@property def input_requirements(self) -> TermInputRequirements: """Declare the directed full-neighbor-list requirement.""" return TermInputRequirements(full_neighbor_list=True)
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate four-body energies for source-centered neighbor triples.""" self.validate_inputs(inputs) coeffs = self.coeffs.to(device=inputs.device, dtype=inputs.dtype) first_atom, second_atom = inputs.pair_indices() pair_vectors = inputs.pair_vectors() pair_distances = inputs.pair_distances() energy = torch.zeros(inputs.n_systems, dtype=inputs.dtype, device=inputs.device) per_atom_energy = torch.zeros( inputs.n_atoms, dtype=inputs.dtype, device=inputs.device, ) for source in range(inputs.n_atoms): pair_rows = torch.nonzero(first_atom == source, as_tuple=False).reshape(-1) if pair_rows.numel() < 3: continue rows_list = pair_rows.detach().cpu().tolist() for row_a, row_b, row_c in combinations(rows_list, 3): neighbor_a = int(second_atom[row_a].item()) neighbor_b = int(second_atom[row_b].item()) neighbor_c = int(second_atom[row_c].item()) vectors = ( pair_vectors[row_a], pair_vectors[row_b], pair_vectors[row_c], ) center_distances = ( pair_distances[row_a].reshape(1), pair_distances[row_b].reshape(1), pair_distances[row_c].reshape(1), ) outer_distances = ( torch.linalg.vector_norm(vectors[0] - vectors[1]).reshape(1), torch.linalg.vector_norm(vectors[0] - vectors[2]).reshape(1), torch.linalg.vector_norm(vectors[1] - vectors[2]).reshape(1), ) coords = (*center_distances, *outer_distances) supported = spline_support_mask_6d( coords, coeff_shape=tuple(int(dim) for dim in coeffs.shape), first_knots=self.first_knots, knot_spacings=self.knot_spacings, spline=self.spline, ) if not bool(torch.all(supported)): continue stencil = uniform_stencil_6d( coords, coeff_shape=tuple(int(dim) for dim in coeffs.shape), first_knots=self.first_knots, knot_spacings=self.knot_spacings, spline=self.spline, ) value = torch.sum( coeffs.reshape(-1)[stencil.indices[0]] * stencil.values[0] ) system = inputs.system_index[source] energy[system] = energy[system] + value share = value / 4.0 per_atom_energy[source] = per_atom_energy[source] + share per_atom_energy[neighbor_a] = per_atom_energy[neighbor_a] + share per_atom_energy[neighbor_b] = per_atom_energy[neighbor_b] + share per_atom_energy[neighbor_c] = per_atom_energy[neighbor_c] + share return UFPOutput( energy=energy, per_atom_energy=per_atom_energy, )
__all__ = [ "SplineFourBody6DTerm", "SplineKind", ]