"""
Low-level combinatorics for three-body spline assembly.
This module groups neighbors by category and exposes reusable bucket, edge, and
distance-derivative helpers shared by forward and fitting code.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Sequence
import torch
BucketItem = tuple[int, torch.Tensor, torch.Tensor]
[docs]
@dataclass(frozen=True)
class Buckets:
"""Flattened CSR-like neighborhood batches grouped by shared source patterns."""
src_ids: torch.Tensor
row_ptr: torch.Tensor
nbr_ids: torch.Tensor
pair_vectors: torch.Tensor
pair_distances: torch.Tensor
patterns: torch.Tensor
pattern_ptr: torch.Tensor
pattern_plans: tuple["BucketPatternPlan", ...] = field(default_factory=tuple)
tensor_pattern_plans: "TensorBucketPatternPlans | None" = None
def __bool__(self) -> bool:
"""Report whether any valid source neighborhoods were retained."""
return bool(self.src_ids.numel())
[docs]
def with_pattern_plans(self, device: torch.device | None = None) -> "Buckets":
"""Return buckets with reusable CPU-side pattern metadata attached."""
resolved_device = self.pair_vectors.device if device is None else device
pattern_plans = build_bucket_pattern_plans(
self.patterns,
self.pattern_ptr,
self.row_ptr,
torch.device(resolved_device),
)
return Buckets(
src_ids=self.src_ids,
row_ptr=self.row_ptr,
nbr_ids=self.nbr_ids,
pair_vectors=self.pair_vectors,
pair_distances=self.pair_distances,
patterns=self.patterns,
pattern_ptr=self.pattern_ptr,
pattern_plans=pattern_plans,
tensor_pattern_plans=tensorize_bucket_pattern_plans(
pattern_plans,
torch.device(resolved_device),
),
)
@dataclass(frozen=True)
class PatternTripletLayout:
"""Cached triplet index layout for one neighbor-category histogram."""
row: torch.Tensor
col: torch.Tensor
edge_cat: torch.Tensor
def to_device(self, device: torch.device) -> "PatternTripletLayout":
"""Return this layout on the requested device."""
if self.row.device == device:
return self
return PatternTripletLayout(
row=self.row.to(device=device, non_blocking=True),
col=self.col.to(device=device, non_blocking=True),
edge_cat=self.edge_cat.to(device=device, non_blocking=True),
)
@dataclass(frozen=True)
class BucketPatternPlan:
"""CPU-side bucket metadata plus device triplet layout for one source pattern."""
src_cat: int
counts: tuple[int, ...]
src_start: int
src_end: int
edge_start: int
edge_end: int
layout: PatternTripletLayout
def to_device(self, device: torch.device) -> "BucketPatternPlan":
"""Return a copy whose triplet layout lives on ``device``."""
return BucketPatternPlan(
src_cat=self.src_cat,
counts=self.counts,
src_start=self.src_start,
src_end=self.src_end,
edge_start=self.edge_start,
edge_end=self.edge_end,
layout=self.layout.to_device(device),
)
@dataclass(frozen=True)
class TensorBucketPatternPlans:
"""Tensorized bucket metadata for native three-body operator calls."""
src_cat: torch.Tensor
src_start: torch.Tensor
src_end: torch.Tensor
edge_start: torch.Tensor
edge_end: torch.Tensor
degree: torch.Tensor
layout_ptr: torch.Tensor
row: torch.Tensor
col: torch.Tensor
edge_cat: torch.Tensor
def __bool__(self) -> bool:
"""Report whether any source pattern metadata is present."""
return bool(self.src_cat.numel())
def to_device(self, device: torch.device) -> "TensorBucketPatternPlans":
"""Return this tensor metadata on the requested device."""
if self.src_cat.device == device:
return self
return TensorBucketPatternPlans(
src_cat=self.src_cat.to(device=device, non_blocking=True),
src_start=self.src_start.to(device=device, non_blocking=True),
src_end=self.src_end.to(device=device, non_blocking=True),
edge_start=self.edge_start.to(device=device, non_blocking=True),
edge_end=self.edge_end.to(device=device, non_blocking=True),
degree=self.degree.to(device=device, non_blocking=True),
layout_ptr=self.layout_ptr.to(device=device, non_blocking=True),
row=self.row.to(device=device, non_blocking=True),
col=self.col.to(device=device, non_blocking=True),
edge_cat=self.edge_cat.to(device=device, non_blocking=True),
)
_PAIR_CACHE_SAMECAT: dict[tuple[int, str], tuple[torch.Tensor, torch.Tensor]] = {}
_PAIR_CACHE_CROSSCAT: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {}
_PATTERN_TRIPLET_LAYOUT_CACHE: dict[
tuple[tuple[int, ...], str],
PatternTripletLayout,
] = {}
[docs]
def num_edge_categories(n_cat: int) -> int:
"""Return the number of unordered neighbor-category pairs."""
return n_cat * (n_cat + 1) // 2
[docs]
def build_edge_category_table(n_cat: int, device=None) -> torch.Tensor:
"""Map ordered category pairs to unordered edge-category indices."""
table = torch.empty((n_cat, n_cat), dtype=torch.int64, device=device)
idx = 0
for a in range(n_cat):
for b in range(a, n_cat):
table[a, b] = idx
table[b, a] = idx
idx += 1
return table
def edge_category_index(cat_a: int, cat_b: int, n_cat: int) -> int:
"""Return the unordered edge-category index for two category ids."""
first = min(int(cat_a), int(cat_b))
second = max(int(cat_a), int(cat_b))
return first * int(n_cat) - first * (first - 1) // 2 + (second - first)
[docs]
def preprocess_sources(
i: torch.Tensor,
j: torch.Tensor,
node_cat: torch.Tensor,
n_cat: int,
pair_vectors: torch.Tensor,
pair_distances: torch.Tensor | None = None,
) -> Buckets:
"""Group directed neighbor rows into source buckets for triplet assembly."""
if i.ndim != 1 or j.ndim != 1 or i.shape != j.shape:
raise ValueError("`i` and `j` must be 1D tensors with the same shape")
if pair_vectors.ndim != 2 or pair_vectors.shape != (i.numel(), 3):
raise ValueError("`pair_vectors` must have shape (n_pairs, 3)")
if pair_distances is None:
pair_distances = torch.linalg.vector_norm(pair_vectors, dim=1)
elif pair_distances.ndim != 1 or pair_distances.shape != i.shape:
raise ValueError("`pair_distances` must have shape (n_pairs,)")
if i.numel() == 0:
empty_int = torch.zeros((0,), dtype=torch.int64, device=i.device)
return Buckets(
src_ids=empty_int,
row_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device),
nbr_ids=empty_int,
pair_vectors=pair_vectors.new_zeros((0, 3)),
pair_distances=pair_distances.new_zeros((0,)),
patterns=torch.zeros((0, n_cat + 1), dtype=torch.int64, device=i.device),
pattern_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device),
)
nbr_cat = node_cat[j]
# Sort globally by (source, neighbor category, neighbor id) so each source
# neighborhood becomes one contiguous CSR row with category blocks.
n_nodes = int(node_cat.numel())
row_sort_key = (i * int(n_cat) + nbr_cat.to(torch.int64)) * n_nodes + j
perm = torch.argsort(row_sort_key, stable=True)
i = i[perm]
j = j[perm]
nbr_cat = nbr_cat[perm]
pair_vectors = pair_vectors[perm]
pair_distances = pair_distances[perm]
src_vals, counts = torch.unique_consecutive(i, return_counts=True)
keep_src = counts >= 2
if not torch.any(keep_src):
empty_int = torch.zeros((0,), dtype=torch.int64, device=i.device)
return Buckets(
src_ids=empty_int,
row_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device),
nbr_ids=empty_int,
pair_vectors=pair_vectors.new_zeros((0, 3)),
pair_distances=pair_distances.new_zeros((0,)),
patterns=torch.zeros((0, n_cat + 1), dtype=torch.int64, device=i.device),
pattern_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device),
)
edge_keep = torch.repeat_interleave(keep_src, counts)
i = i[edge_keep]
j = j[edge_keep]
nbr_cat = nbr_cat[edge_keep]
pair_vectors = pair_vectors[edge_keep]
pair_distances = pair_distances[edge_keep]
src_ids = src_vals[keep_src]
degrees = counts[keep_src]
source_index = torch.repeat_interleave(
torch.arange(src_ids.numel(), device=i.device, dtype=torch.int64),
degrees,
)
counts_by_source_cat = torch.bincount(
source_index * n_cat + nbr_cat.to(torch.int64),
minlength=src_ids.numel() * n_cat,
).reshape(src_ids.numel(), n_cat)
src_cat = node_cat[src_ids].to(torch.int64)
pattern_rows = torch.cat([src_cat[:, None], counts_by_source_cat], dim=1)
patterns, pattern_index = torch.unique(
pattern_rows,
dim=0,
return_inverse=True,
sorted=True,
)
source_order = torch.argsort(pattern_index, stable=True)
original_row_ptr = torch.cat(
[torch.zeros((1,), dtype=torch.int64, device=i.device), degrees.cumsum(0)]
)
ordered_degrees = degrees[source_order]
ordered_edge_starts = original_row_ptr[source_order]
ordered_row_starts = torch.cat(
[
torch.zeros((1,), dtype=torch.int64, device=i.device),
ordered_degrees.cumsum(0)[:-1],
]
)
edge_offsets = torch.arange(i.numel(), device=i.device, dtype=torch.int64)
edge_offsets = edge_offsets - torch.repeat_interleave(
ordered_row_starts,
ordered_degrees,
)
edge_order = (
torch.repeat_interleave(ordered_edge_starts, ordered_degrees) + edge_offsets
)
src_ids = src_ids[source_order]
degrees = ordered_degrees
pattern_index = pattern_index[source_order]
_, pattern_counts = torch.unique_consecutive(pattern_index, return_counts=True)
row_ptr = torch.cat(
[torch.zeros((1,), dtype=torch.int64, device=i.device), degrees.cumsum(0)]
)
pattern_ptr = torch.cat(
[
torch.zeros((1,), dtype=torch.int64, device=i.device),
pattern_counts.cumsum(0),
]
)
return Buckets(
src_ids=src_ids,
row_ptr=row_ptr,
nbr_ids=j[edge_order],
pair_vectors=pair_vectors[edge_order],
pair_distances=pair_distances[edge_order],
patterns=patterns,
pattern_ptr=pattern_ptr,
)
def samecat_pairs(
count: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return cached upper-triangular pair indices within one category block."""
key = (count, str(device))
if key not in _PAIR_CACHE_SAMECAT:
row, col = torch.triu_indices(count, count, offset=1)
_PAIR_CACHE_SAMECAT[key] = (
row.to(device, non_blocking=True),
col.to(device, non_blocking=True),
)
return _PAIR_CACHE_SAMECAT[key]
def crosscat_pairs(
ca: int,
cb: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return cached Cartesian pair indices between two category blocks."""
key = (ca, cb, str(device))
if key not in _PAIR_CACHE_CROSSCAT:
row = torch.arange(ca, dtype=torch.int64).repeat_interleave(cb)
col = torch.arange(cb, dtype=torch.int64).repeat(ca)
_PAIR_CACHE_CROSSCAT[key] = (
row.to(device, non_blocking=True),
col.to(device, non_blocking=True),
)
return _PAIR_CACHE_CROSSCAT[key]
def pattern_triplet_layout(
counts: Sequence[int],
device: torch.device,
) -> PatternTripletLayout:
"""Return cached triplet row/column offsets for one source pattern."""
normalized_counts = tuple(int(count) for count in counts)
key = (normalized_counts, str(device))
if key in _PATTERN_TRIPLET_LAYOUT_CACHE:
return _PATTERN_TRIPLET_LAYOUT_CACHE[key]
n_cat = len(normalized_counts)
offsets = [0]
for count in normalized_counts[:-1]:
offsets.append(offsets[-1] + count)
row_parts: list[torch.Tensor] = []
col_parts: list[torch.Tensor] = []
edge_parts: list[torch.Tensor] = []
for cat_a in range(n_cat):
count_a = normalized_counts[cat_a]
if count_a == 0:
continue
offset_a = offsets[cat_a]
row, col = samecat_pairs(count_a, device)
if row.numel() > 0:
row_parts.append(offset_a + row)
col_parts.append(offset_a + col)
edge_parts.append(
torch.full_like(
row,
edge_category_index(cat_a, cat_a, n_cat),
)
)
for cat_b in range(cat_a + 1, n_cat):
count_b = normalized_counts[cat_b]
if count_b == 0:
continue
offset_b = offsets[cat_b]
row, col = crosscat_pairs(count_a, count_b, device)
if row.numel() == 0:
continue
row_parts.append(offset_a + row)
col_parts.append(offset_b + col)
edge_parts.append(
torch.full_like(
row,
edge_category_index(cat_a, cat_b, n_cat),
)
)
if row_parts:
layout = PatternTripletLayout(
row=torch.cat(row_parts, dim=0),
col=torch.cat(col_parts, dim=0),
edge_cat=torch.cat(edge_parts, dim=0),
)
else:
empty = torch.zeros((0,), dtype=torch.int64, device=device)
layout = PatternTripletLayout(
row=empty,
col=empty,
edge_cat=empty,
)
_PATTERN_TRIPLET_LAYOUT_CACHE[key] = layout
return layout
def build_bucket_pattern_plans(
patterns: torch.Tensor,
pattern_ptr: torch.Tensor,
row_ptr: torch.Tensor,
device: torch.device,
) -> tuple[BucketPatternPlan, ...]:
"""Build reusable Python metadata for walking bucketed source patterns."""
patterns_host = (
patterns if patterns.device.type == "cpu" else patterns.detach().cpu()
)
pattern_ptr_host = (
pattern_ptr if pattern_ptr.device.type == "cpu" else pattern_ptr.detach().cpu()
)
row_ptr_host = row_ptr if row_ptr.device.type == "cpu" else row_ptr.detach().cpu()
plans: list[BucketPatternPlan] = []
for pattern_index in range(int(patterns_host.shape[0])):
pattern = patterns_host[pattern_index]
counts = tuple(int(value) for value in pattern[1:].tolist())
plans.append(
BucketPatternPlan(
src_cat=int(pattern[0]),
counts=counts,
src_start=int(pattern_ptr_host[pattern_index]),
src_end=int(pattern_ptr_host[pattern_index + 1]),
edge_start=int(row_ptr_host[int(pattern_ptr_host[pattern_index])]),
edge_end=int(row_ptr_host[int(pattern_ptr_host[pattern_index + 1])]),
layout=pattern_triplet_layout(counts, device),
)
)
return tuple(plans)
def tensorize_bucket_pattern_plans(
pattern_plans: Sequence[BucketPatternPlan],
device: torch.device,
) -> TensorBucketPatternPlans:
"""Convert Python bucket pattern plans to a tensor-only execution contract."""
int_kwargs = {"dtype": torch.int64, "device": device}
n_plans = len(pattern_plans)
if n_plans == 0:
empty = torch.empty((0,), **int_kwargs)
return TensorBucketPatternPlans(
src_cat=empty,
src_start=empty,
src_end=empty,
edge_start=empty,
edge_end=empty,
degree=empty,
layout_ptr=torch.zeros((1,), **int_kwargs),
row=empty,
col=empty,
edge_cat=empty,
)
src_cat = torch.tensor([plan.src_cat for plan in pattern_plans], **int_kwargs)
src_start = torch.tensor([plan.src_start for plan in pattern_plans], **int_kwargs)
src_end = torch.tensor([plan.src_end for plan in pattern_plans], **int_kwargs)
edge_start = torch.tensor([plan.edge_start for plan in pattern_plans], **int_kwargs)
edge_end = torch.tensor([plan.edge_end for plan in pattern_plans], **int_kwargs)
degree = torch.tensor([sum(plan.counts) for plan in pattern_plans], **int_kwargs)
layout_lengths = torch.tensor(
[plan.layout.row.numel() for plan in pattern_plans],
**int_kwargs,
)
layout_ptr = torch.cat(
[
torch.zeros((1,), **int_kwargs),
layout_lengths.cumsum(0),
],
)
row_parts = [plan.layout.row.to(device=device) for plan in pattern_plans]
col_parts = [plan.layout.col.to(device=device) for plan in pattern_plans]
edge_cat_parts = [plan.layout.edge_cat.to(device=device) for plan in pattern_plans]
return TensorBucketPatternPlans(
src_cat=src_cat,
src_start=src_start,
src_end=src_end,
edge_start=edge_start,
edge_end=edge_end,
degree=degree,
layout_ptr=layout_ptr,
row=torch.cat(row_parts, dim=0),
col=torch.cat(col_parts, dim=0),
edge_cat=torch.cat(edge_cat_parts, dim=0),
)
def build_tensor_bucket_pattern_plans(
patterns: torch.Tensor,
pattern_ptr: torch.Tensor,
row_ptr: torch.Tensor,
device: torch.device,
) -> TensorBucketPatternPlans:
"""Build tensorized pattern metadata directly from bucket CSR structures."""
return tensorize_bucket_pattern_plans(
build_bucket_pattern_plans(patterns, pattern_ptr, row_ptr, device),
device,
)
def build_tensor_bucket_pattern_plans_torch(
patterns: torch.Tensor,
pattern_ptr: torch.Tensor,
row_ptr: torch.Tensor,
device: torch.device | None = None,
) -> TensorBucketPatternPlans:
"""
Build tensorized pattern metadata with Torch tensor operations.
This path is intended for CUDA dynamic inference, where the native evaluator
consumes tensor metadata and the Python ``BucketPatternPlan`` tuple would
otherwise force all pattern rows through host-side loops.
Args:
patterns: Tensor of source-category and neighbor-count patterns.
pattern_ptr: CSR pointer into source rows for each pattern.
row_ptr: CSR pointer into neighbor rows for each source.
device: Optional destination device for the tensorized metadata.
Returns:
Tensorized pattern plans on the resolved device.
"""
resolved_device = patterns.device if device is None else torch.device(device)
int_kwargs = {"dtype": torch.int64, "device": resolved_device}
n_plans = int(patterns.shape[0])
if n_plans == 0:
empty = torch.empty((0,), **int_kwargs)
return TensorBucketPatternPlans(
src_cat=empty,
src_start=empty,
src_end=empty,
edge_start=empty,
edge_end=empty,
degree=empty,
layout_ptr=torch.zeros((1,), **int_kwargs),
row=empty,
col=empty,
edge_cat=empty,
)
patterns = patterns.to(device=resolved_device, dtype=torch.int64)
pattern_ptr = pattern_ptr.to(device=resolved_device, dtype=torch.int64)
row_ptr = row_ptr.to(device=resolved_device, dtype=torch.int64)
counts = patterns[:, 1:]
n_cat = int(counts.shape[1])
src_start = pattern_ptr[:-1]
src_end = pattern_ptr[1:]
degree = counts.sum(dim=1)
max_degree = int(degree.detach().max().cpu().item())
if max_degree < 2:
empty = torch.empty((0,), **int_kwargs)
return TensorBucketPatternPlans(
src_cat=patterns[:, 0],
src_start=src_start,
src_end=src_end,
edge_start=row_ptr.index_select(0, src_start),
edge_end=row_ptr.index_select(0, src_end),
degree=degree,
layout_ptr=torch.zeros((n_plans + 1,), **int_kwargs),
row=empty,
col=empty,
edge_cat=empty,
)
offsets = torch.cat(
[
torch.zeros((n_plans, 1), **int_kwargs),
counts.cumsum(dim=1)[:, :-1],
],
dim=1,
)
plan_ids = torch.arange(n_plans, **int_kwargs)
layout_lengths = torch.zeros((n_plans,), **int_kwargs)
row_parts: list[torch.Tensor] = []
col_parts: list[torch.Tensor] = []
edge_cat_parts: list[torch.Tensor] = []
plan_parts: list[torch.Tensor] = []
pair_order_parts: list[torch.Tensor] = []
local_order_parts: list[torch.Tensor] = []
pair_order = 0
for cat_a in range(n_cat):
count_a = counts[:, cat_a]
offset_a = offsets[:, cat_a]
same_count = int(count_a.detach().max().cpu().item())
if same_count > 1:
base_row, base_col = torch.triu_indices(
same_count,
same_count,
offset=1,
dtype=torch.int64,
device=resolved_device,
)
valid = base_col.unsqueeze(0) < count_a.unsqueeze(1)
layout_lengths += count_a * (count_a - 1) // 2
row_values = offset_a[:, None] + base_row[None, :]
col_values = offset_a[:, None] + base_col[None, :]
local_order = torch.arange(base_row.numel(), **int_kwargs)
row_parts.append(row_values.expand(n_plans, -1)[valid])
col_parts.append(col_values.expand(n_plans, -1)[valid])
edge_cat_parts.append(
torch.full(
(int(valid.sum().item()),),
edge_category_index(cat_a, cat_a, n_cat),
**int_kwargs,
)
)
plan_parts.append(
plan_ids[:, None].expand(n_plans, base_row.numel())[valid]
)
pair_order_parts.append(
torch.full((int(valid.sum().item()),), pair_order, **int_kwargs)
)
local_order_parts.append(local_order[None, :].expand(n_plans, -1)[valid])
pair_order += 1
for cat_b in range(cat_a + 1, n_cat):
count_b = counts[:, cat_b]
offset_b = offsets[:, cat_b]
layout_lengths += count_a * count_b
max_a = int(count_a.detach().max().cpu().item())
max_b = int(count_b.detach().max().cpu().item())
if max_a == 0 or max_b == 0:
pair_order += 1
continue
base_row = torch.arange(max_a, **int_kwargs).repeat_interleave(max_b)
base_col = torch.arange(max_b, **int_kwargs).repeat(max_a)
valid = (base_row.unsqueeze(0) < count_a.unsqueeze(1)) & (
base_col.unsqueeze(0) < count_b.unsqueeze(1)
)
row_values = offset_a[:, None] + base_row[None, :]
col_values = offset_b[:, None] + base_col[None, :]
local_order = torch.arange(base_row.numel(), **int_kwargs)
row_parts.append(row_values.expand(n_plans, -1)[valid])
col_parts.append(col_values.expand(n_plans, -1)[valid])
edge_cat_parts.append(
torch.full(
(int(valid.sum().item()),),
edge_category_index(cat_a, cat_b, n_cat),
**int_kwargs,
)
)
plan_parts.append(
plan_ids[:, None].expand(n_plans, base_row.numel())[valid]
)
pair_order_parts.append(
torch.full((int(valid.sum().item()),), pair_order, **int_kwargs)
)
local_order_parts.append(local_order[None, :].expand(n_plans, -1)[valid])
pair_order += 1
layout_ptr = torch.cat(
[
torch.zeros((1,), **int_kwargs),
layout_lengths.cumsum(dim=0),
],
)
if not row_parts:
empty = torch.empty((0,), **int_kwargs)
row = empty
col = empty
edge_cat = empty
else:
max_local_slots = max(max_degree * max_degree, 1)
sort_key = (
torch.cat(plan_parts) * (num_edge_categories(n_cat) * max_local_slots)
+ torch.cat(pair_order_parts) * max_local_slots
+ torch.cat(local_order_parts)
)
order = torch.argsort(sort_key, stable=True)
row = torch.cat(row_parts).index_select(0, order).contiguous()
col = torch.cat(col_parts).index_select(0, order).contiguous()
edge_cat = torch.cat(edge_cat_parts).index_select(0, order).contiguous()
return TensorBucketPatternPlans(
src_cat=patterns[:, 0],
src_start=src_start,
src_end=src_end,
edge_start=row_ptr.index_select(0, src_start),
edge_end=row_ptr.index_select(0, src_end),
degree=degree,
layout_ptr=layout_ptr,
row=row,
col=col,
edge_cat=edge_cat,
)
def pair_distance_partials(
ex: torch.Tensor,
ey: torch.Tensor,
ez: torch.Tensor,
vj: torch.Tensor,
vk: torch.Tensor,
diff: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pair distance partials."""
inv_x = torch.where(x > eps, x.reciprocal(), torch.zeros_like(x))
inv_y = torch.where(y > eps, y.reciprocal(), torch.zeros_like(y))
inv_z = torch.where(z > eps, z.reciprocal(), torch.zeros_like(z))
dE_dvj = ex[:, None] * (vj * inv_x[:, None]) + ez[:, None] * (diff * inv_z[:, None])
dE_dvk = ey[:, None] * (vk * inv_y[:, None]) - ez[:, None] * (diff * inv_z[:, None])
return dE_dvj, dE_dvk
def pair_distance_partials_batched(
ex: torch.Tensor,
ey: torch.Tensor,
ez: torch.Tensor,
vj: torch.Tensor,
vk: torch.Tensor,
diff: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pair distance partials for one stencil axis per triplet."""
inv_x = torch.where(x > eps, x.reciprocal(), torch.zeros_like(x))
inv_y = torch.where(y > eps, y.reciprocal(), torch.zeros_like(y))
inv_z = torch.where(z > eps, z.reciprocal(), torch.zeros_like(z))
unit_x = vj * inv_x[:, None]
unit_y = vk * inv_y[:, None]
unit_z = diff * inv_z[:, None]
dE_dvj = ex[:, :, None] * unit_x[:, None, :] + ez[:, :, None] * unit_z[:, None, :]
dE_dvk = ey[:, :, None] * unit_y[:, None, :] - ez[:, :, None] * unit_z[:, None, :]
return dE_dvj, dE_dvk
def map_atomic_numbers_to_categories(
atomic_numbers: torch.Tensor,
atomic_types: Sequence[int],
) -> torch.Tensor:
"""Map atomic numbers to categories."""
categories = torch.full_like(atomic_numbers, fill_value=-1, dtype=torch.int64)
for category, atomic_number in enumerate(atomic_types):
categories[atomic_numbers == int(atomic_number)] = category
return categories
def shortest_unique_directed_pairs(
first: torch.Tensor,
second: torch.Tensor,
pair_vectors: torch.Tensor,
*,
n_nodes: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Keep one directed pair per ``i -> j`` by selecting the shortest image."""
keep = shortest_unique_directed_pair_indices(
first,
second,
n_nodes=n_nodes,
pair_vectors=pair_vectors,
)
return first[keep], second[keep], pair_vectors[keep]
def shortest_unique_directed_pair_indices(
first: torch.Tensor,
second: torch.Tensor,
*,
n_nodes: int,
pair_vectors: torch.Tensor | None = None,
pair_distances: torch.Tensor | None = None,
pair_shifts: torch.Tensor | None = None,
atol: float = 1.0e-12,
) -> torch.Tensor:
"""Return the retained row indices when duplicate directed periodic images exist."""
if first.ndim != 1 or second.ndim != 1 or first.shape != second.shape:
raise ValueError("`first` and `second` must be 1D tensors with the same shape")
if pair_distances is None:
if pair_vectors is None:
raise ValueError(
"either `pair_distances` or `pair_vectors` is required to resolve "
"duplicate directed pairs"
)
if pair_vectors.ndim != 2 or pair_vectors.shape != (first.numel(), 3):
raise ValueError("`pair_vectors` must have shape (n_pairs, 3)")
pair_distances = torch.linalg.norm(pair_vectors, dim=1)
else:
if pair_distances.ndim != 1 or pair_distances.shape != first.shape:
raise ValueError("`pair_distances` must have shape (n_pairs,)")
if pair_shifts is not None:
if pair_shifts.ndim != 2 or pair_shifts.shape != (first.numel(), 3):
raise ValueError("`pair_shifts` must have shape (n_pairs, 3)")
shift_norm = torch.sum(torch.abs(pair_shifts.to(torch.int64)), dim=1)
else:
shift_norm = None
pair_ids = first.to(torch.int64) * int(n_nodes) + second.to(torch.int64)
if torch.unique(pair_ids).numel() == pair_ids.numel():
return torch.arange(first.numel(), device=first.device, dtype=torch.int64)
order = torch.argsort(pair_ids, stable=True)
pair_ids_sorted = pair_ids[order]
distances_sorted = pair_distances[order]
_, counts = torch.unique_consecutive(pair_ids_sorted, return_counts=True)
offsets = torch.cat([counts.new_zeros(1), counts.cumsum(0)[:-1]])
kept_indices: list[torch.Tensor] = []
for offset, count in zip(offsets.tolist(), counts.tolist(), strict=True):
segment = slice(offset, offset + count)
local_distances = distances_sorted[segment]
min_distance = torch.min(local_distances)
candidates = torch.nonzero(
torch.isclose(
local_distances,
min_distance,
rtol=0.0,
atol=float(atol),
)
).reshape(-1)
if candidates.numel() > 1 and shift_norm is not None:
local_shift_norm = shift_norm[order[segment]].index_select(0, candidates)
chosen = candidates[int(torch.argmin(local_shift_norm).item())]
else:
chosen = candidates[0]
kept_indices.append(order[offset + int(chosen.item())])
return torch.stack(kept_indices)
__all__ = [
"BucketItem",
"Buckets",
"PatternTripletLayout",
"TensorBucketPatternPlans",
"build_edge_category_table",
"crosscat_pairs",
"build_tensor_bucket_pattern_plans",
"build_tensor_bucket_pattern_plans_torch",
"map_atomic_numbers_to_categories",
"num_edge_categories",
"pair_distance_partials",
"pair_distance_partials_batched",
"pattern_triplet_layout",
"preprocess_sources",
"samecat_pairs",
"shortest_unique_directed_pair_indices",
"shortest_unique_directed_pairs",
"tensorize_bucket_pattern_plans",
]