Source code for ufp.terms._threebody_eval

"""Bucketed three-body spline evaluation helpers."""

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Literal

import torch

from ufp.splines._cubic import cubic_eval_3d_with_grads
from ufp.splines._quadratic import quadratic_eval_3d_with_grads
from ufp.splines._quartic import quartic_eval_3d_with_grads
from ufp.splines.representation import (
    all_supported_uniform_stencil_3d,
    spline_support_mask_3d,
    supported_uniform_stencil_3d,
)
from ufp.terms._threebody_kernels import evaluate_bucketed_energy_forces_native_or_torch
from ufp.terms._threebody_ops import (
    Buckets,
    build_bucket_pattern_plans,
    pair_distance_partials,
)
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig
from ufp.terms.categories import (
    active_triplet_mask as _active_triplet_mask,  # noqa: F401
)
from ufp.terms.categories import (
    canonical_triplet as _canonical_triplet,  # noqa: F401
)
from ufp.terms.categories import (
    triplet_categories as _triplet_categories,  # noqa: F401
)


Eval3DWithGrads = Callable[
    [
        float,
        float,
        float,
        float,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ],
    tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
BucketedEnergyForceEvaluator = Callable[
    [Buckets, torch.Tensor, torch.Tensor, torch.Tensor],
    tuple[torch.Tensor, torch.Tensor],
]
SplineKind = Literal["quadratic", "cubic", "quartic"]

_SPLINE_EVAL_3D_WITH_GRADS: dict[str, Eval3DWithGrads] = {
    "quadratic": quadratic_eval_3d_with_grads,
    "cubic": cubic_eval_3d_with_grads,
    "quartic": quartic_eval_3d_with_grads,
}


[docs] def get_eval_3d_with_grads(spline: SplineKind | str) -> Eval3DWithGrads: """Return the 3D spline evaluator for the requested basis family.""" try: return _SPLINE_EVAL_3D_WITH_GRADS[spline] except KeyError as exc: choices = ", ".join(sorted(_SPLINE_EVAL_3D_WITH_GRADS)) raise ValueError( f"Unsupported spline '{spline}'. Expected one of: {choices}." ) from exc
def _same_neighbor_triplet_mask( triplet_categories: Sequence[tuple[int, int, int]], ) -> torch.Tensor: """Return triplet channels whose two neighbor categories are identical.""" return torch.tensor( [ int(first_neighbor) == int(second_neighbor) for _, first_neighbor, second_neighbor in triplet_categories ], dtype=torch.bool, ) def _symmetrize_same_neighbor_coeffs( coeffs_by_triplet: torch.Tensor, same_neighbor_triplet_mask: torch.Tensor, ) -> torch.Tensor: """Tie x/y coefficient axes for same-neighbor-category triplet channels.""" same_mask = same_neighbor_triplet_mask.to(device=coeffs_by_triplet.device) if same_mask.numel() == 0 or not bool(torch.any(same_mask)): return coeffs_by_triplet output = coeffs_by_triplet.clone() same_coeffs = output[same_mask] output[same_mask] = 0.5 * (same_coeffs + same_coeffs.transpose(1, 2)) return output def _neighbor_neighbor_cutoff(center_neighbor_cutoff: float) -> float: """Derive the outer cutoff implied by the center-neighbor cutoff.""" return 2.0 * float(center_neighbor_cutoff) def _support_bounds( first_knot: float, knot_spacing: float, coeff_size: int, *, lower_full_support: float, ) -> tuple[float, float]: """Return the physical support interval used before stencil construction.""" return float(lower_full_support), float(first_knot + int(coeff_size) * knot_spacing) def _evaluate_pair_block( vec: torch.Tensor, r: torch.Tensor, nbr_ids: torch.Tensor, coeffs_by_triplet: torch.Tensor, edge_cat_table: torch.Tensor, eval_3d_with_grads: Eval3DWithGrads, triplets_per_src_cat: int, src_cat: int, cat_a: int, cat_b: int, offset_a: int, offset_b: int, row: torch.Tensor, col: torch.Tensor, first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, eps: float, spline: SplineKind | str, active_triplet_mask: torch.Tensor | None, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """Evaluate one bucketed triplet block and return its energy and force pieces.""" vj = vec[:, offset_a + row, :] vk = vec[:, offset_b + col, :] x = r[:, offset_a + row] y = r[:, offset_b + col] diff = vj - vk z = torch.linalg.norm(diff, dim=2) batch_size, n_pairs = x.shape triplet_idx = src_cat * triplets_per_src_cat + edge_cat_table[cat_a, cat_b] if active_triplet_mask is not None and not bool(active_triplet_mask[triplet_idx]): empty_dst = nbr_ids[:, offset_a + row][:, :0] empty_force = vec.new_zeros((batch_size, 0, 3)) return ( coeffs_by_triplet.new_zeros((batch_size,)), coeffs_by_triplet.new_zeros((batch_size, 3)), empty_dst, empty_force, empty_dst, empty_force, ) coeffs = coeffs_by_triplet[triplet_idx] supported = supported_uniform_stencil_3d( x.reshape(batch_size * n_pairs), y.reshape(batch_size * n_pairs), z.reshape(batch_size * n_pairs), coeff_shape=tuple(int(value) for value in coeffs.shape), first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, ) flat_mask = supported.mask.reshape(-1) flat_x = x.reshape(-1) stencil = supported.stencil coeff_window = coeffs.reshape(-1)[stencil.indices] e = (stencil.values * coeff_window).sum(dim=1) ex = (stencil.grad_x * coeff_window).sum(dim=1) ey = (stencil.grad_y * coeff_window).sum(dim=1) ez = (stencil.grad_z * coeff_window).sum(dim=1) flat_e = flat_x.new_zeros(flat_x.shape) flat_ex = flat_x.new_zeros(flat_x.shape) flat_ey = flat_x.new_zeros(flat_x.shape) flat_ez = flat_x.new_zeros(flat_x.shape) flat_e[flat_mask] = e flat_ex[flat_mask] = ex flat_ey[flat_mask] = ey flat_ez[flat_mask] = ez e = flat_e.view(batch_size, n_pairs) ex = flat_ex.view(batch_size, n_pairs) ey = flat_ey.view(batch_size, n_pairs) ez = flat_ez.view(batch_size, n_pairs) flat_vj = vj.reshape(batch_size * n_pairs, 3) flat_vk = vk.reshape(batch_size * n_pairs, 3) flat_diff = diff.reshape(batch_size * n_pairs, 3) supported_dE_dvj, supported_dE_dvk = pair_distance_partials( flat_ex[flat_mask], flat_ey[flat_mask], flat_ez[flat_mask], flat_vj[flat_mask], flat_vk[flat_mask], flat_diff[flat_mask], supported.x, supported.y, supported.z, eps, ) flat_dE_dvj = flat_vj.new_zeros(flat_vj.shape) flat_dE_dvk = flat_vk.new_zeros(flat_vk.shape) flat_dE_dvj[flat_mask] = supported_dE_dvj flat_dE_dvk[flat_mask] = supported_dE_dvk force_j = -flat_dE_dvj.view(batch_size, n_pairs, 3) force_k = -flat_dE_dvk.view(batch_size, n_pairs, 3) force_i = (flat_dE_dvj + flat_dE_dvk).view(batch_size, n_pairs, 3) dst_j = nbr_ids[:, offset_a + row] dst_k = nbr_ids[:, offset_b + col] return ( e.sum(dim=1), force_i.sum(dim=1), dst_j, force_j, dst_k, force_k, ) def _evaluate_bucketed_energy_forces_torch( buckets: Buckets, node_cat: torch.Tensor, coeffs_by_triplet: torch.Tensor, edge_cat_table: torch.Tensor, eval_3d_with_grads: Eval3DWithGrads, *, spline: SplineKind = "cubic", active_triplet_mask: torch.Tensor | None = None, n_nodes: int, n_cat: int = 10, first_knot_xy: float = 0.0, first_knot_z: float = 0.0, knot_spacing_xy: float = 0.25, knot_spacing_z: float = 0.25, lower_support_xy: float = 0.0, lower_support_z: float = 0.0, eps: float = 1.0e-12, ) -> tuple[torch.Tensor, torch.Tensor]: """Walk bucketed triplets and accumulate one coefficient tensor.""" del node_cat, edge_cat_table, eval_3d_with_grads device = coeffs_by_triplet.device if active_triplet_mask is not None: active_triplet_mask = active_triplet_mask.to(device=device) triplets_per_src_cat = n_cat * (n_cat + 1) // 2 coeff_volume = int( coeffs_by_triplet.shape[1] * coeffs_by_triplet.shape[2] * coeffs_by_triplet.shape[3] ) coeffs_flat = coeffs_by_triplet.reshape(-1) coeff_shape = tuple(int(value) for value in coeffs_by_triplet.shape[1:]) upper_support_xy = first_knot_xy + coeff_shape[0] * knot_spacing_xy node_energy = coeffs_by_triplet.new_zeros((n_nodes,)) node_force = coeffs_by_triplet.new_zeros((n_nodes, 3)) pattern_plans = buckets.pattern_plans if not pattern_plans: pattern_plans = build_bucket_pattern_plans( buckets.patterns, buckets.pattern_ptr, buckets.row_ptr, device, ) for plan in pattern_plans: src_start = plan.src_start src_end = plan.src_end src_cat = plan.src_cat layout = plan.layout.to_device(device) if layout.row.numel() == 0: continue batch_size = src_end - src_start degree = int(sum(plan.counts)) src_ids = buckets.src_ids[src_start:src_end].to(device=device) edge_start = plan.edge_start edge_end = plan.edge_end nbr_ids = ( buckets.nbr_ids[edge_start:edge_end] .to(device=device) .view( batch_size, degree, ) ) vec = ( buckets.pair_vectors[edge_start:edge_end] .to(device=device) .view( batch_size, degree, 3, ) ) pair_distances = ( buckets.pair_distances[edge_start:edge_end] .to(device=device) .view(batch_size, degree) ) row = layout.row col = layout.col triplet_idx = src_cat * triplets_per_src_cat + layout.edge_cat if active_triplet_mask is not None: active_mask = active_triplet_mask.index_select(0, triplet_idx) if not torch.any(active_mask): continue row = row[active_mask] col = col[active_mask] triplet_idx = triplet_idx[active_mask] vj = vec[:, row, :] vk = vec[:, col, :] x = pair_distances[:, row] y = pair_distances[:, col] diff = vj - vk z = torch.linalg.norm(diff, dim=2) flat_x = x.reshape(-1) flat_y = y.reshape(-1) flat_z = z.reshape(-1) flat_mask = ( (flat_x >= lower_support_xy) & (flat_x < upper_support_xy) & (flat_y >= lower_support_xy) & (flat_y < upper_support_xy) & (flat_z >= lower_support_z) & spline_support_mask_3d( flat_x, flat_y, flat_z, coeff_shape=coeff_shape, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, ) ) if not torch.any(flat_mask): continue all_supported = bool(torch.all(flat_mask)) if all_supported: supported_x = flat_x supported_y = flat_y supported_z = flat_z flat_triplet_positions = torch.arange( flat_x.numel(), dtype=torch.int64, device=device, ) else: supported_x = flat_x[flat_mask] supported_y = flat_y[flat_mask] supported_z = flat_z[flat_mask] flat_triplet_positions = torch.nonzero( flat_mask, as_tuple=False, ).reshape(-1) stencil = all_supported_uniform_stencil_3d( supported_x, supported_y, supported_z, coeff_shape=coeff_shape, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, ) src_triplet_ids = src_ids[:, None].expand(-1, row.numel()).reshape(-1) src_triplet_ids = src_triplet_ids.index_select(0, flat_triplet_positions) dst_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_triplet_positions) dst_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_triplet_positions) flat_triplet_idx = ( triplet_idx[None, :] .expand(batch_size, -1) .reshape(-1) .index_select(0, flat_triplet_positions) ) supported_vj = vj.reshape(-1, 3).index_select(0, flat_triplet_positions) supported_vk = vk.reshape(-1, 3).index_select(0, flat_triplet_positions) supported_diff = diff.reshape(-1, 3).index_select(0, flat_triplet_positions) coeff_window = coeffs_flat[ stencil.indices + flat_triplet_idx[:, None] * coeff_volume ] e = (stencil.values * coeff_window).sum(dim=1) ex = (stencil.grad_x * coeff_window).sum(dim=1) ey = (stencil.grad_y * coeff_window).sum(dim=1) ez = (stencil.grad_z * coeff_window).sum(dim=1) dE_dvj, dE_dvk = pair_distance_partials( ex, ey, ez, supported_vj, supported_vk, supported_diff, supported_x, supported_y, supported_z, eps, ) node_energy.index_add_(0, src_triplet_ids, e) node_force.index_add_(0, src_triplet_ids, dE_dvj + dE_dvk) node_force.index_add_(0, dst_j, -dE_dvj) node_force.index_add_(0, dst_k, -dE_dvk) return node_energy, node_force def _evaluate_bucketed_energy_forces( buckets: Buckets, node_cat: torch.Tensor, coeffs_by_triplet: torch.Tensor, edge_cat_table: torch.Tensor, eval_3d_with_grads: Eval3DWithGrads, *, spline: SplineKind = "cubic", active_triplet_mask: torch.Tensor | None = None, n_nodes: int, n_cat: int = 10, first_knot_xy: float = 0.0, first_knot_z: float = 0.0, knot_spacing_xy: float = 0.25, knot_spacing_z: float = 0.25, lower_support_xy: float = 0.0, lower_support_z: float = 0.0, eps: float = 1.0e-12, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Dispatch bucketed triplet evaluation to native or Torch backends.""" def torch_evaluator() -> tuple[torch.Tensor, torch.Tensor]: return _evaluate_bucketed_energy_forces_torch( buckets, node_cat, coeffs_by_triplet, edge_cat_table, eval_3d_with_grads, spline=spline, active_triplet_mask=active_triplet_mask, n_nodes=n_nodes, n_cat=n_cat, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, lower_support_xy=lower_support_xy, lower_support_z=lower_support_z, eps=eps, ) return evaluate_bucketed_energy_forces_native_or_torch( buckets, coeffs_by_triplet, active_triplet_mask, spline=spline, n_nodes=n_nodes, n_cat=n_cat, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, lower_support_xy=lower_support_xy, lower_support_z=lower_support_z, eps=eps, torch_evaluator=torch_evaluator, runtime_config=runtime_config, )
[docs] def evaluate_bucketed_energy_forces( buckets: Buckets, node_cat: torch.Tensor, coeffs_by_triplet: torch.Tensor, edge_cat_table: torch.Tensor, eval_3d_with_grads: Eval3DWithGrads | None = None, *, spline: SplineKind = "cubic", active_triplet_mask: torch.Tensor | None = None, n_nodes: int, n_cat: int = 10, first_knot_xy: float = 0.0, first_knot_z: float = 0.0, knot_spacing_xy: float = 0.25, knot_spacing_z: float = 0.25, lower_support_xy: float = 0.0, lower_support_z: float = 0.0, eps: float = 1.0e-12, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Public wrapper around the bucket evaluator used by three-body forward paths.""" if eval_3d_with_grads is None: eval_3d_with_grads = get_eval_3d_with_grads(spline) return _evaluate_bucketed_energy_forces( buckets, node_cat, coeffs_by_triplet, edge_cat_table, eval_3d_with_grads, spline=spline, active_triplet_mask=active_triplet_mask, n_nodes=n_nodes, n_cat=n_cat, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, lower_support_xy=lower_support_xy, lower_support_z=lower_support_z, eps=eps, runtime_config=runtime_config, )
[docs] def make_bucketed_energy_forces_evaluator( *, spline: SplineKind = "cubic", n_nodes: int, n_cat: int = 10, first_knot_xy: float = 0.0, first_knot_z: float = 0.0, knot_spacing_xy: float = 0.25, knot_spacing_z: float = 0.25, lower_support_xy: float = 0.0, lower_support_z: float = 0.0, eps: float = 1.0e-12, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> BucketedEnergyForceEvaluator: """Bind spline hyperparameters into a reusable bucket evaluator callable.""" eval_3d_with_grads = get_eval_3d_with_grads(spline) def evaluate( buckets: Buckets, node_cat: torch.Tensor, coeffs_by_triplet: torch.Tensor, edge_cat_table: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Evaluate bucketed three-body energies and forces.""" return _evaluate_bucketed_energy_forces( buckets, node_cat, coeffs_by_triplet, edge_cat_table, eval_3d_with_grads, spline=spline, active_triplet_mask=None, n_nodes=n_nodes, n_cat=n_cat, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, lower_support_xy=lower_support_xy, lower_support_z=lower_support_z, eps=eps, runtime_config=runtime_config, ) return evaluate