Source code for ufp.terms.triplet2d

"""Two-distance spline three-body term."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Literal

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
    spline_support_mask_2d,
    uniform_stencil_2d,
    uniform_support_parameters,
)
from ufp.terms._base import LinearAssemblyOptions, TermInputRequirements, ThreeBodyTerm
from ufp.terms._parameters import (
    ParameterBlock,
    ParameterBlockCacheChannel,
    ParameterBlockCacheDescriptor,
    copy_parameter_data,
)
from ufp.terms._selected_assembly import add_selected_entries, selected_column_lookup
from ufp.terms._shared import empty_atomwise_output
from ufp.terms._threebody_ops import (
    num_edge_categories,
    pattern_triplet_layout,
    preprocess_sources,
)
from ufp.terms.categories import triplet_categories as _triplet_categories
from ufp.terms.threebody import _active_triplet_mask, _support_bounds


SplineKind = Literal["quadratic", "cubic", "quartic"]


def _add_entries(
    matrix: torch.Tensor,
    rows: torch.Tensor,
    cols: torch.Tensor,
    values: torch.Tensor,
) -> None:
    """Accumulate broadcastable row/column/value entries into a dense matrix."""
    if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
        return
    rows, cols, values = torch.broadcast_tensors(rows, cols, values)
    valid = rows >= 0
    if not torch.any(valid):
        return
    width = int(matrix.shape[1])
    flat_index = rows[valid].reshape(-1) * width + cols[valid].reshape(-1)
    matrix.reshape(-1).index_add_(0, flat_index, values[valid].reshape(-1))


def _selected_block_matrix(
    targets,
    selected_indices: Sequence[int],
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Create a compact matrix for a selected coefficient block."""
    return torch.zeros(
        (targets.n_rows, len(tuple(selected_indices))),
        dtype=dtype,
        device=device,
    )


def _selected_channel_mask(
    triplet_index: torch.Tensor,
    selected_triplet_indices: torch.Tensor,
) -> torch.Tensor:
    """Return which local triplet categories have selected coefficients."""
    mask = torch.zeros_like(triplet_index, dtype=torch.bool)
    for selected_index in selected_triplet_indices.detach().cpu().tolist():
        mask = mask | (triplet_index == int(selected_index))
    return mask


