Source code for ufp.terms.threebody

"""
Spline-based three-body interaction term implementation.

``SplineThreeBodyTerm`` is the stable user-facing term. Bucket containers,
feature caches, and evaluator helpers exported from this module are expert
diagnostics for benchmarks, tests, and performance investigations.
"""

from __future__ import annotations

import hashlib
import json
from collections.abc import Sequence
from pathlib import Path
from typing import Literal

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import (
    all_supported_uniform_stencil_3d,
    uniform_support_parameters,
)
from ufp.terms._base import (
    LinearAssemblyOptions,
    TermCacheOptions,
    TermInputRequirements,
    ThreeBodyTerm,
)
from ufp.terms._parameters import (
    ParameterBlock,
    ParameterBlockCacheChannel,
    ParameterBlockCacheDescriptor,
    copy_parameter_data,
)
from ufp.terms._selected_assembly import add_selected_entries, selected_column_lookup
from ufp.terms._shared import empty_atomwise_output
from ufp.terms._threebody_cache import (
    FeatureCacheMode,
    _build_dense_feature_cache_from_buckets,
    _dense_feature_cache_dir,
    _dense_feature_cache_metadata,
    _find_compatible_memmap_dense_feature_cache,
    _load_memmap_dense_feature_cache,
    load_memmap_threebody_feature_cache,
)
from ufp.terms._threebody_dense import (
    DenseThreeBodyFeatureCache,
    DenseTripletFeatureBlock,
    MemmapDenseThreeBodyFeatureCache,
    MemmapDenseTripletFeatureBlock,
    ThreeBodyDenseAtomFeatures,
    _build_dense_feature_cache_from_feature_cache,
    _dense_atom_features_from_feature_cache,
    _evaluate_dense_feature_cache_energy_forces,
    _selected_atom_indices,
    _symmetrize_dense_atom_features,
)
from ufp.terms._threebody_eval import (
    BucketedEnergyForceEvaluator,
    SplineKind,
    _neighbor_neighbor_cutoff,
    _same_neighbor_triplet_mask,
    _support_bounds,
    _symmetrize_same_neighbor_coeffs,
    evaluate_bucketed_energy_forces,
    get_eval_3d_with_grads,
    make_bucketed_energy_forces_evaluator,
)
from ufp.terms._threebody_eval import (
    Eval3DWithGrads as Eval3DWithGrads,
)
from ufp.terms._threebody_eval import (
    _evaluate_bucketed_energy_forces as _evaluate_bucketed_energy_forces,
)
from ufp.terms._threebody_features import (
    ThreeBodyFeatureBlock as ThreeBodyFeatureBlock,
)
from ufp.terms._threebody_features import (
    ThreeBodyFeatureCache as ThreeBodyFeatureCache,
)
from ufp.terms._threebody_features import (
    _build_feature_cache_from_buckets,
)
from ufp.terms._threebody_kernels import (
    preprocess_sources_native_or_torch,
)
from ufp.terms._threebody_ops import (
    Buckets,
    build_edge_category_table,
    num_edge_categories,
    pair_distance_partials_batched,
    pattern_triplet_layout,
    preprocess_sources,
)
from ufp.terms._threebody_runtime import (
    ThreeBodyRuntimeConfig,
    resolve_threebody_runtime_config,
)
from ufp.terms.alchemical import AlchemicalCoefficients
from ufp.terms.categories import (
    active_triplet_mask as _active_triplet_mask,
)
from ufp.terms.categories import (
    canonical_triplet as _canonical_triplet,
)
from ufp.terms.categories import (
    triplet_categories as _triplet_categories,
)


_THREEBODY_BUCKET_CACHE_KEY = "_ufp_threebody_buckets"
_THREEBODY_FEATURE_CACHE_KEY = "_ufp_threebody_features"


