"""
Spline-based three-body interaction term implementation.
``SplineThreeBodyTerm`` is the stable user-facing term. Bucket containers,
feature caches, and evaluator helpers exported from this module are expert
diagnostics for benchmarks, tests, and performance investigations.
"""
from __future__ import annotations
import hashlib
import json
from collections.abc import Sequence
from pathlib import Path
from typing import Literal
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
all_supported_uniform_stencil_3d,
uniform_support_parameters,
)
from ufp.terms._base import (
LinearAssemblyOptions,
TermCacheOptions,
TermInputRequirements,
ThreeBodyTerm,
)
from ufp.terms._parameters import (
ParameterBlock,
ParameterBlockCacheChannel,
ParameterBlockCacheDescriptor,
copy_parameter_data,
)
from ufp.terms._selected_assembly import add_selected_entries, selected_column_lookup
from ufp.terms._shared import empty_atomwise_output
from ufp.terms._threebody_cache import (
FeatureCacheMode,
_build_dense_feature_cache_from_buckets,
_dense_feature_cache_dir,
_dense_feature_cache_metadata,
_find_compatible_memmap_dense_feature_cache,
_load_memmap_dense_feature_cache,
load_memmap_threebody_feature_cache,
)
from ufp.terms._threebody_dense import (
DenseThreeBodyFeatureCache,
DenseTripletFeatureBlock,
MemmapDenseThreeBodyFeatureCache,
MemmapDenseTripletFeatureBlock,
ThreeBodyDenseAtomFeatures,
_build_dense_feature_cache_from_feature_cache,
_dense_atom_features_from_feature_cache,
_evaluate_dense_feature_cache_energy_forces,
_selected_atom_indices,
_symmetrize_dense_atom_features,
)
from ufp.terms._threebody_eval import (
BucketedEnergyForceEvaluator,
SplineKind,
_neighbor_neighbor_cutoff,
_same_neighbor_triplet_mask,
_support_bounds,
_symmetrize_same_neighbor_coeffs,
evaluate_bucketed_energy_forces,
get_eval_3d_with_grads,
make_bucketed_energy_forces_evaluator,
)
from ufp.terms._threebody_eval import (
Eval3DWithGrads as Eval3DWithGrads,
)
from ufp.terms._threebody_eval import (
_evaluate_bucketed_energy_forces as _evaluate_bucketed_energy_forces,
)
from ufp.terms._threebody_features import (
ThreeBodyFeatureBlock as ThreeBodyFeatureBlock,
)
from ufp.terms._threebody_features import (
ThreeBodyFeatureCache as ThreeBodyFeatureCache,
)
from ufp.terms._threebody_features import (
_build_feature_cache_from_buckets,
)
from ufp.terms._threebody_kernels import (
preprocess_sources_native_or_torch,
)
from ufp.terms._threebody_ops import (
Buckets,
build_edge_category_table,
num_edge_categories,
pair_distance_partials_batched,
pattern_triplet_layout,
preprocess_sources,
)
from ufp.terms._threebody_runtime import (
ThreeBodyRuntimeConfig,
resolve_threebody_runtime_config,
)
from ufp.terms.alchemical import AlchemicalCoefficients
from ufp.terms.categories import (
active_triplet_mask as _active_triplet_mask,
)
from ufp.terms.categories import (
canonical_triplet as _canonical_triplet,
)
from ufp.terms.categories import (
triplet_categories as _triplet_categories,
)
_THREEBODY_BUCKET_CACHE_KEY = "_ufp_threebody_buckets"
_THREEBODY_FEATURE_CACHE_KEY = "_ufp_threebody_features"
def _selected_block_matrix(
targets,
selected_indices: Sequence[int],
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Create a compact matrix for a selected coefficient block."""
return torch.zeros(
(targets.n_rows, len(tuple(selected_indices))),
dtype=dtype,
device=device,
)
def _swapped_xy_cols(
cols: torch.Tensor,
coeff_volume: int,
nx: int,
ny: int,
nz: int,
) -> torch.Tensor:
"""Return columns addressing the same coefficient with x/y swapped."""
local = torch.remainder(cols, coeff_volume)
block_offset = cols - local
iz = torch.remainder(local, nz)
iy = torch.remainder(torch.div(local, nz, rounding_mode="floor"), ny)
ix = torch.div(local, ny * nz, rounding_mode="floor")
return block_offset + ((iy * nx + ix) * nz + iz)
def _add_selected_threebody_entries(
matrix: torch.Tensor,
rows: torch.Tensor,
cols: torch.Tensor,
values: torch.Tensor,
*,
selected_lookup: torch.Tensor,
same_triplet_mask: torch.Tensor,
coeff_volume: int,
nx: int,
ny: int,
nz: int,
) -> None:
"""Accumulate selected three-body rows with same-neighbor x/y tying."""
if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
return
same_triplet_mask = same_triplet_mask.to(device=cols.device, dtype=torch.bool)
if same_triplet_mask.numel() == 0 or not bool(torch.any(same_triplet_mask)):
add_selected_entries(matrix, rows, cols, values, selected_lookup)
return
ordinary = ~same_triplet_mask
if torch.any(ordinary):
add_selected_entries(
matrix,
rows[ordinary],
cols[ordinary],
values[ordinary],
selected_lookup,
)
if torch.any(same_triplet_mask):
same_rows = rows[same_triplet_mask]
same_cols = cols[same_triplet_mask]
same_values = 0.5 * values[same_triplet_mask]
add_selected_entries(
matrix,
same_rows,
same_cols,
same_values,
selected_lookup,
)
add_selected_entries(
matrix,
same_rows,
_swapped_xy_cols(same_cols, coeff_volume, nx, ny, nz),
same_values,
selected_lookup,
)
def _selected_channel_mask(
triplet_index: torch.Tensor,
selected_triplet_indices: torch.Tensor | None,
) -> torch.Tensor:
"""Return which local triplet categories have selected coefficients."""
if selected_triplet_indices is None:
return torch.ones_like(triplet_index, dtype=torch.bool)
mask = torch.zeros_like(triplet_index, dtype=torch.bool)
for selected_index in selected_triplet_indices.detach().cpu().tolist():
mask = mask | (triplet_index == int(selected_index))
return mask
def _accumulate_selected_threebody_pairs(
*,
matrix: torch.Tensor,
term: "SplineThreeBodyTerm",
src_ids: torch.Tensor,
src_system: torch.Tensor,
triplet_index: torch.Tensor,
row: torch.Tensor,
col: torch.Tensor,
vec: torch.Tensor,
distances: torch.Tensor,
nbr_sorted: torch.Tensor,
coeff_shape: tuple[int, int, int, int],
selected_triplet_index: int | None,
selected_triplet_indices: torch.Tensor | None,
selected_lookup: torch.Tensor,
energy_rows: torch.Tensor,
force_rows: torch.Tensor,
per_atom_rows: torch.Tensor,
) -> None:
"""Accumulate one source-pattern's selected three-body columns."""
if row.numel() == 0 or src_ids.numel() == 0:
return
_, nx, ny, nz = coeff_shape
active_triplet_mask = term.active_triplet_mask.to(device=triplet_index.device)
active_mask = active_triplet_mask.index_select(0, triplet_index)
if selected_triplet_index is not None:
active_mask = active_mask & (triplet_index == int(selected_triplet_index))
else:
active_mask = active_mask & _selected_channel_mask(
triplet_index,
selected_triplet_indices,
)
if not torch.any(active_mask):
return
row = row[active_mask]
col = col[active_mask]
triplet_index = triplet_index[active_mask]
coeff_volume = int(nx * ny * nz)
vj = vec[:, row, :]
vk = vec[:, col, :]
x = distances[:, row]
y = distances[:, col]
diff = vj - vk
z = torch.linalg.norm(diff, dim=2)
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
flat_z = z.reshape(-1)
flat_mask = (
(flat_x >= term.lower_support_xy)
& (flat_x < term.upper_support_xy)
& (flat_y >= term.lower_support_xy)
& (flat_y < term.upper_support_xy)
& (flat_z >= term.lower_support_z)
)
if not torch.any(flat_mask):
return
if bool(torch.all(flat_mask)):
supported_x = flat_x
supported_y = flat_y
supported_z = flat_z
flat_triplet_positions = torch.arange(
flat_x.numel(),
dtype=torch.int64,
device=flat_x.device,
)
else:
supported_x = flat_x[flat_mask]
supported_y = flat_y[flat_mask]
supported_z = flat_z[flat_mask]
flat_triplet_positions = torch.nonzero(flat_mask, as_tuple=False).reshape(-1)
stencil = all_supported_uniform_stencil_3d(
supported_x,
supported_y,
supported_z,
coeff_shape=(nx, ny, nz),
first_knot_xy=term.first_knot_xy,
first_knot_z=term.first_knot_z,
knot_spacing_xy=term.knot_spacing_xy,
knot_spacing_z=term.knot_spacing_z,
spline=term.spline,
)
supported_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
supported_src = supported_src.index_select(0, flat_triplet_positions)
supported_system = src_system[:, None].expand(-1, row.numel()).reshape(-1)
supported_system = supported_system.index_select(0, flat_triplet_positions)
supported_dst_j = nbr_sorted[:, row].reshape(-1)
supported_dst_j = supported_dst_j.index_select(0, flat_triplet_positions)
supported_dst_k = nbr_sorted[:, col].reshape(-1)
supported_dst_k = supported_dst_k.index_select(0, flat_triplet_positions)
supported_triplet_index = (
triplet_index[None, :]
.expand(src_ids.shape[0], -1)
.reshape(-1)
.index_select(0, flat_triplet_positions)
)
supported_vj = vj.reshape(-1, 3).index_select(0, flat_triplet_positions)
supported_vk = vk.reshape(-1, 3).index_select(0, flat_triplet_positions)
supported_diff = diff.reshape(-1, 3).index_select(0, flat_triplet_positions)
if selected_triplet_index is None:
cols = stencil.indices + supported_triplet_index[:, None] * coeff_volume
else:
cols = stencil.indices
d_e_dvj, d_e_dvk = pair_distance_partials_batched(
stencil.grad_x,
stencil.grad_y,
stencil.grad_z,
supported_vj,
supported_vk,
supported_diff,
supported_x,
supported_y,
supported_z,
term.eps,
)
force_j = -d_e_dvj
force_k = -d_e_dvk
force_i = d_e_dvj + d_e_dvk
same_triplet_mask = term.same_neighbor_triplet_mask.to(
device=supported_triplet_index.device
).index_select(
0,
supported_triplet_index,
)
_add_selected_threebody_entries(
matrix,
energy_rows.index_select(0, supported_system)[:, None],
cols,
stencil.values,
selected_lookup=selected_lookup,
same_triplet_mask=same_triplet_mask,
coeff_volume=coeff_volume,
nx=nx,
ny=ny,
nz=nz,
)
if torch.any(per_atom_rows >= 0):
_add_selected_threebody_entries(
matrix,
per_atom_rows.index_select(0, supported_src)[:, None],
cols,
stencil.values,
selected_lookup=selected_lookup,
same_triplet_mask=same_triplet_mask,
coeff_volume=coeff_volume,
nx=nx,
ny=ny,
nz=nz,
)
if torch.any(force_rows >= 0):
for atom_index, force in (
(supported_src, force_i),
(supported_dst_j, force_j),
(supported_dst_k, force_k),
):
_add_selected_threebody_entries(
matrix,
force_rows.index_select(0, atom_index)[:, :, None],
cols[:, None, :],
force.permute(0, 2, 1),
selected_lookup=selected_lookup,
same_triplet_mask=same_triplet_mask,
coeff_volume=coeff_volume,
nx=nx,
ny=ny,
nz=nz,
)
[docs]
class SplineThreeBodyTerm(ThreeBodyTerm):
"""
Source-distinguished three-body spline interaction term.
"""
def __init__(
self,
*,
cutoff: float,
atomic_types: Sequence[int],
coeffs_by_triplet=None,
coefficient_provider: AlchemicalCoefficients | None = None,
coefficient_index: int | None = None,
active_triplets: Sequence[tuple[int, int, int]] | None = None,
spline: SplineKind = "cubic",
full_support_start_xy: float = 0.0,
full_support_start_z: float = 2.0,
eps: float = 1.0e-12,
trainable: bool = True,
fittable: bool = True,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store categorized three-body coefficients and category layout."""
super().__init__(cutoff=cutoff, atomic_types=atomic_types)
if self.atomic_types is None or not self.atomic_types:
raise ValueError("`atomic_types` must contain at least one element")
n_cat = len(self.atomic_types)
expected_triplet_categories = n_cat * num_edge_categories(n_cat)
triplet_categories = _triplet_categories(self.atomic_types)
object.__setattr__(self, "_triplet_categories", triplet_categories)
object.__setattr__(
self,
"_triplet_index",
{triplet: index for index, triplet in enumerate(triplet_categories)},
)
self.coefficient_index = (
None if coefficient_index is None else int(coefficient_index)
)
object.__setattr__(self, "_coefficient_provider", coefficient_provider)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
active_triplet_mask = _active_triplet_mask(
triplet_categories,
active_triplets=active_triplets,
)
same_neighbor_triplet_mask = _same_neighbor_triplet_mask(triplet_categories)
self.register_buffer(
"active_triplet_mask",
active_triplet_mask,
persistent=False,
)
self.register_buffer(
"same_neighbor_triplet_mask",
same_neighbor_triplet_mask,
persistent=False,
)
object.__setattr__(
self,
"_active_triplet_indices",
tuple(
index
for index, enabled in enumerate(active_triplet_mask.tolist())
if enabled
),
)
if coefficient_provider is None:
if coeffs_by_triplet is None:
raise ValueError(
"`coeffs_by_triplet` is required when "
"`coefficient_provider` is not set"
)
coeffs_tensor = torch.as_tensor(
coeffs_by_triplet,
dtype=dtype,
)
if coeffs_tensor.ndim != 4:
raise ValueError(
"`coeffs_by_triplet` must have shape "
"(n_triplet_categories, Nx, Ny, Nz)"
)
if coeffs_tensor.shape[0] != expected_triplet_categories:
raise ValueError(
"`coeffs_by_triplet.shape[0]` must equal "
f"{expected_triplet_categories} for "
f"atomic_types={self.atomic_types}, "
f"got {coeffs_tensor.shape[0]}"
)
self.coeffs_by_triplet = torch.nn.Parameter(
coeffs_tensor,
requires_grad=bool(trainable) and not self.frozen,
)
coeff_shape = tuple(int(dim) for dim in coeffs_tensor.shape[1:])
else:
provider_shape = coefficient_provider.coefficient_shape
if len(provider_shape) == 3:
if len(self._active_triplet_indices) != 1:
raise ValueError(
"three-dimensional three-body alchemical coefficients "
"require exactly one active triplet"
)
coeff_shape = tuple(int(dim) for dim in provider_shape)
elif len(provider_shape) == 4:
if provider_shape[0] != expected_triplet_categories:
raise ValueError(
"`coefficient_provider` must provide "
f"{expected_triplet_categories} triplet categories for "
f"atomic_types={self.atomic_types}, got "
f"{provider_shape[0]}"
)
coeff_shape = tuple(int(dim) for dim in provider_shape[1:])
else:
raise ValueError(
"`coefficient_provider` must provide three-dimensional "
"single-triplet coefficients or four-dimensional categorized "
"coefficients"
)
if self.coefficient_index is None:
raise ValueError(
"`coefficient_index` is required when `coefficient_provider` is set"
)
coefficient_provider.true_coeffs_for(self.coefficient_index)
self.spline = spline
self.full_support_start_xy = float(full_support_start_xy)
self.full_support_start_z = float(full_support_start_z)
if coeff_shape[0] != coeff_shape[1]:
raise ValueError(
"three-body coefficients must have matching x/y dimensions when "
"using a shared center-neighbor support grid"
)
object.__setattr__(self, "coeff_shape", coeff_shape)
coeff_size_xy = coeff_shape[0]
coeff_size_z = coeff_shape[2]
self.first_knot_xy, self.knot_spacing_xy = uniform_support_parameters(
coeff_size=coeff_size_xy,
lower_full_support=self.full_support_start_xy,
upper_full_support=cutoff,
spline=self.spline,
)
self.first_knot_z, self.knot_spacing_z = uniform_support_parameters(
coeff_size=coeff_size_z,
lower_full_support=self.full_support_start_z,
upper_full_support=_neighbor_neighbor_cutoff(cutoff),
spline=self.spline,
)
self.lower_support_xy, self.upper_support_xy = _support_bounds(
self.first_knot_xy,
self.knot_spacing_xy,
coeff_size_xy,
lower_full_support=self.full_support_start_xy,
)
self.lower_support_z, self.upper_support_z = _support_bounds(
self.first_knot_z,
self.knot_spacing_z,
coeff_size_z,
lower_full_support=self.full_support_start_z,
)
self.eps = float(eps)
self.register_buffer(
"edge_cat_table",
build_edge_category_table(n_cat),
persistent=False,
)
@property
def n_categories(self) -> int:
"""Return the number of atomic categories tracked by this term."""
assert self.atomic_types is not None
return len(self.atomic_types)
@property
def triplet_categories(self) -> tuple[tuple[int, int, int], ...]:
"""Return the ordered triplet categories addressed by the coefficient tensor."""
return self._triplet_categories
@property
def coefficient_provider(self) -> AlchemicalCoefficients | None:
"""Return the shared coefficient provider for alchemical fitting."""
return self._coefficient_provider
@property
def provides_forces(self) -> bool:
"""Report that this term produces analytic forces directly."""
return True
@property
def input_requirements(self) -> TermInputRequirements:
"""Declare the directed full-neighbor-list requirement."""
return TermInputRequirements(full_neighbor_list=True)
@property
def neighbor_neighbor_cutoff(self) -> float:
"""Return the derived cutoff used when forming neighbor-neighbor triplets."""
assert self.cutoff is not None
return _neighbor_neighbor_cutoff(self.cutoff)
@property
def active_triplet_categories(self) -> tuple[tuple[int, int, int], ...]:
"""Return the subset of triplet categories that remain enabled."""
return tuple(
self.triplet_categories[index] for index in self._active_triplet_indices
)
@property
def true_coeffs_by_triplet(self) -> torch.Tensor:
"""Return the direct or provider-projected triplet coefficient tensor."""
if self.coefficient_provider is None:
coeffs = self.coeffs_by_triplet
else:
assert self.coefficient_index is not None
coeffs = self.coefficient_provider.true_coeffs_for(self.coefficient_index)
if coeffs.ndim == 3:
triplet_index = self._active_triplet_indices[0]
full_shape = (
len(self.triplet_categories),
int(coeffs.shape[0]),
int(coeffs.shape[1]),
int(coeffs.shape[2]),
)
full_coeffs = coeffs.new_zeros(full_shape)
full_coeffs[triplet_index] = coeffs
coeffs = full_coeffs
return _symmetrize_same_neighbor_coeffs(coeffs, self.same_neighbor_triplet_mask)
def _parameter_block_shape(self) -> tuple[int, ...]:
"""Return the solved coefficient-block shape for this term."""
if (
self.coefficient_provider is not None
and len(self.coefficient_provider.coefficient_shape) == 3
):
assert self.coefficient_index is not None
return tuple(
int(dim)
for dim in self.coefficient_provider.true_coeffs_for(
self.coefficient_index
).shape
)
return tuple(int(dim) for dim in self.true_coeffs_by_triplet.shape)
def _read_parameter_block(self) -> torch.Tensor:
"""Return the coefficient tensor represented by the solve block."""
if (
self.coefficient_provider is not None
and len(self.coefficient_provider.coefficient_shape) == 3
):
assert self.coefficient_index is not None
return self.coefficient_provider.true_coeffs_for(self.coefficient_index)
return self.true_coeffs_by_triplet
def _write_parameter_block(self, values: torch.Tensor) -> None:
"""Write solved three-body coefficients back into storage."""
if self.coefficient_provider is None:
copy_parameter_data(self.coeffs_by_triplet, values)
return
if not self.coefficient_provider.uses_identity_weights:
raise ValueError(
"can not write true coefficients directly into a non-identity "
"alchemical provider"
)
assert self.coefficient_index is not None
target_shape = self._read_parameter_block().shape
self.coefficient_provider.proxy_coeffs.data[self.coefficient_index].copy_(
values.reshape(target_shape).to(self.coefficient_provider.proxy_coeffs)
)
def _parameter_block_cache_descriptor(self) -> ParameterBlockCacheDescriptor | None:
"""Return reusable semantic cache metadata for this coefficient block."""
shape = self._parameter_block_shape()
if len(shape) == 3:
if len(self._active_triplet_indices) != 1:
return None
nx, ny, nz = shape
coeff_shape = (int(nx), int(ny), int(nz))
triplet_indices = tuple(
int(index) for index in self._active_triplet_indices
)
starts = {triplet_indices[0]: 0}
elif len(shape) == 4:
_, nx, ny, nz = shape
coeff_shape = (int(nx), int(ny), int(nz))
triplet_indices = tuple(
int(index) for index in self._active_triplet_indices
)
volume = int(nx) * int(ny) * int(nz)
starts = {
int(triplet_index): int(triplet_index) * volume
for triplet_index in triplet_indices
}
else:
return None
volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2])
return ParameterBlockCacheDescriptor(
family={
"kind": "threebody_spline",
"atomic_types": [int(value) for value in self.atomic_types or ()],
"spline": str(self.spline),
"first_knot_xy": float(self.first_knot_xy),
"first_knot_z": float(self.first_knot_z),
"knot_spacing_xy": float(self.knot_spacing_xy),
"knot_spacing_z": float(self.knot_spacing_z),
"lower_support_xy": float(self.lower_support_xy),
"lower_support_z": float(self.lower_support_z),
"coeff_shape": [int(value) for value in coeff_shape],
"eps": float(self.eps),
},
channels=tuple(
ParameterBlockCacheChannel(
kind="triplet",
values=self.triplet_categories[triplet_index],
start=int(starts[triplet_index]),
stop=int(starts[triplet_index]) + volume,
)
for triplet_index in triplet_indices
),
)
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return the three-body spline coefficient block."""
return (
ParameterBlock(
name="coeffs_by_triplet",
kind="threebody",
shape=self._parameter_block_shape(),
read=self._read_parameter_block,
write=self._write_parameter_block,
label=f"threebody[{self.atomic_types}]",
coefficient_provider=self.coefficient_provider,
coefficient_index=self.coefficient_index,
regularization_group="threebody",
fittable=self.fittable and bool(self._active_triplet_indices),
frozen=self.frozen,
assembler="threebody",
cache_descriptor=self._parameter_block_cache_descriptor(),
),
)
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble three-body least-squares blocks for this term."""
from ufp.leastsquares._assemble import _assemble_threebody_block
blocks = () if options is None else options.blocks
threebody_lstsq_backend = (
None if options is None else options.threebody_lstsq_backend
)
threebody_bucket_backend = (
None if options is None else options.threebody_bucket_backend
)
runtime_config = None if options is None else options.threebody_runtime_config
matrices = {}
for block in blocks:
matrix = _assemble_threebody_block(
block,
batch.inputs,
targets,
threebody_lstsq_backend=threebody_lstsq_backend,
threebody_bucket_backend=threebody_bucket_backend,
threebody_runtime_config=runtime_config,
)
if matrix is not None:
matrices[block.index] = matrix
return matrices
[docs]
def assemble_selected_linear_block(
self,
block,
inputs: UFPInput,
targets,
selected_indices: Sequence[int],
) -> torch.Tensor | None:
"""Assemble only requested coefficient columns for this three-body block."""
selected_indices = tuple(int(index) for index in selected_indices)
if inputs.neighbor_list is None:
raise RuntimeError("SplineThreeBodyTerm requires a neighbor list")
if not inputs.neighbor_list.full_list:
raise RuntimeError("SplineThreeBodyTerm requires a full neighbor list")
assert self.atomic_types is not None
node_cat = inputs.atomic_category_indices(self.atomic_types)
supported_atoms = node_cat >= 0
if not torch.any(supported_atoms):
return None
first_atom, second_atom = inputs.pair_indices()
pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom]
if not torch.any(pair_mask):
return None
pair_distances = inputs.pair_distances(pair_mask)
center_support_mask = (pair_distances >= self.lower_support_xy) & (
pair_distances < self.upper_support_xy
)
if not torch.any(center_support_mask):
return None
filtered_first, filtered_second = inputs.pair_indices(pair_mask)
pair_vectors = inputs.pair_vectors(pair_mask)
filtered_first = filtered_first[center_support_mask]
filtered_second = filtered_second[center_support_mask]
pair_vectors = pair_vectors[center_support_mask]
pair_distances = pair_distances[center_support_mask]
buckets = preprocess_sources(
filtered_first,
filtered_second,
node_cat,
self.n_categories,
pair_vectors,
pair_distances,
)
if not buckets:
return None
if len(block.shape) == 3:
if len(self._active_triplet_indices) != 1:
raise ValueError(
"single-triplet three-body alchemical blocks require exactly one "
"active triplet"
)
nx, ny, nz = block.shape
n_triplet_categories = len(self.triplet_categories)
selected_triplet_index = int(self._active_triplet_indices[0])
selected_triplet_indices = None
else:
n_triplet_categories, nx, ny, nz = block.shape
coeff_volume = int(nx * ny * nz)
selected_triplet_index = None
selected_triplet_indices = torch.unique(
torch.div(
torch.as_tensor(
[int(index) for index in selected_indices],
dtype=torch.int64,
device=inputs.device,
),
coeff_volume,
rounding_mode="floor",
)
)
selected_lookup = selected_column_lookup(
selected_indices,
block_size=block.size,
device=inputs.device,
)
matrix = _selected_block_matrix(
targets,
selected_indices,
device=inputs.device,
dtype=inputs.dtype,
)
triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2
system_index = inputs.system_index
for pattern_index in range(int(buckets.patterns.shape[0])):
src_start = int(buckets.pattern_ptr[pattern_index].item())
src_end = int(buckets.pattern_ptr[pattern_index + 1].item())
pattern = buckets.patterns[pattern_index]
src_cat = int(pattern[0].item())
counts = pattern[1:].detach().cpu().tolist()
layout = pattern_triplet_layout(counts, inputs.device)
if layout.row.numel() == 0:
continue
degree = int(sum(counts))
edge_start = int(buckets.row_ptr[src_start].item())
edge_end = int(buckets.row_ptr[src_end].item())
src_ids = buckets.src_ids[src_start:src_end]
nbr_ids = buckets.nbr_ids[edge_start:edge_end].view(
src_end - src_start,
degree,
)
vectors = buckets.pair_vectors[edge_start:edge_end].view(
src_end - src_start,
degree,
3,
)
distances = buckets.pair_distances[edge_start:edge_end].view(
src_end - src_start,
degree,
)
triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat
_accumulate_selected_threebody_pairs(
matrix=matrix,
term=self,
src_ids=src_ids,
src_system=system_index.index_select(0, src_ids),
triplet_index=triplet_index,
row=layout.row,
col=layout.col,
vec=vectors,
distances=distances,
nbr_sorted=nbr_ids,
coeff_shape=(
int(n_triplet_categories),
int(nx),
int(ny),
int(nz),
),
selected_triplet_index=selected_triplet_index,
selected_triplet_indices=selected_triplet_indices,
selected_lookup=selected_lookup,
energy_rows=targets.energy_rows,
force_rows=targets.force_rows,
per_atom_rows=targets.per_atom_rows,
)
return None if torch.count_nonzero(matrix) == 0 else matrix
def _cache_key(self) -> str:
"""Return this term's key inside per-input three-body bucket caches."""
payload = {
"atomic_types": list(self.atomic_types or ()),
"coeff_shape": list(self.coeff_shape),
"active_triplet_indices": list(self._active_triplet_indices),
"spline": self.spline,
"first_knot_xy": self.first_knot_xy,
"first_knot_z": self.first_knot_z,
"knot_spacing_xy": self.knot_spacing_xy,
"knot_spacing_z": self.knot_spacing_z,
"lower_support_xy": self.lower_support_xy,
"lower_support_z": self.lower_support_z,
"eps": self.eps,
}
encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode(
"utf8"
)
return hashlib.sha256(encoded).hexdigest()[:16]
def _bucket_triplets(
self,
inputs: UFPInput,
node_cat: torch.Tensor,
*,
attach_pattern_plans: bool,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> Buckets:
"""Build reusable triplet buckets for this term and input geometry."""
config = resolve_threebody_runtime_config(runtime_config)
supported_atoms = node_cat >= 0
if not torch.any(supported_atoms):
return preprocess_sources_native_or_torch(
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
node_cat,
self.n_categories,
inputs.positions.new_zeros((0, 3)),
inputs.positions.new_zeros((0,)),
runtime_config=config,
)
first_atom, second_atom = inputs.pair_indices()
pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom]
if not torch.any(pair_mask):
return preprocess_sources_native_or_torch(
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
node_cat,
self.n_categories,
inputs.positions.new_zeros((0, 3)),
inputs.positions.new_zeros((0,)),
runtime_config=config,
)
pair_distances = inputs.pair_distances(pair_mask)
center_support_mask = (pair_distances >= self.lower_support_xy) & (
pair_distances < self.upper_support_xy
)
if not torch.any(center_support_mask):
return preprocess_sources_native_or_torch(
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
torch.zeros((0,), dtype=torch.int64, device=inputs.device),
node_cat,
self.n_categories,
inputs.positions.new_zeros((0, 3)),
inputs.positions.new_zeros((0,)),
runtime_config=config,
)
filtered_first, filtered_second = inputs.pair_indices(pair_mask)
pair_vectors = inputs.pair_vectors(pair_mask)
filtered_first = filtered_first[center_support_mask]
filtered_second = filtered_second[center_support_mask]
pair_vectors = pair_vectors[center_support_mask]
pair_distances = pair_distances[center_support_mask]
buckets = preprocess_sources_native_or_torch(
filtered_first,
filtered_second,
node_cat,
self.n_categories,
pair_vectors,
pair_distances,
runtime_config=config,
)
if attach_pattern_plans:
if buckets.tensor_pattern_plans is not None:
return buckets
return buckets.with_pattern_plans(inputs.device)
return buckets
[docs]
def cache_input(
self,
inputs: UFPInput,
options: TermCacheOptions | None = None,
*,
feature_cache_storage: Literal["none", "cpu", "disk"] = "cpu",
feature_cache_mode: FeatureCacheMode = "auto",
feature_cache_dir: Path | str | None = None,
cache_prefix: str = "threebody",
legacy_cache_prefixes: Sequence[str] = (),
include_per_atom_energy: bool = True,
) -> None:
"""Precompute static dense three-body feature blocks for a cached input."""
if options is not None:
feature_cache_storage = options.feature_cache_storage
feature_cache_mode = options.feature_cache_mode # type: ignore[assignment]
feature_cache_dir = options.feature_cache_dir
cache_prefix = options.cache_prefix
include_per_atom_energy = options.include_per_atom_energy
del legacy_cache_prefixes
if feature_cache_mode not in {"auto", "read", "refresh"}:
raise ValueError(
"`feature_cache_mode` must be 'auto', 'read', or 'refresh'"
)
if feature_cache_mode == "read" and feature_cache_storage != "disk":
raise ValueError("`feature_cache_mode='read'` requires disk feature cache")
if inputs.neighbor_list is None or not inputs.neighbor_list.full_list:
return
if not self._active_triplet_indices:
return
assert self.atomic_types is not None
runtime_config = resolve_threebody_runtime_config()
bucket_cache = dict(inputs.metadata.get(_THREEBODY_BUCKET_CACHE_KEY, {}))
feature_blocks = dict(inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY, {}))
coeff_shape = tuple(int(value) for value in self.coeff_shape)
cache_dir = None if feature_cache_dir is None else Path(feature_cache_dir)
cache_key = self._cache_key()
disk_prefix = f"{cache_prefix}_term{cache_key}"
metadata = _dense_feature_cache_metadata(
inputs,
cache_key=cache_key,
atomic_types=self.atomic_types,
triplet_categories=self.triplet_categories,
coeff_shape=coeff_shape,
active_triplet_indices=self._active_triplet_indices,
include_per_atom_energy=include_per_atom_energy,
spline=self.spline,
first_knot_xy=self.first_knot_xy,
first_knot_z=self.first_knot_z,
knot_spacing_xy=self.knot_spacing_xy,
knot_spacing_z=self.knot_spacing_z,
lower_support_xy=self.lower_support_xy,
lower_support_z=self.lower_support_z,
eps=self.eps,
)
cached_disk_features = None
if (
feature_cache_storage == "disk"
and cache_dir is not None
and feature_cache_mode != "refresh"
):
settings_dir = _dense_feature_cache_dir(cache_dir, disk_prefix, metadata)
try:
cached_disk_features = _load_memmap_dense_feature_cache(
settings_dir,
disk_prefix,
expected_metadata=metadata,
required_triplet_indices=self._active_triplet_indices,
)
except (OSError, ValueError, json.JSONDecodeError):
cached_disk_features = None
if cached_disk_features is None:
try:
cached_disk_features = _load_memmap_dense_feature_cache(
cache_dir,
disk_prefix,
expected_metadata=metadata,
required_triplet_indices=self._active_triplet_indices,
)
except (OSError, ValueError, json.JSONDecodeError):
cached_disk_features = None
if cached_disk_features is None:
try:
cached_disk_features = _find_compatible_memmap_dense_feature_cache(
cache_dir,
expected_metadata=metadata,
required_triplet_indices=self._active_triplet_indices,
)
except (OSError, ValueError, json.JSONDecodeError):
cached_disk_features = None
if cached_disk_features is not None:
feature_blocks[cache_key] = cached_disk_features
inputs.metadata[_THREEBODY_FEATURE_CACHE_KEY] = feature_blocks
return
if feature_cache_mode == "read" and feature_cache_storage == "disk":
raise FileNotFoundError(
"three-body feature cache requested in read mode, but no compatible "
f"V2 cache was found for prefix {disk_prefix!r}"
)
node_cat = inputs.atomic_category_indices(self.atomic_types)
buckets = self._bucket_triplets(
inputs,
node_cat,
attach_pattern_plans=True,
runtime_config=runtime_config,
)
bucket_cache[cache_key] = buckets
inputs.metadata[_THREEBODY_BUCKET_CACHE_KEY] = bucket_cache
if not buckets:
return
if feature_cache_storage == "none":
return
feature_blocks[cache_key] = _build_dense_feature_cache_from_buckets(
buckets,
inputs.system_index,
coeff_shape,
spline=self.spline,
active_triplet_mask=(
None
if len(self._active_triplet_indices) == len(self.triplet_categories)
else self.active_triplet_mask
),
n_cat=self.n_categories,
first_knot_xy=self.first_knot_xy,
first_knot_z=self.first_knot_z,
knot_spacing_xy=self.knot_spacing_xy,
knot_spacing_z=self.knot_spacing_z,
lower_support_xy=self.lower_support_xy,
lower_support_z=self.lower_support_z,
eps=self.eps,
storage=feature_cache_storage,
cache_dir=cache_dir,
cache_prefix=disk_prefix,
metadata=metadata,
overwrite=feature_cache_mode == "refresh",
include_per_atom_energy=include_per_atom_energy,
runtime_config=runtime_config,
)
inputs.metadata[_THREEBODY_FEATURE_CACHE_KEY] = feature_blocks
[docs]
def dense_atom_features(
self,
inputs: UFPInput,
atom_indices: Sequence[int] | torch.Tensor | None = None,
*,
force_scope: Literal["output", "source"] = "output",
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> ThreeBodyDenseAtomFeatures:
"""
Return dense coefficient-space output rows for selected atoms.
This is intended for debugging fixed-geometry feature construction. The
returned rows are the dense equivalent of the sparse cached operators used
for this term: one per-atom energy row and one force row for each Cartesian
component. With ``force_scope="output"``, force rows are the full model-output
rows for the selected atoms. With ``force_scope="source"``, force rows include
only interactions centered on the selected atoms.
Args:
inputs: Normalized input bundle with a full neighbor list.
atom_indices: Optional atom indices to extract. If omitted, all atoms
are returned in input order.
force_scope: Whether force rows should include all output contributions
or only source-centered contributions.
Returns:
Dense per-atom energy and force-component feature rows.
Raises:
RuntimeError: If the input lacks a full neighbor list.
ValueError: If ``force_scope`` is not ``"output"`` or ``"source"``.
"""
if force_scope not in {"output", "source"}:
raise ValueError("`force_scope` must be 'output' or 'source'")
if inputs.neighbor_list is None:
raise RuntimeError(
"SplineThreeBodyTerm.dense_atom_features requires a neighbor list"
)
if not inputs.neighbor_list.full_list:
raise RuntimeError(
"SplineThreeBodyTerm.dense_atom_features requires a full neighbor list"
)
selected_atoms = _selected_atom_indices(
inputs.n_atoms,
atom_indices,
device=inputs.device,
)
runtime_config = resolve_threebody_runtime_config(runtime_config)
coeff_shape = tuple(int(value) for value in self.coeff_shape)
dense_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache
dense_cache = DenseThreeBodyFeatureCache(blocks=())
if self._active_triplet_indices:
cached_features = None
if force_scope == "output":
feature_cache = inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY)
if isinstance(feature_cache, dict):
cached_features = feature_cache.get(self._cache_key())
if isinstance(
cached_features,
(DenseThreeBodyFeatureCache, MemmapDenseThreeBodyFeatureCache),
) and all(
block.per_atom_energy is not None for block in cached_features.blocks
):
dense_cache = cached_features
else:
assert self.atomic_types is not None
node_cat = inputs.atomic_category_indices(self.atomic_types)
buckets = self._bucket_triplets(
inputs,
node_cat,
attach_pattern_plans=True,
runtime_config=runtime_config,
)
if buckets:
feature_cache = _build_feature_cache_from_buckets(
buckets,
coeff_shape,
spline=self.spline,
active_triplet_mask=(
None
if len(self._active_triplet_indices)
== len(self.triplet_categories)
else self.active_triplet_mask
),
n_cat=self.n_categories,
first_knot_xy=self.first_knot_xy,
first_knot_z=self.first_knot_z,
knot_spacing_xy=self.knot_spacing_xy,
knot_spacing_z=self.knot_spacing_z,
lower_support_xy=self.lower_support_xy,
lower_support_z=self.lower_support_z,
eps=self.eps,
runtime_config=runtime_config,
)
dense_cache = _build_dense_feature_cache_from_feature_cache(
feature_cache,
inputs.system_index,
coeff_shape=coeff_shape,
force_scope=force_scope,
runtime_config=runtime_config,
)
features = _dense_atom_features_from_feature_cache(
dense_cache,
selected_atoms,
n_triplet_categories=len(self.triplet_categories),
coeff_shape=coeff_shape,
dtype=inputs.dtype,
)
return _symmetrize_dense_atom_features(
features,
self.same_neighbor_triplet_mask,
coeff_shape=coeff_shape,
)
[docs]
def canonical_triplet(
self,
source: int,
first_neighbor: int,
second_neighbor: int,
) -> tuple[int, int, int]:
"""Normalize a triplet key using the term's neighbor-ordering convention."""
return _canonical_triplet(source, first_neighbor, second_neighbor)
[docs]
def triplet_category_index(
self,
source: int,
first_neighbor: int,
second_neighbor: int,
) -> int:
"""Return the coefficient-block index for one canonical triplet."""
triplet = self.canonical_triplet(source, first_neighbor, second_neighbor)
try:
return self._triplet_index[triplet]
except KeyError as exc:
raise KeyError(f"triplet {triplet} is not part of this term") from exc
[docs]
def is_triplet_active(
self,
source: int,
first_neighbor: int,
second_neighbor: int,
) -> bool:
"""Report whether a canonical triplet category remains enabled."""
return bool(
self.active_triplet_mask[
self.triplet_category_index(source, first_neighbor, second_neighbor)
].item()
)
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Build local triplet buckets and return one three-body contribution."""
if inputs.neighbor_list is None:
raise RuntimeError(
"SplineThreeBodyTerm requires a neighbor list, but `inputs` does not "
"contain one"
)
if not inputs.neighbor_list.full_list:
raise RuntimeError("SplineThreeBodyTerm requires a full neighbor list")
if not self._active_triplet_indices:
return empty_atomwise_output(inputs, forces=True)
assert self.atomic_types is not None
runtime_config = resolve_threebody_runtime_config()
node_cat = inputs.atomic_category_indices(self.atomic_types)
cached_features = None
cached_buckets = None
if not inputs.positions.requires_grad:
feature_cache = inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY)
if isinstance(feature_cache, dict):
cached_features = feature_cache.get(self._cache_key())
cache = inputs.metadata.get(_THREEBODY_BUCKET_CACHE_KEY)
if isinstance(cache, dict):
cached_buckets = cache.get(self._cache_key())
coeffs = self.true_coeffs_by_triplet.to(
device=inputs.device,
dtype=inputs.dtype,
)
if isinstance(
cached_features,
(DenseThreeBodyFeatureCache, MemmapDenseThreeBodyFeatureCache),
):
energy, per_atom_energy, forces = (
_evaluate_dense_feature_cache_energy_forces(
cached_features,
coeffs,
n_nodes=inputs.n_atoms,
n_systems=inputs.n_systems,
)
)
return UFPOutput(
energy=energy,
forces=forces,
per_atom_energy=per_atom_energy,
)
if isinstance(cached_buckets, Buckets):
buckets = cached_buckets
else:
buckets = self._bucket_triplets(
inputs,
node_cat,
attach_pattern_plans=False,
runtime_config=runtime_config,
)
if not buckets:
return empty_atomwise_output(inputs, forces=True)
edge_cat_table = self.edge_cat_table.to(device=inputs.device)
active_triplet_mask = (
None
if len(self._active_triplet_indices) == len(self.triplet_categories)
else self.active_triplet_mask.to(device=inputs.device)
)
per_atom_energy, forces = evaluate_bucketed_energy_forces(
buckets,
node_cat,
coeffs,
edge_cat_table,
spline=self.spline,
active_triplet_mask=active_triplet_mask,
n_nodes=inputs.n_atoms,
n_cat=self.n_categories,
first_knot_xy=self.first_knot_xy,
first_knot_z=self.first_knot_z,
knot_spacing_xy=self.knot_spacing_xy,
knot_spacing_z=self.knot_spacing_z,
lower_support_xy=self.lower_support_xy,
lower_support_z=self.lower_support_z,
eps=self.eps,
runtime_config=runtime_config,
)
energy = torch.zeros(inputs.n_systems, device=inputs.device, dtype=inputs.dtype)
energy.index_add_(0, inputs.system_index, per_atom_energy)
return UFPOutput(
energy=energy,
forces=forces,
per_atom_energy=per_atom_energy,
)
__all__ = [
"BucketedEnergyForceEvaluator",
"Buckets",
"DenseThreeBodyFeatureCache",
"DenseTripletFeatureBlock",
"MemmapDenseThreeBodyFeatureCache",
"MemmapDenseTripletFeatureBlock",
"SplineKind",
"SplineThreeBodyTerm",
"ThreeBodyDenseAtomFeatures",
"ThreeBodyTerm",
"build_edge_category_table",
"evaluate_bucketed_energy_forces",
"get_eval_3d_with_grads",
"load_memmap_threebody_feature_cache",
"make_bucketed_energy_forces_evaluator",
"num_edge_categories",
"preprocess_sources",
]