[docs] class SplineTriplet2DTerm(ThreeBodyTerm): """ Three-body spline over the two center-neighbor distances ``r_ij`` and ``r_ik``. """ def __init__( self, *, cutoff: float, atomic_types: Sequence[int], coeffs_by_triplet, active_triplets: Sequence[tuple[int, int, int]] | None = None, spline: SplineKind = "cubic", full_support_start: float = 0.0, eps: float = 1.0e-12, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store categorized 2D triplet spline coefficients.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) if self.atomic_types is None or not self.atomic_types: raise ValueError("`atomic_types` must contain at least one element") coeffs = torch.as_tensor(coeffs_by_triplet, dtype=dtype) if coeffs.ndim != 3: raise ValueError( "`coeffs_by_triplet` must have shape (n_triplet_categories, Nx, Ny)" ) n_cat = len(self.atomic_types) expected_triplet_categories = n_cat * num_edge_categories(n_cat) if int(coeffs.shape[0]) != expected_triplet_categories: raise ValueError( "`coeffs_by_triplet.shape[0]` must equal " f"{expected_triplet_categories}, got {coeffs.shape[0]}" ) self.spline = spline self.full_support_start = float(full_support_start) self.eps = float(eps) self.fittable = bool(fittable) self.frozen = bool(frozen) self.coeffs_by_triplet = torch.nn.Parameter( coeffs, requires_grad=bool(trainable) and not self.frozen, ) self.triplet_categories = _triplet_categories(self.atomic_types) self.register_buffer( "active_triplet_mask", _active_triplet_mask( self.triplet_categories, active_triplets=active_triplets, ), persistent=False, ) self._active_triplet_indices = tuple( index for index, enabled in enumerate(self.active_triplet_mask.tolist()) if enabled ) if int(coeffs.shape[1]) != int(coeffs.shape[2]): raise ValueError("2D triplet coefficients must have matching x/y sizes") self.coeff_shape = (int(coeffs.shape[1]), int(coeffs.shape[2])) self.first_knot, self.knot_spacing = uniform_support_parameters( coeff_size=self.coeff_shape[0], lower_full_support=self.full_support_start, upper_full_support=cutoff, spline=spline, ) self.lower_support, self.upper_support = _support_bounds( self.first_knot, self.knot_spacing, self.coeff_shape[0], lower_full_support=self.full_support_start, ) @property def n_categories(self) -> int: """Return the number of atomic categories.""" assert self.atomic_types is not None return len(self.atomic_types) @property def provides_forces(self) -> bool: """Report that this term produces analytic forces.""" return True @property def input_requirements(self) -> TermInputRequirements: """Declare the directed full-neighbor-list requirement.""" return TermInputRequirements(full_neighbor_list=True) @property def active_triplet_categories(self) -> tuple[tuple[int, int, int], ...]: """Return the subset of triplet categories enabled for evaluation.""" return tuple( self.triplet_categories[index] for index in self._active_triplet_indices ) def _parameter_block_cache_descriptor(self) -> ParameterBlockCacheDescriptor: """Return reusable semantic cache metadata for 2D triplet coefficients.""" nx, ny = self.coeff_shape volume = int(nx * ny) return ParameterBlockCacheDescriptor( family={ "kind": "triplet2d_spline", "atomic_types": [int(value) for value in self.atomic_types or ()], "spline": str(self.spline), "first_knot": float(self.first_knot), "knot_spacing": float(self.knot_spacing), "lower_support": float(self.lower_support), "coeff_shape": [int(nx), int(ny)], "eps": float(self.eps), }, channels=tuple( ParameterBlockCacheChannel( kind="triplet2d", values=self.triplet_categories[triplet_index], start=int(triplet_index) * volume, stop=(int(triplet_index) + 1) * volume, ) for triplet_index in self._active_triplet_indices ), )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the categorized 2D triplet coefficient block.""" return ( ParameterBlock( name="coeffs_by_triplet", kind="triplet2d", shape=tuple(int(dim) for dim in self.coeffs_by_triplet.shape), read=lambda: self.coeffs_by_triplet, write=lambda values: copy_parameter_data( self.coeffs_by_triplet, values ), label=f"triplet2d[{self.atomic_types}]", regularization_group="threebody", fittable=self.fittable and bool(self._active_triplet_indices), frozen=self.frozen, assembler=self.assemble_linear_block, cache_descriptor=self._parameter_block_cache_descriptor(), ), )
def _bucket_triplets(self, inputs: UFPInput): """Build source-neighbor buckets for the supported atoms.""" assert self.atomic_types is not None node_cat = inputs.atomic_category_indices(self.atomic_types) supported_atoms = node_cat >= 0 if not torch.any(supported_atoms): return None, node_cat first_atom, second_atom = inputs.pair_indices() pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom] if not torch.any(pair_mask): return None, node_cat pair_distances = inputs.pair_distances(pair_mask) center_support_mask = (pair_distances >= self.lower_support) & ( pair_distances < self.upper_support ) if not torch.any(center_support_mask): return None, node_cat filtered_first, filtered_second = inputs.pair_indices(pair_mask) pair_vectors = inputs.pair_vectors(pair_mask) buckets = preprocess_sources( filtered_first[center_support_mask], filtered_second[center_support_mask], node_cat, self.n_categories, pair_vectors[center_support_mask], pair_distances[center_support_mask], ) return buckets, node_cat
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate the 2D triplet spline over full-neighbor-list triplets.""" self.validate_inputs(inputs) if not self._active_triplet_indices: return empty_atomwise_output(inputs, forces=True) buckets, node_cat = self._bucket_triplets(inputs) if not buckets: return empty_atomwise_output(inputs, forces=True) output = empty_atomwise_output(inputs, forces=True) assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None coeffs_by_triplet = self.coeffs_by_triplet.to( device=inputs.device, dtype=inputs.dtype, ) triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2 active_mask = self.active_triplet_mask.to(device=inputs.device) for pattern_index in range(int(buckets.patterns.shape[0])): src_start = int(buckets.pattern_ptr[pattern_index].item()) src_end = int(buckets.pattern_ptr[pattern_index + 1].item()) pattern = buckets.patterns[pattern_index] src_cat = int(pattern[0].item()) counts = pattern[1:].detach().cpu().tolist() layout = pattern_triplet_layout(counts, inputs.device) if layout.row.numel() == 0: continue degree = int(sum(counts)) edge_start = int(buckets.row_ptr[src_start].item()) edge_end = int(buckets.row_ptr[src_end].item()) src_ids = buckets.src_ids[src_start:src_end] src_system = inputs.system_index.index_select(0, src_ids) nbr_ids = buckets.nbr_ids[edge_start:edge_end].view( src_end - src_start, degree ) vec = buckets.pair_vectors[edge_start:edge_end].view( src_end - src_start, degree, 3 ) distances = buckets.pair_distances[edge_start:edge_end].view( src_end - src_start, degree ) triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat enabled = active_mask.index_select(0, triplet_index) if not torch.any(enabled): continue row = layout.row[enabled] col = layout.col[enabled] triplet_index = triplet_index[enabled] x = distances[:, row] y = distances[:, col] flat_x = x.reshape(-1) flat_y = y.reshape(-1) supported = spline_support_mask_2d( flat_x, flat_y, coeff_shape=self.coeff_shape, first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) if not torch.any(supported): continue flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1) stencil = uniform_stencil_2d( flat_x[supported], flat_y[supported], coeff_shape=self.coeff_shape, first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) flat_triplet_index = ( triplet_index[None, :] .expand(src_ids.shape[0], -1) .reshape(-1) .index_select(0, flat_positions) ) coeff_window = coeffs_by_triplet.reshape(len(self.triplet_categories), -1)[ flat_triplet_index[:, None], stencil.indices ] energy = (stencil.values * coeff_window).sum(dim=1) grad_x = (stencil.grad_x * coeff_window).sum(dim=1) grad_y = (stencil.grad_y * coeff_window).sum(dim=1) flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1) flat_src = flat_src.index_select(0, flat_positions) flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1) flat_system = flat_system.index_select(0, flat_positions) flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions) flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions) unit_j = ( vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions) / flat_x[supported].clamp_min(self.eps)[:, None] ) unit_k = ( vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions) / flat_y[supported].clamp_min(self.eps)[:, None] ) force_j = -grad_x[:, None] * unit_j force_k = -grad_y[:, None] * unit_k force_i = -(force_j + force_k) output.energy.index_add_(0, flat_system, energy) output.per_atom_energy.index_add_(0, flat_src, energy) output.forces.index_add_(0, flat_src, force_i) output.forces.index_add_(0, flat_j, force_j) output.forces.index_add_(0, flat_k, force_k) return output
[docs] def assemble_linear_block( self, block, inputs: UFPInput, targets: Any, ) -> torch.Tensor | None: """Assemble this term's dense least-squares block.""" if inputs.neighbor_list is None or not inputs.neighbor_list.full_list: raise RuntimeError("SplineTriplet2DTerm requires a full neighbor list") buckets, _ = self._bucket_triplets(inputs) if not buckets: return None n_triplet_categories, nx, ny = block.shape matrix = torch.zeros( (targets.n_rows, block.size), dtype=inputs.dtype, device=inputs.device, ) triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2 active_mask = self.active_triplet_mask.to(device=inputs.device) for pattern_index in range(int(buckets.patterns.shape[0])): src_start = int(buckets.pattern_ptr[pattern_index].item()) src_end = int(buckets.pattern_ptr[pattern_index + 1].item()) pattern = buckets.patterns[pattern_index] src_cat = int(pattern[0].item()) counts = pattern[1:].detach().cpu().tolist() layout = pattern_triplet_layout(counts, inputs.device) if layout.row.numel() == 0: continue degree = int(sum(counts)) edge_start = int(buckets.row_ptr[src_start].item()) edge_end = int(buckets.row_ptr[src_end].item()) src_ids = buckets.src_ids[src_start:src_end] src_system = inputs.system_index.index_select(0, src_ids) nbr_ids = buckets.nbr_ids[edge_start:edge_end].view( src_end - src_start, degree ) vec = buckets.pair_vectors[edge_start:edge_end].view( src_end - src_start, degree, 3 ) distances = buckets.pair_distances[edge_start:edge_end].view( src_end - src_start, degree ) triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat enabled = active_mask.index_select(0, triplet_index) if not torch.any(enabled): continue row = layout.row[enabled] col = layout.col[enabled] triplet_index = triplet_index[enabled] x = distances[:, row] y = distances[:, col] flat_x = x.reshape(-1) flat_y = y.reshape(-1) supported = spline_support_mask_2d( flat_x, flat_y, coeff_shape=(int(nx), int(ny)), first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) if not torch.any(supported): continue flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1) stencil = uniform_stencil_2d( flat_x[supported], flat_y[supported], coeff_shape=(int(nx), int(ny)), first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) flat_triplet_index = ( triplet_index[None, :] .expand(src_ids.shape[0], -1) .reshape(-1) .index_select(0, flat_positions) ) cols = stencil.indices + flat_triplet_index[:, None] * int(nx * ny) flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1) flat_src = flat_src.index_select(0, flat_positions) flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1) flat_system = flat_system.index_select(0, flat_positions) flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions) flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions) unit_j = ( vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions) / flat_x[supported].clamp_min(self.eps)[:, None] ) unit_k = ( vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions) / flat_y[supported].clamp_min(self.eps)[:, None] ) force_j = -(stencil.grad_x[:, :, None] * unit_j[:, None, :]) force_k = -(stencil.grad_y[:, :, None] * unit_k[:, None, :]) force_i = -(force_j + force_k) _add_entries( matrix, targets.energy_rows.index_select(0, flat_system)[:, None], cols, stencil.values, ) _add_entries( matrix, targets.per_atom_rows.index_select(0, flat_src)[:, None], cols, stencil.values, ) for atom_index, force in ( (flat_src, force_i), (flat_j, force_j), (flat_k, force_k), ): _add_entries( matrix, targets.force_rows.index_select(0, atom_index)[:, :, None], cols[:, None, :], force.permute(0, 2, 1), ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def assemble_selected_linear_block( self, block, inputs: UFPInput, targets: Any, selected_indices: Sequence[int], ) -> torch.Tensor | None: """Assemble only requested coefficient columns for this 2D triplet block.""" selected_indices = tuple(int(index) for index in selected_indices) if inputs.neighbor_list is None or not inputs.neighbor_list.full_list: raise RuntimeError("SplineTriplet2DTerm requires a full neighbor list") buckets, _ = self._bucket_triplets(inputs) if not buckets: return None n_triplet_categories, nx, ny = block.shape coeff_volume = int(nx * ny) selected_lookup = selected_column_lookup( selected_indices, block_size=block.size, device=inputs.device, ) selected_triplet_indices = torch.unique( torch.div( torch.as_tensor( [int(index) for index in selected_indices], dtype=torch.int64, device=inputs.device, ), coeff_volume, rounding_mode="floor", ) ) matrix = _selected_block_matrix( targets, selected_indices, device=inputs.device, dtype=inputs.dtype, ) triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2 active_mask = self.active_triplet_mask.to(device=inputs.device) for pattern_index in range(int(buckets.patterns.shape[0])): src_start = int(buckets.pattern_ptr[pattern_index].item()) src_end = int(buckets.pattern_ptr[pattern_index + 1].item()) pattern = buckets.patterns[pattern_index] src_cat = int(pattern[0].item()) counts = pattern[1:].detach().cpu().tolist() layout = pattern_triplet_layout(counts, inputs.device) if layout.row.numel() == 0: continue degree = int(sum(counts)) edge_start = int(buckets.row_ptr[src_start].item()) edge_end = int(buckets.row_ptr[src_end].item()) src_ids = buckets.src_ids[src_start:src_end] src_system = inputs.system_index.index_select(0, src_ids) nbr_ids = buckets.nbr_ids[edge_start:edge_end].view( src_end - src_start, degree ) vec = buckets.pair_vectors[edge_start:edge_end].view( src_end - src_start, degree, 3 ) distances = buckets.pair_distances[edge_start:edge_end].view( src_end - src_start, degree ) triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat enabled = active_mask.index_select(0, triplet_index) enabled = enabled & _selected_channel_mask( triplet_index, selected_triplet_indices, ) if not torch.any(enabled): continue row = layout.row[enabled] col = layout.col[enabled] triplet_index = triplet_index[enabled] x = distances[:, row] y = distances[:, col] flat_x = x.reshape(-1) flat_y = y.reshape(-1) supported = spline_support_mask_2d( flat_x, flat_y, coeff_shape=(int(nx), int(ny)), first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) if not torch.any(supported): continue flat_positions = torch.nonzero(supported, as_tuple=False).reshape(-1) stencil = uniform_stencil_2d( flat_x[supported], flat_y[supported], coeff_shape=(int(nx), int(ny)), first_knot_x=self.first_knot, first_knot_y=self.first_knot, knot_spacing_x=self.knot_spacing, knot_spacing_y=self.knot_spacing, spline=self.spline, ) flat_triplet_index = ( triplet_index[None, :] .expand(src_ids.shape[0], -1) .reshape(-1) .index_select(0, flat_positions) ) cols = stencil.indices + flat_triplet_index[:, None] * coeff_volume flat_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1) flat_src = flat_src.index_select(0, flat_positions) flat_system = src_system[:, None].expand(-1, row.numel()).reshape(-1) flat_system = flat_system.index_select(0, flat_positions) flat_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_positions) flat_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_positions) unit_j = ( vec[:, row, :].reshape(-1, 3).index_select(0, flat_positions) / flat_x[supported].clamp_min(self.eps)[:, None] ) unit_k = ( vec[:, col, :].reshape(-1, 3).index_select(0, flat_positions) / flat_y[supported].clamp_min(self.eps)[:, None] ) force_j = -(stencil.grad_x[:, :, None] * unit_j[:, None, :]) force_k = -(stencil.grad_y[:, :, None] * unit_k[:, None, :]) force_i = -(force_j + force_k) add_selected_entries( matrix, targets.energy_rows.index_select(0, flat_system)[:, None], cols, stencil.values, selected_lookup, ) add_selected_entries( matrix, targets.per_atom_rows.index_select(0, flat_src)[:, None], cols, stencil.values, selected_lookup, ) for atom_index, force in ( (flat_src, force_i), (flat_j, force_j), (flat_k, force_k), ): add_selected_entries( matrix, targets.force_rows.index_select(0, atom_index)[:, :, None], cols[:, None, :], force.permute(0, 2, 1), selected_lookup, ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble all requested 2D triplet blocks for this term.""" blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := self.assemble_linear_block(block, batch.inputs, targets)) is not None }
__all__ = [ "SplineKind", "SplineTriplet2DTerm", ]