def _selected_block_matrix(
    targets,
    selected_indices: Sequence[int],
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Create a compact matrix for a selected coefficient block."""
    return torch.zeros(
        (targets.n_rows, len(tuple(selected_indices))),
        dtype=dtype,
        device=device,
    )


def _swapped_xy_cols(
    cols: torch.Tensor,
    coeff_volume: int,
    nx: int,
    ny: int,
    nz: int,
) -> torch.Tensor:
    """Return columns addressing the same coefficient with x/y swapped."""
    local = torch.remainder(cols, coeff_volume)
    block_offset = cols - local
    iz = torch.remainder(local, nz)
    iy = torch.remainder(torch.div(local, nz, rounding_mode="floor"), ny)
    ix = torch.div(local, ny * nz, rounding_mode="floor")
    return block_offset + ((iy * nx + ix) * nz + iz)


def _add_selected_threebody_entries(
    matrix: torch.Tensor,
    rows: torch.Tensor,
    cols: torch.Tensor,
    values: torch.Tensor,
    *,
    selected_lookup: torch.Tensor,
    same_triplet_mask: torch.Tensor,
    coeff_volume: int,
    nx: int,
    ny: int,
    nz: int,
) -> None:
    """Accumulate selected three-body rows with same-neighbor x/y tying."""
    if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
        return
    same_triplet_mask = same_triplet_mask.to(device=cols.device, dtype=torch.bool)
    if same_triplet_mask.numel() == 0 or not bool(torch.any(same_triplet_mask)):
        add_selected_entries(matrix, rows, cols, values, selected_lookup)
        return

    ordinary = ~same_triplet_mask
    if torch.any(ordinary):
        add_selected_entries(
            matrix,
            rows[ordinary],
            cols[ordinary],
            values[ordinary],
            selected_lookup,
        )

    if torch.any(same_triplet_mask):
        same_rows = rows[same_triplet_mask]
        same_cols = cols[same_triplet_mask]
        same_values = 0.5 * values[same_triplet_mask]
        add_selected_entries(
            matrix,
            same_rows,
            same_cols,
            same_values,
            selected_lookup,
        )
        add_selected_entries(
            matrix,
            same_rows,
            _swapped_xy_cols(same_cols, coeff_volume, nx, ny, nz),
            same_values,
            selected_lookup,
        )


def _selected_channel_mask(
    triplet_index: torch.Tensor,
    selected_triplet_indices: torch.Tensor | None,
) -> torch.Tensor:
    """Return which local triplet categories have selected coefficients."""
    if selected_triplet_indices is None:
        return torch.ones_like(triplet_index, dtype=torch.bool)
    mask = torch.zeros_like(triplet_index, dtype=torch.bool)
    for selected_index in selected_triplet_indices.detach().cpu().tolist():
        mask = mask | (triplet_index == int(selected_index))
    return mask


def _accumulate_selected_threebody_pairs(
    *,
    matrix: torch.Tensor,
    term: "SplineThreeBodyTerm",
    src_ids: torch.Tensor,
    src_system: torch.Tensor,
    triplet_index: torch.Tensor,
    row: torch.Tensor,
    col: torch.Tensor,
    vec: torch.Tensor,
    distances: torch.Tensor,
    nbr_sorted: torch.Tensor,
    coeff_shape: tuple[int, int, int, int],
    selected_triplet_index: int | None,
    selected_triplet_indices: torch.Tensor | None,
    selected_lookup: torch.Tensor,
    energy_rows: torch.Tensor,
    force_rows: torch.Tensor,
    per_atom_rows: torch.Tensor,
) -> None:
    """Accumulate one source-pattern's selected three-body columns."""
    if row.numel() == 0 or src_ids.numel() == 0:
        return

    _, nx, ny, nz = coeff_shape
    active_triplet_mask = term.active_triplet_mask.to(device=triplet_index.device)
    active_mask = active_triplet_mask.index_select(0, triplet_index)
    if selected_triplet_index is not None:
        active_mask = active_mask & (triplet_index == int(selected_triplet_index))
    else:
        active_mask = active_mask & _selected_channel_mask(
            triplet_index,
            selected_triplet_indices,
        )
    if not torch.any(active_mask):
        return

    row = row[active_mask]
    col = col[active_mask]
    triplet_index = triplet_index[active_mask]
    coeff_volume = int(nx * ny * nz)

    vj = vec[:, row, :]
    vk = vec[:, col, :]
    x = distances[:, row]
    y = distances[:, col]
    diff = vj - vk
    z = torch.linalg.norm(diff, dim=2)
    flat_x = x.reshape(-1)
    flat_y = y.reshape(-1)
    flat_z = z.reshape(-1)
    flat_mask = (
        (flat_x >= term.lower_support_xy)
        & (flat_x < term.upper_support_xy)
        & (flat_y >= term.lower_support_xy)
        & (flat_y < term.upper_support_xy)
        & (flat_z >= term.lower_support_z)
    )
    if not torch.any(flat_mask):
        return

    if bool(torch.all(flat_mask)):
        supported_x = flat_x
        supported_y = flat_y
        supported_z = flat_z
        flat_triplet_positions = torch.arange(
            flat_x.numel(),
            dtype=torch.int64,
            device=flat_x.device,
        )
    else:
        supported_x = flat_x[flat_mask]
        supported_y = flat_y[flat_mask]
        supported_z = flat_z[flat_mask]
        flat_triplet_positions = torch.nonzero(flat_mask, as_tuple=False).reshape(-1)

    stencil = all_supported_uniform_stencil_3d(
        supported_x,
        supported_y,
        supported_z,
        coeff_shape=(nx, ny, nz),
        first_knot_xy=term.first_knot_xy,
        first_knot_z=term.first_knot_z,
        knot_spacing_xy=term.knot_spacing_xy,
        knot_spacing_z=term.knot_spacing_z,
        spline=term.spline,
    )
    supported_src = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
    supported_src = supported_src.index_select(0, flat_triplet_positions)
    supported_system = src_system[:, None].expand(-1, row.numel()).reshape(-1)
    supported_system = supported_system.index_select(0, flat_triplet_positions)
    supported_dst_j = nbr_sorted[:, row].reshape(-1)
    supported_dst_j = supported_dst_j.index_select(0, flat_triplet_positions)
    supported_dst_k = nbr_sorted[:, col].reshape(-1)
    supported_dst_k = supported_dst_k.index_select(0, flat_triplet_positions)
    supported_triplet_index = (
        triplet_index[None, :]
        .expand(src_ids.shape[0], -1)
        .reshape(-1)
        .index_select(0, flat_triplet_positions)
    )

    supported_vj = vj.reshape(-1, 3).index_select(0, flat_triplet_positions)
    supported_vk = vk.reshape(-1, 3).index_select(0, flat_triplet_positions)
    supported_diff = diff.reshape(-1, 3).index_select(0, flat_triplet_positions)

    if selected_triplet_index is None:
        cols = stencil.indices + supported_triplet_index[:, None] * coeff_volume
    else:
        cols = stencil.indices

    d_e_dvj, d_e_dvk = pair_distance_partials_batched(
        stencil.grad_x,
        stencil.grad_y,
        stencil.grad_z,
        supported_vj,
        supported_vk,
        supported_diff,
        supported_x,
        supported_y,
        supported_z,
        term.eps,
    )

    force_j = -d_e_dvj
    force_k = -d_e_dvk
    force_i = d_e_dvj + d_e_dvk
    same_triplet_mask = term.same_neighbor_triplet_mask.to(
        device=supported_triplet_index.device
    ).index_select(
        0,
        supported_triplet_index,
    )

    _add_selected_threebody_entries(
        matrix,
        energy_rows.index_select(0, supported_system)[:, None],
        cols,
        stencil.values,
        selected_lookup=selected_lookup,
        same_triplet_mask=same_triplet_mask,
        coeff_volume=coeff_volume,
        nx=nx,
        ny=ny,
        nz=nz,
    )
    if torch.any(per_atom_rows >= 0):
        _add_selected_threebody_entries(
            matrix,
            per_atom_rows.index_select(0, supported_src)[:, None],
            cols,
            stencil.values,
            selected_lookup=selected_lookup,
            same_triplet_mask=same_triplet_mask,
            coeff_volume=coeff_volume,
            nx=nx,
            ny=ny,
            nz=nz,
        )
    if torch.any(force_rows >= 0):
        for atom_index, force in (
            (supported_src, force_i),
            (supported_dst_j, force_j),
            (supported_dst_k, force_k),
        ):
            _add_selected_threebody_entries(
                matrix,
                force_rows.index_select(0, atom_index)[:, :, None],
                cols[:, None, :],
                force.permute(0, 2, 1),
                selected_lookup=selected_lookup,
                same_triplet_mask=same_triplet_mask,
                coeff_volume=coeff_volume,
                nx=nx,
                ny=ny,
                nz=nz,
            )


[docs] class SplineThreeBodyTerm(ThreeBodyTerm): """ Source-distinguished three-body spline interaction term. """ def __init__( self, *, cutoff: float, atomic_types: Sequence[int], coeffs_by_triplet=None, coefficient_provider: AlchemicalCoefficients | None = None, coefficient_index: int | None = None, active_triplets: Sequence[tuple[int, int, int]] | None = None, spline: SplineKind = "cubic", full_support_start_xy: float = 0.0, full_support_start_z: float = 2.0, eps: float = 1.0e-12, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store categorized three-body coefficients and category layout.""" super().__init__(cutoff=cutoff, atomic_types=atomic_types) if self.atomic_types is None or not self.atomic_types: raise ValueError("`atomic_types` must contain at least one element") n_cat = len(self.atomic_types) expected_triplet_categories = n_cat * num_edge_categories(n_cat) triplet_categories = _triplet_categories(self.atomic_types) object.__setattr__(self, "_triplet_categories", triplet_categories) object.__setattr__( self, "_triplet_index", {triplet: index for index, triplet in enumerate(triplet_categories)}, ) self.coefficient_index = ( None if coefficient_index is None else int(coefficient_index) ) object.__setattr__(self, "_coefficient_provider", coefficient_provider) self.fittable = bool(fittable) self.frozen = bool(frozen) active_triplet_mask = _active_triplet_mask( triplet_categories, active_triplets=active_triplets, ) same_neighbor_triplet_mask = _same_neighbor_triplet_mask(triplet_categories) self.register_buffer( "active_triplet_mask", active_triplet_mask, persistent=False, ) self.register_buffer( "same_neighbor_triplet_mask", same_neighbor_triplet_mask, persistent=False, ) object.__setattr__( self, "_active_triplet_indices", tuple( index for index, enabled in enumerate(active_triplet_mask.tolist()) if enabled ), ) if coefficient_provider is None: if coeffs_by_triplet is None: raise ValueError( "`coeffs_by_triplet` is required when " "`coefficient_provider` is not set" ) coeffs_tensor = torch.as_tensor( coeffs_by_triplet, dtype=dtype, ) if coeffs_tensor.ndim != 4: raise ValueError( "`coeffs_by_triplet` must have shape " "(n_triplet_categories, Nx, Ny, Nz)" ) if coeffs_tensor.shape[0] != expected_triplet_categories: raise ValueError( "`coeffs_by_triplet.shape[0]` must equal " f"{expected_triplet_categories} for " f"atomic_types={self.atomic_types}, " f"got {coeffs_tensor.shape[0]}" ) self.coeffs_by_triplet = torch.nn.Parameter( coeffs_tensor, requires_grad=bool(trainable) and not self.frozen, ) coeff_shape = tuple(int(dim) for dim in coeffs_tensor.shape[1:]) else: provider_shape = coefficient_provider.coefficient_shape if len(provider_shape) == 3: if len(self._active_triplet_indices) != 1: raise ValueError( "three-dimensional three-body alchemical coefficients " "require exactly one active triplet" ) coeff_shape = tuple(int(dim) for dim in provider_shape) elif len(provider_shape) == 4: if provider_shape[0] != expected_triplet_categories: raise ValueError( "`coefficient_provider` must provide " f"{expected_triplet_categories} triplet categories for " f"atomic_types={self.atomic_types}, got " f"{provider_shape[0]}" ) coeff_shape = tuple(int(dim) for dim in provider_shape[1:]) else: raise ValueError( "`coefficient_provider` must provide three-dimensional " "single-triplet coefficients or four-dimensional categorized " "coefficients" ) if self.coefficient_index is None: raise ValueError( "`coefficient_index` is required when `coefficient_provider` is set" ) coefficient_provider.true_coeffs_for(self.coefficient_index) self.spline = spline self.full_support_start_xy = float(full_support_start_xy) self.full_support_start_z = float(full_support_start_z) if coeff_shape[0] != coeff_shape[1]: raise ValueError( "three-body coefficients must have matching x/y dimensions when " "using a shared center-neighbor support grid" ) object.__setattr__(self, "coeff_shape", coeff_shape) coeff_size_xy = coeff_shape[0] coeff_size_z = coeff_shape[2] self.first_knot_xy, self.knot_spacing_xy = uniform_support_parameters( coeff_size=coeff_size_xy, lower_full_support=self.full_support_start_xy, upper_full_support=cutoff, spline=self.spline, ) self.first_knot_z, self.knot_spacing_z = uniform_support_parameters( coeff_size=coeff_size_z, lower_full_support=self.full_support_start_z, upper_full_support=_neighbor_neighbor_cutoff(cutoff), spline=self.spline, ) self.lower_support_xy, self.upper_support_xy = _support_bounds( self.first_knot_xy, self.knot_spacing_xy, coeff_size_xy, lower_full_support=self.full_support_start_xy, ) self.lower_support_z, self.upper_support_z = _support_bounds( self.first_knot_z, self.knot_spacing_z, coeff_size_z, lower_full_support=self.full_support_start_z, ) self.eps = float(eps) self.register_buffer( "edge_cat_table", build_edge_category_table(n_cat), persistent=False, ) @property def n_categories(self) -> int: """Return the number of atomic categories tracked by this term.""" assert self.atomic_types is not None return len(self.atomic_types) @property def triplet_categories(self) -> tuple[tuple[int, int, int], ...]: """Return the ordered triplet categories addressed by the coefficient tensor.""" return self._triplet_categories @property def coefficient_provider(self) -> AlchemicalCoefficients | None: """Return the shared coefficient provider for alchemical fitting.""" return self._coefficient_provider @property def provides_forces(self) -> bool: """Report that this term produces analytic forces directly.""" return True @property def input_requirements(self) -> TermInputRequirements: """Declare the directed full-neighbor-list requirement.""" return TermInputRequirements(full_neighbor_list=True) @property def neighbor_neighbor_cutoff(self) -> float: """Return the derived cutoff used when forming neighbor-neighbor triplets.""" assert self.cutoff is not None return _neighbor_neighbor_cutoff(self.cutoff) @property def active_triplet_categories(self) -> tuple[tuple[int, int, int], ...]: """Return the subset of triplet categories that remain enabled.""" return tuple( self.triplet_categories[index] for index in self._active_triplet_indices ) @property def true_coeffs_by_triplet(self) -> torch.Tensor: """Return the direct or provider-projected triplet coefficient tensor.""" if self.coefficient_provider is None: coeffs = self.coeffs_by_triplet else: assert self.coefficient_index is not None coeffs = self.coefficient_provider.true_coeffs_for(self.coefficient_index) if coeffs.ndim == 3: triplet_index = self._active_triplet_indices[0] full_shape = ( len(self.triplet_categories), int(coeffs.shape[0]), int(coeffs.shape[1]), int(coeffs.shape[2]), ) full_coeffs = coeffs.new_zeros(full_shape) full_coeffs[triplet_index] = coeffs coeffs = full_coeffs return _symmetrize_same_neighbor_coeffs(coeffs, self.same_neighbor_triplet_mask) def _parameter_block_shape(self) -> tuple[int, ...]: """Return the solved coefficient-block shape for this term.""" if ( self.coefficient_provider is not None and len(self.coefficient_provider.coefficient_shape) == 3 ): assert self.coefficient_index is not None return tuple( int(dim) for dim in self.coefficient_provider.true_coeffs_for( self.coefficient_index ).shape ) return tuple(int(dim) for dim in self.true_coeffs_by_triplet.shape) def _read_parameter_block(self) -> torch.Tensor: """Return the coefficient tensor represented by the solve block.""" if ( self.coefficient_provider is not None and len(self.coefficient_provider.coefficient_shape) == 3 ): assert self.coefficient_index is not None return self.coefficient_provider.true_coeffs_for(self.coefficient_index) return self.true_coeffs_by_triplet def _write_parameter_block(self, values: torch.Tensor) -> None: """Write solved three-body coefficients back into storage.""" if self.coefficient_provider is None: copy_parameter_data(self.coeffs_by_triplet, values) return if not self.coefficient_provider.uses_identity_weights: raise ValueError( "can not write true coefficients directly into a non-identity " "alchemical provider" ) assert self.coefficient_index is not None target_shape = self._read_parameter_block().shape self.coefficient_provider.proxy_coeffs.data[self.coefficient_index].copy_( values.reshape(target_shape).to(self.coefficient_provider.proxy_coeffs) ) def _parameter_block_cache_descriptor(self) -> ParameterBlockCacheDescriptor | None: """Return reusable semantic cache metadata for this coefficient block.""" shape = self._parameter_block_shape() if len(shape) == 3: if len(self._active_triplet_indices) != 1: return None nx, ny, nz = shape coeff_shape = (int(nx), int(ny), int(nz)) triplet_indices = tuple( int(index) for index in self._active_triplet_indices ) starts = {triplet_indices[0]: 0} elif len(shape) == 4: _, nx, ny, nz = shape coeff_shape = (int(nx), int(ny), int(nz)) triplet_indices = tuple( int(index) for index in self._active_triplet_indices ) volume = int(nx) * int(ny) * int(nz) starts = { int(triplet_index): int(triplet_index) * volume for triplet_index in triplet_indices } else: return None volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2]) return ParameterBlockCacheDescriptor( family={ "kind": "threebody_spline", "atomic_types": [int(value) for value in self.atomic_types or ()], "spline": str(self.spline), "first_knot_xy": float(self.first_knot_xy), "first_knot_z": float(self.first_knot_z), "knot_spacing_xy": float(self.knot_spacing_xy), "knot_spacing_z": float(self.knot_spacing_z), "lower_support_xy": float(self.lower_support_xy), "lower_support_z": float(self.lower_support_z), "coeff_shape": [int(value) for value in coeff_shape], "eps": float(self.eps), }, channels=tuple( ParameterBlockCacheChannel( kind="triplet", values=self.triplet_categories[triplet_index], start=int(starts[triplet_index]), stop=int(starts[triplet_index]) + volume, ) for triplet_index in triplet_indices ), )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the three-body spline coefficient block.""" return ( ParameterBlock( name="coeffs_by_triplet", kind="threebody", shape=self._parameter_block_shape(), read=self._read_parameter_block, write=self._write_parameter_block, label=f"threebody[{self.atomic_types}]", coefficient_provider=self.coefficient_provider, coefficient_index=self.coefficient_index, regularization_group="threebody", fittable=self.fittable and bool(self._active_triplet_indices), frozen=self.frozen, assembler="threebody", cache_descriptor=self._parameter_block_cache_descriptor(), ), )
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble three-body least-squares blocks for this term.""" from ufp.leastsquares._assemble import _assemble_threebody_block blocks = () if options is None else options.blocks threebody_lstsq_backend = ( None if options is None else options.threebody_lstsq_backend ) threebody_bucket_backend = ( None if options is None else options.threebody_bucket_backend ) runtime_config = None if options is None else options.threebody_runtime_config matrices = {} for block in blocks: matrix = _assemble_threebody_block( block, batch.inputs, targets, threebody_lstsq_backend=threebody_lstsq_backend, threebody_bucket_backend=threebody_bucket_backend, threebody_runtime_config=runtime_config, ) if matrix is not None: matrices[block.index] = matrix return matrices
[docs] def assemble_selected_linear_block( self, block, inputs: UFPInput, targets, selected_indices: Sequence[int], ) -> torch.Tensor | None: """Assemble only requested coefficient columns for this three-body block.""" selected_indices = tuple(int(index) for index in selected_indices) if inputs.neighbor_list is None: raise RuntimeError("SplineThreeBodyTerm requires a neighbor list") if not inputs.neighbor_list.full_list: raise RuntimeError("SplineThreeBodyTerm requires a full neighbor list") assert self.atomic_types is not None node_cat = inputs.atomic_category_indices(self.atomic_types) supported_atoms = node_cat >= 0 if not torch.any(supported_atoms): return None first_atom, second_atom = inputs.pair_indices() pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom] if not torch.any(pair_mask): return None pair_distances = inputs.pair_distances(pair_mask) center_support_mask = (pair_distances >= self.lower_support_xy) & ( pair_distances < self.upper_support_xy ) if not torch.any(center_support_mask): return None filtered_first, filtered_second = inputs.pair_indices(pair_mask) pair_vectors = inputs.pair_vectors(pair_mask) filtered_first = filtered_first[center_support_mask] filtered_second = filtered_second[center_support_mask] pair_vectors = pair_vectors[center_support_mask] pair_distances = pair_distances[center_support_mask] buckets = preprocess_sources( filtered_first, filtered_second, node_cat, self.n_categories, pair_vectors, pair_distances, ) if not buckets: return None if len(block.shape) == 3: if len(self._active_triplet_indices) != 1: raise ValueError( "single-triplet three-body alchemical blocks require exactly one " "active triplet" ) nx, ny, nz = block.shape n_triplet_categories = len(self.triplet_categories) selected_triplet_index = int(self._active_triplet_indices[0]) selected_triplet_indices = None else: n_triplet_categories, nx, ny, nz = block.shape coeff_volume = int(nx * ny * nz) selected_triplet_index = None selected_triplet_indices = torch.unique( torch.div( torch.as_tensor( [int(index) for index in selected_indices], dtype=torch.int64, device=inputs.device, ), coeff_volume, rounding_mode="floor", ) ) selected_lookup = selected_column_lookup( selected_indices, block_size=block.size, device=inputs.device, ) matrix = _selected_block_matrix( targets, selected_indices, device=inputs.device, dtype=inputs.dtype, ) triplets_per_src_cat = self.n_categories * (self.n_categories + 1) // 2 system_index = inputs.system_index for pattern_index in range(int(buckets.patterns.shape[0])): src_start = int(buckets.pattern_ptr[pattern_index].item()) src_end = int(buckets.pattern_ptr[pattern_index + 1].item()) pattern = buckets.patterns[pattern_index] src_cat = int(pattern[0].item()) counts = pattern[1:].detach().cpu().tolist() layout = pattern_triplet_layout(counts, inputs.device) if layout.row.numel() == 0: continue degree = int(sum(counts)) edge_start = int(buckets.row_ptr[src_start].item()) edge_end = int(buckets.row_ptr[src_end].item()) src_ids = buckets.src_ids[src_start:src_end] nbr_ids = buckets.nbr_ids[edge_start:edge_end].view( src_end - src_start, degree, ) vectors = buckets.pair_vectors[edge_start:edge_end].view( src_end - src_start, degree, 3, ) distances = buckets.pair_distances[edge_start:edge_end].view( src_end - src_start, degree, ) triplet_index = src_cat * triplets_per_src_cat + layout.edge_cat _accumulate_selected_threebody_pairs( matrix=matrix, term=self, src_ids=src_ids, src_system=system_index.index_select(0, src_ids), triplet_index=triplet_index, row=layout.row, col=layout.col, vec=vectors, distances=distances, nbr_sorted=nbr_ids, coeff_shape=( int(n_triplet_categories), int(nx), int(ny), int(nz), ), selected_triplet_index=selected_triplet_index, selected_triplet_indices=selected_triplet_indices, selected_lookup=selected_lookup, energy_rows=targets.energy_rows, force_rows=targets.force_rows, per_atom_rows=targets.per_atom_rows, ) return None if torch.count_nonzero(matrix) == 0 else matrix
def _cache_key(self) -> str: """Return this term's key inside per-input three-body bucket caches.""" payload = { "atomic_types": list(self.atomic_types or ()), "coeff_shape": list(self.coeff_shape), "active_triplet_indices": list(self._active_triplet_indices), "spline": self.spline, "first_knot_xy": self.first_knot_xy, "first_knot_z": self.first_knot_z, "knot_spacing_xy": self.knot_spacing_xy, "knot_spacing_z": self.knot_spacing_z, "lower_support_xy": self.lower_support_xy, "lower_support_z": self.lower_support_z, "eps": self.eps, } encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode( "utf8" ) return hashlib.sha256(encoded).hexdigest()[:16] def _bucket_triplets( self, inputs: UFPInput, node_cat: torch.Tensor, *, attach_pattern_plans: bool, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> Buckets: """Build reusable triplet buckets for this term and input geometry.""" config = resolve_threebody_runtime_config(runtime_config) supported_atoms = node_cat >= 0 if not torch.any(supported_atoms): return preprocess_sources_native_or_torch( torch.zeros((0,), dtype=torch.int64, device=inputs.device), torch.zeros((0,), dtype=torch.int64, device=inputs.device), node_cat, self.n_categories, inputs.positions.new_zeros((0, 3)), inputs.positions.new_zeros((0,)), runtime_config=config, ) first_atom, second_atom = inputs.pair_indices() pair_mask = supported_atoms[first_atom] & supported_atoms[second_atom] if not torch.any(pair_mask): return preprocess_sources_native_or_torch( torch.zeros((0,), dtype=torch.int64, device=inputs.device), torch.zeros((0,), dtype=torch.int64, device=inputs.device), node_cat, self.n_categories, inputs.positions.new_zeros((0, 3)), inputs.positions.new_zeros((0,)), runtime_config=config, ) pair_distances = inputs.pair_distances(pair_mask) center_support_mask = (pair_distances >= self.lower_support_xy) & ( pair_distances < self.upper_support_xy ) if not torch.any(center_support_mask): return preprocess_sources_native_or_torch( torch.zeros((0,), dtype=torch.int64, device=inputs.device), torch.zeros((0,), dtype=torch.int64, device=inputs.device), node_cat, self.n_categories, inputs.positions.new_zeros((0, 3)), inputs.positions.new_zeros((0,)), runtime_config=config, ) filtered_first, filtered_second = inputs.pair_indices(pair_mask) pair_vectors = inputs.pair_vectors(pair_mask) filtered_first = filtered_first[center_support_mask] filtered_second = filtered_second[center_support_mask] pair_vectors = pair_vectors[center_support_mask] pair_distances = pair_distances[center_support_mask] buckets = preprocess_sources_native_or_torch( filtered_first, filtered_second, node_cat, self.n_categories, pair_vectors, pair_distances, runtime_config=config, ) if attach_pattern_plans: if buckets.tensor_pattern_plans is not None: return buckets return buckets.with_pattern_plans(inputs.device) return buckets
[docs] def cache_input( self, inputs: UFPInput, options: TermCacheOptions | None = None, *, feature_cache_storage: Literal["none", "cpu", "disk"] = "cpu", feature_cache_mode: FeatureCacheMode = "auto", feature_cache_dir: Path | str | None = None, cache_prefix: str = "threebody", legacy_cache_prefixes: Sequence[str] = (), include_per_atom_energy: bool = True, ) -> None: """Precompute static dense three-body feature blocks for a cached input.""" if options is not None: feature_cache_storage = options.feature_cache_storage feature_cache_mode = options.feature_cache_mode # type: ignore[assignment] feature_cache_dir = options.feature_cache_dir cache_prefix = options.cache_prefix include_per_atom_energy = options.include_per_atom_energy del legacy_cache_prefixes if feature_cache_mode not in {"auto", "read", "refresh"}: raise ValueError( "`feature_cache_mode` must be 'auto', 'read', or 'refresh'" ) if feature_cache_mode == "read" and feature_cache_storage != "disk": raise ValueError("`feature_cache_mode='read'` requires disk feature cache") if inputs.neighbor_list is None or not inputs.neighbor_list.full_list: return if not self._active_triplet_indices: return assert self.atomic_types is not None runtime_config = resolve_threebody_runtime_config() bucket_cache = dict(inputs.metadata.get(_THREEBODY_BUCKET_CACHE_KEY, {})) feature_blocks = dict(inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY, {})) coeff_shape = tuple(int(value) for value in self.coeff_shape) cache_dir = None if feature_cache_dir is None else Path(feature_cache_dir) cache_key = self._cache_key() disk_prefix = f"{cache_prefix}_term{cache_key}" metadata = _dense_feature_cache_metadata( inputs, cache_key=cache_key, atomic_types=self.atomic_types, triplet_categories=self.triplet_categories, coeff_shape=coeff_shape, active_triplet_indices=self._active_triplet_indices, include_per_atom_energy=include_per_atom_energy, spline=self.spline, first_knot_xy=self.first_knot_xy, first_knot_z=self.first_knot_z, knot_spacing_xy=self.knot_spacing_xy, knot_spacing_z=self.knot_spacing_z, lower_support_xy=self.lower_support_xy, lower_support_z=self.lower_support_z, eps=self.eps, ) cached_disk_features = None if ( feature_cache_storage == "disk" and cache_dir is not None and feature_cache_mode != "refresh" ): settings_dir = _dense_feature_cache_dir(cache_dir, disk_prefix, metadata) try: cached_disk_features = _load_memmap_dense_feature_cache( settings_dir, disk_prefix, expected_metadata=metadata, required_triplet_indices=self._active_triplet_indices, ) except (OSError, ValueError, json.JSONDecodeError): cached_disk_features = None if cached_disk_features is None: try: cached_disk_features = _load_memmap_dense_feature_cache( cache_dir, disk_prefix, expected_metadata=metadata, required_triplet_indices=self._active_triplet_indices, ) except (OSError, ValueError, json.JSONDecodeError): cached_disk_features = None if cached_disk_features is None: try: cached_disk_features = _find_compatible_memmap_dense_feature_cache( cache_dir, expected_metadata=metadata, required_triplet_indices=self._active_triplet_indices, ) except (OSError, ValueError, json.JSONDecodeError): cached_disk_features = None if cached_disk_features is not None: feature_blocks[cache_key] = cached_disk_features inputs.metadata[_THREEBODY_FEATURE_CACHE_KEY] = feature_blocks return if feature_cache_mode == "read" and feature_cache_storage == "disk": raise FileNotFoundError( "three-body feature cache requested in read mode, but no compatible " f"V2 cache was found for prefix {disk_prefix!r}" ) node_cat = inputs.atomic_category_indices(self.atomic_types) buckets = self._bucket_triplets( inputs, node_cat, attach_pattern_plans=True, runtime_config=runtime_config, ) bucket_cache[cache_key] = buckets inputs.metadata[_THREEBODY_BUCKET_CACHE_KEY] = bucket_cache if not buckets: return if feature_cache_storage == "none": return feature_blocks[cache_key] = _build_dense_feature_cache_from_buckets( buckets, inputs.system_index, coeff_shape, spline=self.spline, active_triplet_mask=( None if len(self._active_triplet_indices) == len(self.triplet_categories) else self.active_triplet_mask ), n_cat=self.n_categories, first_knot_xy=self.first_knot_xy, first_knot_z=self.first_knot_z, knot_spacing_xy=self.knot_spacing_xy, knot_spacing_z=self.knot_spacing_z, lower_support_xy=self.lower_support_xy, lower_support_z=self.lower_support_z, eps=self.eps, storage=feature_cache_storage, cache_dir=cache_dir, cache_prefix=disk_prefix, metadata=metadata, overwrite=feature_cache_mode == "refresh", include_per_atom_energy=include_per_atom_energy, runtime_config=runtime_config, ) inputs.metadata[_THREEBODY_FEATURE_CACHE_KEY] = feature_blocks
[docs] def dense_atom_features( self, inputs: UFPInput, atom_indices: Sequence[int] | torch.Tensor | None = None, *, force_scope: Literal["output", "source"] = "output", runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> ThreeBodyDenseAtomFeatures: """ Return dense coefficient-space output rows for selected atoms. This is intended for debugging fixed-geometry feature construction. The returned rows are the dense equivalent of the sparse cached operators used for this term: one per-atom energy row and one force row for each Cartesian component. With ``force_scope="output"``, force rows are the full model-output rows for the selected atoms. With ``force_scope="source"``, force rows include only interactions centered on the selected atoms. Args: inputs: Normalized input bundle with a full neighbor list. atom_indices: Optional atom indices to extract. If omitted, all atoms are returned in input order. force_scope: Whether force rows should include all output contributions or only source-centered contributions. Returns: Dense per-atom energy and force-component feature rows. Raises: RuntimeError: If the input lacks a full neighbor list. ValueError: If ``force_scope`` is not ``"output"`` or ``"source"``. """ if force_scope not in {"output", "source"}: raise ValueError("`force_scope` must be 'output' or 'source'") if inputs.neighbor_list is None: raise RuntimeError( "SplineThreeBodyTerm.dense_atom_features requires a neighbor list" ) if not inputs.neighbor_list.full_list: raise RuntimeError( "SplineThreeBodyTerm.dense_atom_features requires a full neighbor list" ) selected_atoms = _selected_atom_indices( inputs.n_atoms, atom_indices, device=inputs.device, ) runtime_config = resolve_threebody_runtime_config(runtime_config) coeff_shape = tuple(int(value) for value in self.coeff_shape) dense_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache dense_cache = DenseThreeBodyFeatureCache(blocks=()) if self._active_triplet_indices: cached_features = None if force_scope == "output": feature_cache = inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY) if isinstance(feature_cache, dict): cached_features = feature_cache.get(self._cache_key()) if isinstance( cached_features, (DenseThreeBodyFeatureCache, MemmapDenseThreeBodyFeatureCache), ) and all( block.per_atom_energy is not None for block in cached_features.blocks ): dense_cache = cached_features else: assert self.atomic_types is not None node_cat = inputs.atomic_category_indices(self.atomic_types) buckets = self._bucket_triplets( inputs, node_cat, attach_pattern_plans=True, runtime_config=runtime_config, ) if buckets: feature_cache = _build_feature_cache_from_buckets( buckets, coeff_shape, spline=self.spline, active_triplet_mask=( None if len(self._active_triplet_indices) == len(self.triplet_categories) else self.active_triplet_mask ), n_cat=self.n_categories, first_knot_xy=self.first_knot_xy, first_knot_z=self.first_knot_z, knot_spacing_xy=self.knot_spacing_xy, knot_spacing_z=self.knot_spacing_z, lower_support_xy=self.lower_support_xy, lower_support_z=self.lower_support_z, eps=self.eps, runtime_config=runtime_config, ) dense_cache = _build_dense_feature_cache_from_feature_cache( feature_cache, inputs.system_index, coeff_shape=coeff_shape, force_scope=force_scope, runtime_config=runtime_config, ) features = _dense_atom_features_from_feature_cache( dense_cache, selected_atoms, n_triplet_categories=len(self.triplet_categories), coeff_shape=coeff_shape, dtype=inputs.dtype, ) return _symmetrize_dense_atom_features( features, self.same_neighbor_triplet_mask, coeff_shape=coeff_shape, )
[docs] def canonical_triplet( self, source: int, first_neighbor: int, second_neighbor: int, ) -> tuple[int, int, int]: """Normalize a triplet key using the term's neighbor-ordering convention.""" return _canonical_triplet(source, first_neighbor, second_neighbor)
[docs] def triplet_category_index( self, source: int, first_neighbor: int, second_neighbor: int, ) -> int: """Return the coefficient-block index for one canonical triplet.""" triplet = self.canonical_triplet(source, first_neighbor, second_neighbor) try: return self._triplet_index[triplet] except KeyError as exc: raise KeyError(f"triplet {triplet} is not part of this term") from exc
[docs] def is_triplet_active( self, source: int, first_neighbor: int, second_neighbor: int, ) -> bool: """Report whether a canonical triplet category remains enabled.""" return bool( self.active_triplet_mask[ self.triplet_category_index(source, first_neighbor, second_neighbor) ].item() )
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Build local triplet buckets and return one three-body contribution.""" if inputs.neighbor_list is None: raise RuntimeError( "SplineThreeBodyTerm requires a neighbor list, but `inputs` does not " "contain one" ) if not inputs.neighbor_list.full_list: raise RuntimeError("SplineThreeBodyTerm requires a full neighbor list") if not self._active_triplet_indices: return empty_atomwise_output(inputs, forces=True) assert self.atomic_types is not None runtime_config = resolve_threebody_runtime_config() node_cat = inputs.atomic_category_indices(self.atomic_types) cached_features = None cached_buckets = None if not inputs.positions.requires_grad: feature_cache = inputs.metadata.get(_THREEBODY_FEATURE_CACHE_KEY) if isinstance(feature_cache, dict): cached_features = feature_cache.get(self._cache_key()) cache = inputs.metadata.get(_THREEBODY_BUCKET_CACHE_KEY) if isinstance(cache, dict): cached_buckets = cache.get(self._cache_key()) coeffs = self.true_coeffs_by_triplet.to( device=inputs.device, dtype=inputs.dtype, ) if isinstance( cached_features, (DenseThreeBodyFeatureCache, MemmapDenseThreeBodyFeatureCache), ): energy, per_atom_energy, forces = ( _evaluate_dense_feature_cache_energy_forces( cached_features, coeffs, n_nodes=inputs.n_atoms, n_systems=inputs.n_systems, ) ) return UFPOutput( energy=energy, forces=forces, per_atom_energy=per_atom_energy, ) if isinstance(cached_buckets, Buckets): buckets = cached_buckets else: buckets = self._bucket_triplets( inputs, node_cat, attach_pattern_plans=False, runtime_config=runtime_config, ) if not buckets: return empty_atomwise_output(inputs, forces=True) edge_cat_table = self.edge_cat_table.to(device=inputs.device) active_triplet_mask = ( None if len(self._active_triplet_indices) == len(self.triplet_categories) else self.active_triplet_mask.to(device=inputs.device) ) per_atom_energy, forces = evaluate_bucketed_energy_forces( buckets, node_cat, coeffs, edge_cat_table, spline=self.spline, active_triplet_mask=active_triplet_mask, n_nodes=inputs.n_atoms, n_cat=self.n_categories, first_knot_xy=self.first_knot_xy, first_knot_z=self.first_knot_z, knot_spacing_xy=self.knot_spacing_xy, knot_spacing_z=self.knot_spacing_z, lower_support_xy=self.lower_support_xy, lower_support_z=self.lower_support_z, eps=self.eps, runtime_config=runtime_config, ) energy = torch.zeros(inputs.n_systems, device=inputs.device, dtype=inputs.dtype) energy.index_add_(0, inputs.system_index, per_atom_energy) return UFPOutput( energy=energy, forces=forces, per_atom_energy=per_atom_energy, )
__all__ = [ "BucketedEnergyForceEvaluator", "Buckets", "DenseThreeBodyFeatureCache", "DenseTripletFeatureBlock", "MemmapDenseThreeBodyFeatureCache", "MemmapDenseTripletFeatureBlock", "SplineKind", "SplineThreeBodyTerm", "ThreeBodyDenseAtomFeatures", "ThreeBodyTerm", "build_edge_category_table", "evaluate_bucketed_energy_forces", "get_eval_3d_with_grads", "load_memmap_threebody_feature_cache", "make_bucketed_energy_forces_evaluator", "num_edge_categories", "preprocess_sources", ]