"""Torch-only four-body spline energy term."""
from __future__ import annotations
from itertools import combinations
from typing import Literal
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
spline_support_mask_6d,
uniform_stencil_6d,
uniform_support_parameters,
)
from ufp.terms._base import TermInputRequirements, UFPTerm
from ufp.terms._parameters import ParameterBlock, copy_parameter_data
SplineKind = Literal["quadratic", "cubic", "quartic"]
[docs]
class SplineFourBody6DTerm(UFPTerm):
"""
Torch-only four-body term over the six pair distances in a source-neighbor
triple.
This term contributes energy directly. Forces can be derived through
autograd by calling model APIs with ``derive_forces=True``.
Args:
cutoff: Maximum center-neighbor pair distance included in the term.
coeffs: Six-dimensional coefficient tensor.
spline: Spline family name.
full_support_start: Lower center-neighbor distance with full spline support.
neighbor_neighbor_full_support_start: Lower neighbor-neighbor distance with
full spline support.
atomic_types: Optional atomic-number coverage for the term.
trainable: Whether coefficients require gradients.
fittable: Whether this term exposes a coefficient block to linear fitters.
frozen: Whether the coefficient block is fixed during fitting.
dtype: Optional dtype used when converting coefficients.
"""
def __init__(
self,
*,
cutoff: float,
coeffs,
spline: SplineKind = "cubic",
full_support_start: float = 0.0,
neighbor_neighbor_full_support_start: float = 0.0,
atomic_types: tuple[int, ...] | None = None,
trainable: bool = True,
fittable: bool = False,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store one uncategorized 6D coefficient block."""
super().__init__(cutoff=cutoff, atomic_types=atomic_types)
coeff_tensor = torch.as_tensor(coeffs, dtype=dtype)
if coeff_tensor.ndim != 6:
raise ValueError("`coeffs` must have shape (N1, N2, N3, N4, N5, N6)")
self.spline = spline
self.full_support_start = float(full_support_start)
self.neighbor_neighbor_full_support_start = float(
neighbor_neighbor_full_support_start
)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
self.coeffs = torch.nn.Parameter(
coeff_tensor,
requires_grad=bool(trainable) and not self.frozen,
)
center_first, center_spacing = uniform_support_parameters(
coeff_size=int(coeff_tensor.shape[0]),
lower_full_support=self.full_support_start,
upper_full_support=cutoff,
spline=spline,
)
outer_first, outer_spacing = uniform_support_parameters(
coeff_size=int(coeff_tensor.shape[3]),
lower_full_support=self.neighbor_neighbor_full_support_start,
upper_full_support=2.0 * float(cutoff),
spline=spline,
)
self.first_knots = (
center_first,
center_first,
center_first,
outer_first,
outer_first,
outer_first,
)
self.knot_spacings = (
center_spacing,
center_spacing,
center_spacing,
outer_spacing,
outer_spacing,
outer_spacing,
)
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return the 6D coefficient block when explicitly fittable."""
return (
ParameterBlock(
name="coeffs",
kind="fourbody6d",
shape=tuple(int(dim) for dim in self.coeffs.shape),
read=lambda: self.coeffs,
write=lambda values: copy_parameter_data(self.coeffs, values),
label="fourbody6d",
regularization_group="fourbody",
fittable=self.fittable,
frozen=self.frozen,
assembler=None,
),
)
@property
def input_requirements(self) -> TermInputRequirements:
"""Declare the directed full-neighbor-list requirement."""
return TermInputRequirements(full_neighbor_list=True)
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate four-body energies for source-centered neighbor triples."""
self.validate_inputs(inputs)
coeffs = self.coeffs.to(device=inputs.device, dtype=inputs.dtype)
first_atom, second_atom = inputs.pair_indices()
pair_vectors = inputs.pair_vectors()
pair_distances = inputs.pair_distances()
energy = torch.zeros(inputs.n_systems, dtype=inputs.dtype, device=inputs.device)
per_atom_energy = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
for source in range(inputs.n_atoms):
pair_rows = torch.nonzero(first_atom == source, as_tuple=False).reshape(-1)
if pair_rows.numel() < 3:
continue
rows_list = pair_rows.detach().cpu().tolist()
for row_a, row_b, row_c in combinations(rows_list, 3):
neighbor_a = int(second_atom[row_a].item())
neighbor_b = int(second_atom[row_b].item())
neighbor_c = int(second_atom[row_c].item())
vectors = (
pair_vectors[row_a],
pair_vectors[row_b],
pair_vectors[row_c],
)
center_distances = (
pair_distances[row_a].reshape(1),
pair_distances[row_b].reshape(1),
pair_distances[row_c].reshape(1),
)
outer_distances = (
torch.linalg.vector_norm(vectors[0] - vectors[1]).reshape(1),
torch.linalg.vector_norm(vectors[0] - vectors[2]).reshape(1),
torch.linalg.vector_norm(vectors[1] - vectors[2]).reshape(1),
)
coords = (*center_distances, *outer_distances)
supported = spline_support_mask_6d(
coords,
coeff_shape=tuple(int(dim) for dim in coeffs.shape),
first_knots=self.first_knots,
knot_spacings=self.knot_spacings,
spline=self.spline,
)
if not bool(torch.all(supported)):
continue
stencil = uniform_stencil_6d(
coords,
coeff_shape=tuple(int(dim) for dim in coeffs.shape),
first_knots=self.first_knots,
knot_spacings=self.knot_spacings,
spline=self.spline,
)
value = torch.sum(
coeffs.reshape(-1)[stencil.indices[0]] * stencil.values[0]
)
system = inputs.system_index[source]
energy[system] = energy[system] + value
share = value / 4.0
per_atom_energy[source] = per_atom_energy[source] + share
per_atom_energy[neighbor_a] = per_atom_energy[neighbor_a] + share
per_atom_energy[neighbor_b] = per_atom_energy[neighbor_b] + share
per_atom_energy[neighbor_c] = per_atom_energy[neighbor_c] + share
return UFPOutput(
energy=energy,
per_atom_energy=per_atom_energy,
)
__all__ = [
"SplineFourBody6DTerm",
"SplineKind",
]