"""Two-distance spline three-body term."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Literal
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
spline_support_mask_2d,
uniform_stencil_2d,
uniform_support_parameters,
)
from ufp.terms._base import LinearAssemblyOptions, TermInputRequirements, ThreeBodyTerm
from ufp.terms._parameters import (
ParameterBlock,
ParameterBlockCacheChannel,
ParameterBlockCacheDescriptor,
copy_parameter_data,
)
from ufp.terms._selected_assembly import add_selected_entries, selected_column_lookup
from ufp.terms._shared import empty_atomwise_output
from ufp.terms._threebody_ops import (
num_edge_categories,
pattern_triplet_layout,
preprocess_sources,
)
from ufp.terms.categories import triplet_categories as _triplet_categories
from ufp.terms.threebody import _active_triplet_mask, _support_bounds
SplineKind = Literal["quadratic", "cubic", "quartic"]
def _add_entries(
matrix: torch.Tensor,
rows: torch.Tensor,
cols: torch.Tensor,
values: torch.Tensor,
) -> None:
"""Accumulate broadcastable row/column/value entries into a dense matrix."""
if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
return
rows, cols, values = torch.broadcast_tensors(rows, cols, values)
valid = rows >= 0
if not torch.any(valid):
return
width = int(matrix.shape[1])
flat_index = rows[valid].reshape(-1) * width + cols[valid].reshape(-1)
matrix.reshape(-1).index_add_(0, flat_index, values[valid].reshape(-1))
def _selected_block_matrix(
targets,
selected_indices: Sequence[int],
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Create a compact matrix for a selected coefficient block."""
return torch.zeros(
(targets.n_rows, len(tuple(selected_indices))),
dtype=dtype,
device=device,
)
def _selected_channel_mask(
triplet_index: torch.Tensor,
selected_triplet_indices: torch.Tensor,
) -> torch.Tensor:
"""Return which local triplet categories have selected coefficients."""
mask = torch.zeros_like(triplet_index, dtype=torch.bool)
for selected_index in selected_triplet_indices.detach().cpu().tolist():
mask = mask | (triplet_index == int(selected_index))
return mask
[docs]
class SplineTriplet2DTerm(ThreeBodyTerm):
"""
Three-body spline over the two center-neighbor distances ``r_ij`` and ``r_ik``.
"""
def __init__(
self,
*,
cutoff: float,
atomic_types: Sequence[int],
coeffs_by_triplet,
active_triplets: Sequence[tuple[int, int, int]] | None = None,
spline: SplineKind = "cubic",
full_support_start: float = 0.0,
eps: float = 1.0e-12,
trainable: bool = True,
fittable: bool = True,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store categorized 2D triplet spline coefficients."""
super().__init__(cutoff=cutoff, atomic_types=atomic_types)
if self.atomic_types is None or not self.atomic_types:
raise ValueError("`atomic_types` must contain at least one element")
coeffs = torch.as_tensor(coeffs_by_triplet, dtype=dtype)
if coeffs.ndim != 3:
raise ValueError(
"`coeffs_by_triplet` must have shape (n_triplet_categories, Nx, Ny)"
)
n_cat = len(self.atomic_types)
expected_triplet_categories = n_cat * num_edge_categories(n_cat)
if int(coeffs.shape[0]) != expected_triplet_categories:
raise ValueError(
"`coeffs_by_triplet.shape[0]` must equal "
f"{expected_triplet_categories}, got {coeffs.shape[0]}"
)
self.spline = spline
self.full_support_start = float(full_support_start)
self.eps = float(eps)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
self.coeffs_by_triplet = torch.nn.Parameter(
coeffs,
requires_grad=bool(trainable) and not self.frozen,
)
self.triplet_categories = _triplet_categories(self.atomic_types)
self.register_buffer(
"active_triplet_mask",
_active_triplet_mask(
self.triplet_categories,
active_triplets=active_triplets,
),
persistent=False,
)
self._active_triplet_indices = tuple(
index
for index, enabled in enumerate(self.active_triplet_mask.tolist())
if enabled
)
if int(coeffs.shape[1]) != int(coeffs.shape[2]):
raise ValueError("2D triplet coefficients must have matching x/y sizes")
self.coeff_shape = (int(coeffs.shape[1]), int(coeffs.shape[2]))
self.first_knot, self.knot_spacing = uniform_support_parameters(
coeff_size=self.coeff_shape[0],
lower_full_support=self.full_support_start,
upper_full_support=cutoff,
spline=spline,
)
self.lower_support, self.upper_support = _support_bounds(
self.first_knot,
self.knot_spacing,
self.coeff_shape[0],
lower_full_support=self.full_support_start,
)
@property
def n_categories(self) -> int:
"""Return the number of atomic categories."""
assert self.atomic_types is not None
return len(self.atomic_types)
@property
def provides_forces(self) -> bool:
"""Report that this term produces analytic forces."""
return True
@property
def input_requirements(self) -> TermInputRequirements:
"""Declare the directed full-neighbor-list requirement."""
return TermInputRequirements(full_neighbor_list=True)
@property
def active_triplet_categories(self) -> tuple[tuple[int, int, int], ...]:
"""Return the subset of triplet categories enabled for evaluation."""
return tuple(
self.triplet_categories[index] for index in self._active_triplet_indices
)
def _parameter_block_cache_descriptor(self) -> ParameterBlockCacheDescriptor:
"""Return reusable semantic cache metadata for 2D triplet coefficients."""
nx, ny = self.coeff_shape
volume = int(nx * ny)
return ParameterBlockCacheDescriptor(
family={
"kind": "triplet2d_spline",
"atomic_types": [int(value) for value in self.atomic_types or ()],
"spline": str(self.spline),
"first_knot": float(self.first_knot),
"knot_spacing": float(self.knot_spacing),
"lower_support": float(self.lower_support),
"coeff_shape": [int(nx), int(ny)],
"eps": float(self.eps),
},
channels=tuple(
ParameterBlockCacheChannel(
kind="triplet2d",
values=self.triplet_categories[triplet_index],
start=int(triplet_index) * volume,
stop=(int(triplet_index) + 1) * volume,
)
for triplet_index in self._active_triplet_indices
),
)
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return the categorized 2D triplet coefficient block."""
return (
ParameterBlock(
name="coeffs_by_triplet",
kind="triplet2d",
shape=tuple(int(dim) for dim in self.coeffs_by_triplet.shape),
read=lambda: self.coeffs_by_triplet,
write=lambda values: copy_parameter_data(
self.coeffs_by_triplet, values
),
label=f"triplet2d[{self.atomic_types}]",
regularization_group="threebody",
fittable=self.fittable and bool(self._active_triplet_indices),
frozen=self.frozen,
assembler=self.assemble_linear_block,
cache_descriptor=self._parameter_block_cache_descriptor(),
),
)
def _bucket_triplets(self, inputs: UFPInput):
"""Build source-neighbor buckets for the supported atoms."""
assert self.atomic_types is not None
node_cat = inputs.atomic_category_indices(self.atomic_types)
supported_atoms = node_cat >= 0
if not torch.any(supported_atoms):
return None, node_cat
first_atom, second_atom = inputs.pair_indices()
pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom]
if not torch.any(pair_mask):
return None, node_cat
pair_distances = inputs.pair_distances(pair_mask)
center_support_mask = (pair_distances >= self.lower_support) & (
pair_distances < self.upper_support
)
if not torch.any(center_support_mask):
return None, node_cat
filtered_first, filtered_second = inputs.pair_indices(pair_mask)
pair_vectors = inputs.pair_vectors(pair_mask)
buckets = preprocess_sources(
filtered_first[center_support_mask],
filtered_second[center_support_mask],
node_cat,
self.n_categories,
pair_vectors[center_support_mask],
pair_distances[center_support_mask],
)
return buckets, node_cat
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate the 2D triplet spline over full-neighbor-list triplets."""
self.validate_inputs(inputs)
if not self._active_triplet_indices:
return empty_atomwise_output(inputs, forces=True)
buckets, node_cat = self._bucket_triplets(inputs)
if not buckets:
return empty_atomwise_output(inputs, forces=True)
output = empty_atomwise_output(inputs, forces=True)
assert output.energy is not None
assert output.forces is not None
assert output.per_atom_energy is not None
coeffs_by_triplet = self.coeffs_by_triplet.to(
device=inputs.device,
dtype=inputs.dtype,
)
triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2
active_mask = self.active_triplet_mask.to(device=inputs.device)
for pattern_index in range(int(buckets.patterns.shape[0])):
src_start = int(buckets.pattern_ptr[pattern_index].item())
src_end = int(buckets.pattern_ptr[pattern_index + 1].item())
pattern = buckets.patterns[pattern_index]
src_cat = int(pattern[0].item())
counts = pattern[1:].detach().cpu().tolist()
layout = pattern_triplet_layout(counts, inputs.device)
if layout.row.numel() == 0:
continue
degree = int(sum(counts))
edge_start = int(buckets.row_ptr[src_start].item())
edge_end = int(buckets.row_ptr[src_end].item())
src_ids = buckets.src_ids[src_start:src_end]
src_system = inputs.system_index.index_select(0, src_ids)
nbr_ids = buckets.nbr_ids[edge_start:edge_end].view(
src_end - src_start, degree
)
vec = buckets.pair_vectors[edge_start:edge_end].view(
src_end - src_start, degree, 3
)
distances = buckets.pair_distances[edge_start:edge_end].view(
src_end - src_start, degree
)
triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat
enabled = active_mask.index_select(0, triplet_index)
if not torch.any(enabled):
continue
row = layout.row[enabled]
col = layout.col[enabled]
triplet_index = triplet_index[enabled]
x = distances[:, row]
y = distances[:, col]
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
supported = spline_support_mask_2d(
flat_x,
flat_y,
coeff_shape=self.coeff_shape,
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
if not torch.any(supported):
continue
flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1)
stencil = uniform_stencil_2d(
flat_x[supported],
flat_y[supported],
coeff_shape=self.coeff_shape,
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
flat_triplet_index = (
triplet_index[None, :]
.expand(src_ids.shape[0], -1)
.reshape(-1)
.index_select(0, flat_positions)
)
coeff_window = coeffs_by_triplet.reshape(len(self.triplet_categories), -1)[
flat_triplet_index[:, None], stencil.indices
]
energy = (stencil.values * coeff_window).sum(dim=1)
grad_x = (stencil.grad_x * coeff_window).sum(dim=1)
grad_y = (stencil.grad_y * coeff_window).sum(dim=1)
flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
flat_src = flat_src.index_select(0, flat_positions)
flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1)
flat_system = flat_system.index_select(0, flat_positions)
flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions)
flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions)
unit_j = (
vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_x[supported].clamp_min(self.eps)[:, None]
)
unit_k = (
vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_y[supported].clamp_min(self.eps)[:, None]
)
force_j = -grad_x[:, None] * unit_j
force_k = -grad_y[:, None] * unit_k
force_i = -(force_j + force_k)
output.energy.index_add_(0, flat_system, energy)
output.per_atom_energy.index_add_(0, flat_src, energy)
output.forces.index_add_(0, flat_src, force_i)
output.forces.index_add_(0, flat_j, force_j)
output.forces.index_add_(0, flat_k, force_k)
return output
[docs]
def assemble_linear_block(
self,
block,
inputs: UFPInput,
targets: Any,
) -> torch.Tensor | None:
"""Assemble this term's dense least-squares block."""
if inputs.neighbor_list is None or not inputs.neighbor_list.full_list:
raise RuntimeError("SplineTriplet2DTerm requires a full neighbor list")
buckets, _ = self._bucket_triplets(inputs)
if not buckets:
return None
n_triplet_categories, nx, ny = block.shape
matrix = torch.zeros(
(targets.n_rows, block.size),
dtype=inputs.dtype,
device=inputs.device,
)
triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2
active_mask = self.active_triplet_mask.to(device=inputs.device)
for pattern_index in range(int(buckets.patterns.shape[0])):
src_start = int(buckets.pattern_ptr[pattern_index].item())
src_end = int(buckets.pattern_ptr[pattern_index + 1].item())
pattern = buckets.patterns[pattern_index]
src_cat = int(pattern[0].item())
counts = pattern[1:].detach().cpu().tolist()
layout = pattern_triplet_layout(counts, inputs.device)
if layout.row.numel() == 0:
continue
degree = int(sum(counts))
edge_start = int(buckets.row_ptr[src_start].item())
edge_end = int(buckets.row_ptr[src_end].item())
src_ids = buckets.src_ids[src_start:src_end]
src_system = inputs.system_index.index_select(0, src_ids)
nbr_ids = buckets.nbr_ids[edge_start:edge_end].view(
src_end - src_start, degree
)
vec = buckets.pair_vectors[edge_start:edge_end].view(
src_end - src_start, degree, 3
)
distances = buckets.pair_distances[edge_start:edge_end].view(
src_end - src_start, degree
)
triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat
enabled = active_mask.index_select(0, triplet_index)
if not torch.any(enabled):
continue
row = layout.row[enabled]
col = layout.col[enabled]
triplet_index = triplet_index[enabled]
x = distances[:, row]
y = distances[:, col]
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
supported = spline_support_mask_2d(
flat_x,
flat_y,
coeff_shape=(int(nx), int(ny)),
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
if not torch.any(supported):
continue
flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1)
stencil = uniform_stencil_2d(
flat_x[supported],
flat_y[supported],
coeff_shape=(int(nx), int(ny)),
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
flat_triplet_index = (
triplet_index[None, :]
.expand(src_ids.shape[0], -1)
.reshape(-1)
.index_select(0, flat_positions)
)
cols = stencil.indices + flat_triplet_index[:, None] * int(nx * ny)
flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
flat_src = flat_src.index_select(0, flat_positions)
flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1)
flat_system = flat_system.index_select(0, flat_positions)
flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions)
flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions)
unit_j = (
vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_x[supported].clamp_min(self.eps)[:, None]
)
unit_k = (
vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_y[supported].clamp_min(self.eps)[:, None]
)
force_j = -(stencil.grad_x[:, :, None] * unit_j[:, None, :])
force_k = -(stencil.grad_y[:, :, None] * unit_k[:, None, :])
force_i = -(force_j + force_k)
_add_entries(
matrix,
targets.energy_rows.index_select(0, flat_system)[:, None],
cols,
stencil.values,
)
_add_entries(
matrix,
targets.per_atom_rows.index_select(0, flat_src)[:, None],
cols,
stencil.values,
)
for atom_index, force in (
(flat_src, force_i),
(flat_j, force_j),
(flat_k, force_k),
):
_add_entries(
matrix,
targets.force_rows.index_select(0, atom_index)[:, :, None],
cols[:, None, :],
force.permute(0, 2, 1),
)
return None if torch.count_nonzero(matrix) == 0 else matrix
[docs]
def assemble_selected_linear_block(
self,
block,
inputs: UFPInput,
targets: Any,
selected_indices: Sequence[int],
) -> torch.Tensor | None:
"""Assemble only requested coefficient columns for this 2D triplet block."""
selected_indices = tuple(int(index) for index in selected_indices)
if inputs.neighbor_list is None or not inputs.neighbor_list.full_list:
raise RuntimeError("SplineTriplet2DTerm requires a full neighbor list")
buckets, _ = self._bucket_triplets(inputs)
if not buckets:
return None
n_triplet_categories, nx, ny = block.shape
coeff_volume = int(nx * ny)
selected_lookup = selected_column_lookup(
selected_indices,
block_size=block.size,
device=inputs.device,
)
selected_triplet_indices = torch.unique(
torch.div(
torch.as_tensor(
[int(index) for index in selected_indices],
dtype=torch.int64,
device=inputs.device,
),
coeff_volume,
rounding_mode="floor",
)
)
matrix = _selected_block_matrix(
targets,
selected_indices,
device=inputs.device,
dtype=inputs.dtype,
)
triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2
active_mask = self.active_triplet_mask.to(device=inputs.device)
for pattern_index in range(int(buckets.patterns.shape[0])):
src_start = int(buckets.pattern_ptr[pattern_index].item())
src_end = int(buckets.pattern_ptr[pattern_index + 1].item())
pattern = buckets.patterns[pattern_index]
src_cat = int(pattern[0].item())
counts = pattern[1:].detach().cpu().tolist()
layout = pattern_triplet_layout(counts, inputs.device)
if layout.row.numel() == 0:
continue
degree = int(sum(counts))
edge_start = int(buckets.row_ptr[src_start].item())
edge_end = int(buckets.row_ptr[src_end].item())
src_ids = buckets.src_ids[src_start:src_end]
src_system = inputs.system_index.index_select(0, src_ids)
nbr_ids = buckets.nbr_ids[edge_start:edge_end].view(
src_end - src_start, degree
)
vec = buckets.pair_vectors[edge_start:edge_end].view(
src_end - src_start, degree, 3
)
distances = buckets.pair_distances[edge_start:edge_end].view(
src_end - src_start, degree
)
triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat
enabled = active_mask.index_select(0, triplet_index)
enabled = enabled & _selected_channel_mask(
triplet_index,
selected_triplet_indices,
)
if not torch.any(enabled):
continue
row = layout.row[enabled]
col = layout.col[enabled]
triplet_index = triplet_index[enabled]
x = distances[:, row]
y = distances[:, col]
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
supported = spline_support_mask_2d(
flat_x,
flat_y,
coeff_shape=(int(nx), int(ny)),
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
if not torch.any(supported):
continue
flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1)
stencil = uniform_stencil_2d(
flat_x[supported],
flat_y[supported],
coeff_shape=(int(nx), int(ny)),
first_knot_x=self.first_knot,
first_knot_y=self.first_knot,
knot_spacing_x=self.knot_spacing,
knot_spacing_y=self.knot_spacing,
spline=self.spline,
)
flat_triplet_index = (
triplet_index[None, :]
.expand(src_ids.shape[0], -1)
.reshape(-1)
.index_select(0, flat_positions)
)
cols = stencil.indices + flat_triplet_index[:, None] * coeff_volume
flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
flat_src = flat_src.index_select(0, flat_positions)
flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1)
flat_system = flat_system.index_select(0, flat_positions)
flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions)
flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions)
unit_j = (
vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_x[supported].clamp_min(self.eps)[:, None]
)
unit_k = (
vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions)
/ flat_y[supported].clamp_min(self.eps)[:, None]
)
force_j = -(stencil.grad_x[:, :, None] * unit_j[:, None, :])
force_k = -(stencil.grad_y[:, :, None] * unit_k[:, None, :])
force_i = -(force_j + force_k)
add_selected_entries(
matrix,
targets.energy_rows.index_select(0, flat_system)[:, None],
cols,
stencil.values,
selected_lookup,
)
add_selected_entries(
matrix,
targets.per_atom_rows.index_select(0, flat_src)[:, None],
cols,
stencil.values,
selected_lookup,
)
for atom_index, force in (
(flat_src, force_i),
(flat_j, force_j),
(flat_k, force_k),
):
add_selected_entries(
matrix,
targets.force_rows.index_select(0, atom_index)[:, :, None],
cols[:, None, :],
force.permute(0, 2, 1),
selected_lookup,
)
return None if torch.count_nonzero(matrix) == 0 else matrix
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble all requested 2D triplet blocks for this term."""
blocks = () if options is None else options.blocks
return {
block.index: matrix
for block in blocks
if (matrix := self.assemble_linear_block(block, batch.inputs, targets))
is not None
}
__all__ = [
"SplineKind",
"SplineTriplet2DTerm",
]