"""Bucketed three-body spline evaluation helpers."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import Literal
import torch
from ufp.splines._cubic import cubic_eval_3d_with_grads
from ufp.splines._quadratic import quadratic_eval_3d_with_grads
from ufp.splines._quartic import quartic_eval_3d_with_grads
from ufp.splines.representation import (
all_supported_uniform_stencil_3d,
spline_support_mask_3d,
supported_uniform_stencil_3d,
)
from ufp.terms._threebody_kernels import evaluate_bucketed_energy_forces_native_or_torch
from ufp.terms._threebody_ops import (
Buckets,
build_bucket_pattern_plans,
pair_distance_partials,
)
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig
from ufp.terms.categories import (
active_triplet_mask as _active_triplet_mask, # noqa: F401
)
from ufp.terms.categories import (
canonical_triplet as _canonical_triplet, # noqa: F401
)
from ufp.terms.categories import (
triplet_categories as _triplet_categories, # noqa: F401
)
Eval3DWithGrads = Callable[
[
float,
float,
float,
float,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
BucketedEnergyForceEvaluator = Callable[
[Buckets, torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor],
]
SplineKind = Literal["quadratic", "cubic", "quartic"]
_SPLINE_EVAL_3D_WITH_GRADS: dict[str, Eval3DWithGrads] = {
"quadratic": quadratic_eval_3d_with_grads,
"cubic": cubic_eval_3d_with_grads,
"quartic": quartic_eval_3d_with_grads,
}
[docs]
def get_eval_3d_with_grads(spline: SplineKind | str) -> Eval3DWithGrads:
"""Return the 3D spline evaluator for the requested basis family."""
try:
return _SPLINE_EVAL_3D_WITH_GRADS[spline]
except KeyError as exc:
choices = ", ".join(sorted(_SPLINE_EVAL_3D_WITH_GRADS))
raise ValueError(
f"Unsupported spline '{spline}'. Expected one of: {choices}."
) from exc
def _same_neighbor_triplet_mask(
triplet_categories: Sequence[tuple[int, int, int]],
) -> torch.Tensor:
"""Return triplet channels whose two neighbor categories are identical."""
return torch.tensor(
[
int(first_neighbor) == int(second_neighbor)
for _, first_neighbor, second_neighbor in triplet_categories
],
dtype=torch.bool,
)
def _symmetrize_same_neighbor_coeffs(
coeffs_by_triplet: torch.Tensor,
same_neighbor_triplet_mask: torch.Tensor,
) -> torch.Tensor:
"""Tie x/y coefficient axes for same-neighbor-category triplet channels."""
same_mask = same_neighbor_triplet_mask.to(device=coeffs_by_triplet.device)
if same_mask.numel() == 0 or not bool(torch.any(same_mask)):
return coeffs_by_triplet
output = coeffs_by_triplet.clone()
same_coeffs = output[same_mask]
output[same_mask] = 0.5 * (same_coeffs + same_coeffs.transpose(1, 2))
return output
def _neighbor_neighbor_cutoff(center_neighbor_cutoff: float) -> float:
"""Derive the outer cutoff implied by the center-neighbor cutoff."""
return 2.0 * float(center_neighbor_cutoff)
def _support_bounds(
first_knot: float,
knot_spacing: float,
coeff_size: int,
*,
lower_full_support: float,
) -> tuple[float, float]:
"""Return the physical support interval used before stencil construction."""
return float(lower_full_support), float(first_knot + int(coeff_size) * knot_spacing)
def _evaluate_pair_block(
vec: torch.Tensor,
r: torch.Tensor,
nbr_ids: torch.Tensor,
coeffs_by_triplet: torch.Tensor,
edge_cat_table: torch.Tensor,
eval_3d_with_grads: Eval3DWithGrads,
triplets_per_src_cat: int,
src_cat: int,
cat_a: int,
cat_b: int,
offset_a: int,
offset_b: int,
row: torch.Tensor,
col: torch.Tensor,
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
eps: float,
spline: SplineKind | str,
active_triplet_mask: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""Evaluate one bucketed triplet block and return its energy and force pieces."""
vj = vec[:, offset_a + row, :]
vk = vec[:, offset_b + col, :]
x = r[:, offset_a + row]
y = r[:, offset_b + col]
diff = vj - vk
z = torch.linalg.norm(diff, dim=2)
batch_size, n_pairs = x.shape
triplet_idx = src_cat * triplets_per_src_cat + edge_cat_table[cat_a, cat_b]
if active_triplet_mask is not None and not bool(active_triplet_mask[triplet_idx]):
empty_dst = nbr_ids[:, offset_a + row][:, :0]
empty_force = vec.new_zeros((batch_size, 0, 3))
return (
coeffs_by_triplet.new_zeros((batch_size,)),
coeffs_by_triplet.new_zeros((batch_size, 3)),
empty_dst,
empty_force,
empty_dst,
empty_force,
)
coeffs = coeffs_by_triplet[triplet_idx]
supported = supported_uniform_stencil_3d(
x.reshape(batch_size * n_pairs),
y.reshape(batch_size * n_pairs),
z.reshape(batch_size * n_pairs),
coeff_shape=tuple(int(value) for value in coeffs.shape),
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
)
flat_mask = supported.mask.reshape(-1)
flat_x = x.reshape(-1)
stencil = supported.stencil
coeff_window = coeffs.reshape(-1)[stencil.indices]
e = (stencil.values * coeff_window).sum(dim=1)
ex = (stencil.grad_x * coeff_window).sum(dim=1)
ey = (stencil.grad_y * coeff_window).sum(dim=1)
ez = (stencil.grad_z * coeff_window).sum(dim=1)
flat_e = flat_x.new_zeros(flat_x.shape)
flat_ex = flat_x.new_zeros(flat_x.shape)
flat_ey = flat_x.new_zeros(flat_x.shape)
flat_ez = flat_x.new_zeros(flat_x.shape)
flat_e[flat_mask] = e
flat_ex[flat_mask] = ex
flat_ey[flat_mask] = ey
flat_ez[flat_mask] = ez
e = flat_e.view(batch_size, n_pairs)
ex = flat_ex.view(batch_size, n_pairs)
ey = flat_ey.view(batch_size, n_pairs)
ez = flat_ez.view(batch_size, n_pairs)
flat_vj = vj.reshape(batch_size * n_pairs, 3)
flat_vk = vk.reshape(batch_size * n_pairs, 3)
flat_diff = diff.reshape(batch_size * n_pairs, 3)
supported_dE_dvj, supported_dE_dvk = pair_distance_partials(
flat_ex[flat_mask],
flat_ey[flat_mask],
flat_ez[flat_mask],
flat_vj[flat_mask],
flat_vk[flat_mask],
flat_diff[flat_mask],
supported.x,
supported.y,
supported.z,
eps,
)
flat_dE_dvj = flat_vj.new_zeros(flat_vj.shape)
flat_dE_dvk = flat_vk.new_zeros(flat_vk.shape)
flat_dE_dvj[flat_mask] = supported_dE_dvj
flat_dE_dvk[flat_mask] = supported_dE_dvk
force_j = -flat_dE_dvj.view(batch_size, n_pairs, 3)
force_k = -flat_dE_dvk.view(batch_size, n_pairs, 3)
force_i = (flat_dE_dvj + flat_dE_dvk).view(batch_size, n_pairs, 3)
dst_j = nbr_ids[:, offset_a + row]
dst_k = nbr_ids[:, offset_b + col]
return (
e.sum(dim=1),
force_i.sum(dim=1),
dst_j,
force_j,
dst_k,
force_k,
)
def _evaluate_bucketed_energy_forces_torch(
buckets: Buckets,
node_cat: torch.Tensor,
coeffs_by_triplet: torch.Tensor,
edge_cat_table: torch.Tensor,
eval_3d_with_grads: Eval3DWithGrads,
*,
spline: SplineKind = "cubic",
active_triplet_mask: torch.Tensor | None = None,
n_nodes: int,
n_cat: int = 10,
first_knot_xy: float = 0.0,
first_knot_z: float = 0.0,
knot_spacing_xy: float = 0.25,
knot_spacing_z: float = 0.25,
lower_support_xy: float = 0.0,
lower_support_z: float = 0.0,
eps: float = 1.0e-12,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Walk bucketed triplets and accumulate one coefficient tensor."""
del node_cat, edge_cat_table, eval_3d_with_grads
device = coeffs_by_triplet.device
if active_triplet_mask is not None:
active_triplet_mask = active_triplet_mask.to(device=device)
triplets_per_src_cat = n_cat * (n_cat + 1) // 2
coeff_volume = int(
coeffs_by_triplet.shape[1]
* coeffs_by_triplet.shape[2]
* coeffs_by_triplet.shape[3]
)
coeffs_flat = coeffs_by_triplet.reshape(-1)
coeff_shape = tuple(int(value) for value in coeffs_by_triplet.shape[1:])
upper_support_xy = first_knot_xy + coeff_shape[0] * knot_spacing_xy
node_energy = coeffs_by_triplet.new_zeros((n_nodes,))
node_force = coeffs_by_triplet.new_zeros((n_nodes, 3))
pattern_plans = buckets.pattern_plans
if not pattern_plans:
pattern_plans = build_bucket_pattern_plans(
buckets.patterns,
buckets.pattern_ptr,
buckets.row_ptr,
device,
)
for plan in pattern_plans:
src_start = plan.src_start
src_end = plan.src_end
src_cat = plan.src_cat
layout = plan.layout.to_device(device)
if layout.row.numel() == 0:
continue
batch_size = src_end - src_start
degree = int(sum(plan.counts))
src_ids = buckets.src_ids[src_start:src_end].to(device=device)
edge_start = plan.edge_start
edge_end = plan.edge_end
nbr_ids = (
buckets.nbr_ids[edge_start:edge_end]
.to(device=device)
.view(
batch_size,
degree,
)
)
vec = (
buckets.pair_vectors[edge_start:edge_end]
.to(device=device)
.view(
batch_size,
degree,
3,
)
)
pair_distances = (
buckets.pair_distances[edge_start:edge_end]
.to(device=device)
.view(batch_size, degree)
)
row = layout.row
col = layout.col
triplet_idx = src_cat * triplets_per_src_cat + layout.edge_cat
if active_triplet_mask is not None:
active_mask = active_triplet_mask.index_select(0, triplet_idx)
if not torch.any(active_mask):
continue
row = row[active_mask]
col = col[active_mask]
triplet_idx = triplet_idx[active_mask]
vj = vec[:, row, :]
vk = vec[:, col, :]
x = pair_distances[:, row]
y = pair_distances[:, col]
diff = vj - vk
z = torch.linalg.norm(diff, dim=2)
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
flat_z = z.reshape(-1)
flat_mask = (
(flat_x >= lower_support_xy)
& (flat_x < upper_support_xy)
& (flat_y >= lower_support_xy)
& (flat_y < upper_support_xy)
& (flat_z >= lower_support_z)
& spline_support_mask_3d(
flat_x,
flat_y,
flat_z,
coeff_shape=coeff_shape,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
)
)
if not torch.any(flat_mask):
continue
all_supported = bool(torch.all(flat_mask))
if all_supported:
supported_x = flat_x
supported_y = flat_y
supported_z = flat_z
flat_triplet_positions = torch.arange(
flat_x.numel(),
dtype=torch.int64,
device=device,
)
else:
supported_x = flat_x[flat_mask]
supported_y = flat_y[flat_mask]
supported_z = flat_z[flat_mask]
flat_triplet_positions = torch.nonzero(
flat_mask,
as_tuple=False,
).reshape(-1)
stencil = all_supported_uniform_stencil_3d(
supported_x,
supported_y,
supported_z,
coeff_shape=coeff_shape,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
)
src_triplet_ids = src_ids[:, None].expand(-1, row.numel()).reshape(-1)
src_triplet_ids = src_triplet_ids.index_select(0, flat_triplet_positions)
dst_j = nbr_ids[:, row].reshape(-1).index_select(0, flat_triplet_positions)
dst_k = nbr_ids[:, col].reshape(-1).index_select(0, flat_triplet_positions)
flat_triplet_idx = (
triplet_idx[None, :]
.expand(batch_size, -1)
.reshape(-1)
.index_select(0, flat_triplet_positions)
)
supported_vj = vj.reshape(-1, 3).index_select(0, flat_triplet_positions)
supported_vk = vk.reshape(-1, 3).index_select(0, flat_triplet_positions)
supported_diff = diff.reshape(-1, 3).index_select(0, flat_triplet_positions)
coeff_window = coeffs_flat[
stencil.indices + flat_triplet_idx[:, None] * coeff_volume
]
e = (stencil.values * coeff_window).sum(dim=1)
ex = (stencil.grad_x * coeff_window).sum(dim=1)
ey = (stencil.grad_y * coeff_window).sum(dim=1)
ez = (stencil.grad_z * coeff_window).sum(dim=1)
dE_dvj, dE_dvk = pair_distance_partials(
ex,
ey,
ez,
supported_vj,
supported_vk,
supported_diff,
supported_x,
supported_y,
supported_z,
eps,
)
node_energy.index_add_(0, src_triplet_ids, e)
node_force.index_add_(0, src_triplet_ids, dE_dvj + dE_dvk)
node_force.index_add_(0, dst_j, -dE_dvj)
node_force.index_add_(0, dst_k, -dE_dvk)
return node_energy, node_force
def _evaluate_bucketed_energy_forces(
buckets: Buckets,
node_cat: torch.Tensor,
coeffs_by_triplet: torch.Tensor,
edge_cat_table: torch.Tensor,
eval_3d_with_grads: Eval3DWithGrads,
*,
spline: SplineKind = "cubic",
active_triplet_mask: torch.Tensor | None = None,
n_nodes: int,
n_cat: int = 10,
first_knot_xy: float = 0.0,
first_knot_z: float = 0.0,
knot_spacing_xy: float = 0.25,
knot_spacing_z: float = 0.25,
lower_support_xy: float = 0.0,
lower_support_z: float = 0.0,
eps: float = 1.0e-12,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dispatch bucketed triplet evaluation to native or Torch backends."""
def torch_evaluator() -> tuple[torch.Tensor, torch.Tensor]:
return _evaluate_bucketed_energy_forces_torch(
buckets,
node_cat,
coeffs_by_triplet,
edge_cat_table,
eval_3d_with_grads,
spline=spline,
active_triplet_mask=active_triplet_mask,
n_nodes=n_nodes,
n_cat=n_cat,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
lower_support_xy=lower_support_xy,
lower_support_z=lower_support_z,
eps=eps,
)
return evaluate_bucketed_energy_forces_native_or_torch(
buckets,
coeffs_by_triplet,
active_triplet_mask,
spline=spline,
n_nodes=n_nodes,
n_cat=n_cat,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
lower_support_xy=lower_support_xy,
lower_support_z=lower_support_z,
eps=eps,
torch_evaluator=torch_evaluator,
runtime_config=runtime_config,
)
[docs]
def evaluate_bucketed_energy_forces(
buckets: Buckets,
node_cat: torch.Tensor,
coeffs_by_triplet: torch.Tensor,
edge_cat_table: torch.Tensor,
eval_3d_with_grads: Eval3DWithGrads | None = None,
*,
spline: SplineKind = "cubic",
active_triplet_mask: torch.Tensor | None = None,
n_nodes: int,
n_cat: int = 10,
first_knot_xy: float = 0.0,
first_knot_z: float = 0.0,
knot_spacing_xy: float = 0.25,
knot_spacing_z: float = 0.25,
lower_support_xy: float = 0.0,
lower_support_z: float = 0.0,
eps: float = 1.0e-12,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Public wrapper around the bucket evaluator used by three-body forward paths."""
if eval_3d_with_grads is None:
eval_3d_with_grads = get_eval_3d_with_grads(spline)
return _evaluate_bucketed_energy_forces(
buckets,
node_cat,
coeffs_by_triplet,
edge_cat_table,
eval_3d_with_grads,
spline=spline,
active_triplet_mask=active_triplet_mask,
n_nodes=n_nodes,
n_cat=n_cat,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
lower_support_xy=lower_support_xy,
lower_support_z=lower_support_z,
eps=eps,
runtime_config=runtime_config,
)
[docs]
def make_bucketed_energy_forces_evaluator(
*,
spline: SplineKind = "cubic",
n_nodes: int,
n_cat: int = 10,
first_knot_xy: float = 0.0,
first_knot_z: float = 0.0,
knot_spacing_xy: float = 0.25,
knot_spacing_z: float = 0.25,
lower_support_xy: float = 0.0,
lower_support_z: float = 0.0,
eps: float = 1.0e-12,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> BucketedEnergyForceEvaluator:
"""Bind spline hyperparameters into a reusable bucket evaluator callable."""
eval_3d_with_grads = get_eval_3d_with_grads(spline)
def evaluate(
buckets: Buckets,
node_cat: torch.Tensor,
coeffs_by_triplet: torch.Tensor,
edge_cat_table: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Evaluate bucketed three-body energies and forces."""
return _evaluate_bucketed_energy_forces(
buckets,
node_cat,
coeffs_by_triplet,
edge_cat_table,
eval_3d_with_grads,
spline=spline,
active_triplet_mask=None,
n_nodes=n_nodes,
n_cat=n_cat,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
lower_support_xy=lower_support_xy,
lower_support_z=lower_support_z,
eps=eps,
runtime_config=runtime_config,
)
return evaluate