Source code for ufp.terms._threebody_ops

"""
Low-level combinatorics for three-body spline assembly.

This module groups neighbors by category and exposes reusable bucket, edge, and
distance-derivative helpers shared by forward and fitting code.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Sequence

import torch


BucketItem = tuple[int, torch.Tensor, torch.Tensor]


[docs] @dataclass(frozen=True) class Buckets: """Flattened CSR-like neighborhood batches grouped by shared source patterns.""" src_ids: torch.Tensor row_ptr: torch.Tensor nbr_ids: torch.Tensor pair_vectors: torch.Tensor pair_distances: torch.Tensor patterns: torch.Tensor pattern_ptr: torch.Tensor pattern_plans: tuple["BucketPatternPlan", ...] = field(default_factory=tuple) tensor_pattern_plans: "TensorBucketPatternPlans | None" = None def __bool__(self) -> bool: """Report whether any valid source neighborhoods were retained.""" return bool(self.src_ids.numel())
[docs] def to_input_device( self, *, device: torch.device, dtype: torch.dtype, ) -> "Buckets": """Return a copy on the requested input device and floating dtype.""" return Buckets( src_ids=self.src_ids.to(device=device, non_blocking=True), row_ptr=self.row_ptr.to(device=device, non_blocking=True), nbr_ids=self.nbr_ids.to(device=device, non_blocking=True), pair_vectors=self.pair_vectors.to( device=device, dtype=dtype, non_blocking=True, ), pair_distances=self.pair_distances.to( device=device, dtype=dtype, non_blocking=True, ), patterns=self.patterns.to(device=device, non_blocking=True), pattern_ptr=self.pattern_ptr.to(device=device, non_blocking=True), pattern_plans=tuple(plan.to_device(device) for plan in self.pattern_plans), tensor_pattern_plans=( None if self.tensor_pattern_plans is None else self.tensor_pattern_plans.to_device(device) ), )
[docs] def with_pattern_plans(self, device: torch.device | None = None) -> "Buckets": """Return buckets with reusable CPU-side pattern metadata attached.""" resolved_device = self.pair_vectors.device if device is None else device pattern_plans = build_bucket_pattern_plans( self.patterns, self.pattern_ptr, self.row_ptr, torch.device(resolved_device), ) return Buckets( src_ids=self.src_ids, row_ptr=self.row_ptr, nbr_ids=self.nbr_ids, pair_vectors=self.pair_vectors, pair_distances=self.pair_distances, patterns=self.patterns, pattern_ptr=self.pattern_ptr, pattern_plans=pattern_plans, tensor_pattern_plans=tensorize_bucket_pattern_plans( pattern_plans, torch.device(resolved_device), ), )
@dataclass(frozen=True) class PatternTripletLayout: """Cached triplet index layout for one neighbor-category histogram.""" row: torch.Tensor col: torch.Tensor edge_cat: torch.Tensor def to_device(self, device: torch.device) -> "PatternTripletLayout": """Return this layout on the requested device.""" if self.row.device == device: return self return PatternTripletLayout( row=self.row.to(device=device, non_blocking=True), col=self.col.to(device=device, non_blocking=True), edge_cat=self.edge_cat.to(device=device, non_blocking=True), ) @dataclass(frozen=True) class BucketPatternPlan: """CPU-side bucket metadata plus device triplet layout for one source pattern.""" src_cat: int counts: tuple[int, ...] src_start: int src_end: int edge_start: int edge_end: int layout: PatternTripletLayout def to_device(self, device: torch.device) -> "BucketPatternPlan": """Return a copy whose triplet layout lives on ``device``.""" return BucketPatternPlan( src_cat=self.src_cat, counts=self.counts, src_start=self.src_start, src_end=self.src_end, edge_start=self.edge_start, edge_end=self.edge_end, layout=self.layout.to_device(device), ) @dataclass(frozen=True) class TensorBucketPatternPlans: """Tensorized bucket metadata for native three-body operator calls.""" src_cat: torch.Tensor src_start: torch.Tensor src_end: torch.Tensor edge_start: torch.Tensor edge_end: torch.Tensor degree: torch.Tensor layout_ptr: torch.Tensor row: torch.Tensor col: torch.Tensor edge_cat: torch.Tensor def __bool__(self) -> bool: """Report whether any source pattern metadata is present.""" return bool(self.src_cat.numel()) def to_device(self, device: torch.device) -> "TensorBucketPatternPlans": """Return this tensor metadata on the requested device.""" if self.src_cat.device == device: return self return TensorBucketPatternPlans( src_cat=self.src_cat.to(device=device, non_blocking=True), src_start=self.src_start.to(device=device, non_blocking=True), src_end=self.src_end.to(device=device, non_blocking=True), edge_start=self.edge_start.to(device=device, non_blocking=True), edge_end=self.edge_end.to(device=device, non_blocking=True), degree=self.degree.to(device=device, non_blocking=True), layout_ptr=self.layout_ptr.to(device=device, non_blocking=True), row=self.row.to(device=device, non_blocking=True), col=self.col.to(device=device, non_blocking=True), edge_cat=self.edge_cat.to(device=device, non_blocking=True), ) _PAIR_CACHE_SAMECAT: dict[tuple[int, str], tuple[torch.Tensor, torch.Tensor]] = {} _PAIR_CACHE_CROSSCAT: dict[tuple[int, int, str], tuple[torch.Tensor, torch.Tensor]] = {} _PATTERN_TRIPLET_LAYOUT_CACHE: dict[ tuple[tuple[int, ...], str], PatternTripletLayout, ] = {}
[docs] def num_edge_categories(n_cat: int) -> int: """Return the number of unordered neighbor-category pairs.""" return n_cat * (n_cat + 1) // 2
[docs] def build_edge_category_table(n_cat: int, device=None) -> torch.Tensor: """Map ordered category pairs to unordered edge-category indices.""" table = torch.empty((n_cat, n_cat), dtype=torch.int64, device=device) idx = 0 for a in range(n_cat): for b in range(a, n_cat): table[a, b] = idx table[b, a] = idx idx += 1 return table
def edge_category_index(cat_a: int, cat_b: int, n_cat: int) -> int: """Return the unordered edge-category index for two category ids.""" first = min(int(cat_a), int(cat_b)) second = max(int(cat_a), int(cat_b)) return first * int(n_cat) - first * (first - 1) // 2 + (second - first)
[docs] def preprocess_sources( i: torch.Tensor, j: torch.Tensor, node_cat: torch.Tensor, n_cat: int, pair_vectors: torch.Tensor, pair_distances: torch.Tensor | None = None, ) -> Buckets: """Group directed neighbor rows into source buckets for triplet assembly.""" if i.ndim != 1 or j.ndim != 1 or i.shape != j.shape: raise ValueError("`i` and `j` must be 1D tensors with the same shape") if pair_vectors.ndim != 2 or pair_vectors.shape != (i.numel(), 3): raise ValueError("`pair_vectors` must have shape (n_pairs, 3)") if pair_distances is None: pair_distances = torch.linalg.vector_norm(pair_vectors, dim=1) elif pair_distances.ndim != 1 or pair_distances.shape != i.shape: raise ValueError("`pair_distances` must have shape (n_pairs,)") if i.numel() == 0: empty_int = torch.zeros((0,), dtype=torch.int64, device=i.device) return Buckets( src_ids=empty_int, row_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device), nbr_ids=empty_int, pair_vectors=pair_vectors.new_zeros((0, 3)), pair_distances=pair_distances.new_zeros((0,)), patterns=torch.zeros((0, n_cat + 1), dtype=torch.int64, device=i.device), pattern_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device), ) nbr_cat = node_cat[j] # Sort globally by (source, neighbor category, neighbor id) so each source # neighborhood becomes one contiguous CSR row with category blocks. n_nodes = int(node_cat.numel()) row_sort_key = (i * int(n_cat) + nbr_cat.to(torch.int64)) * n_nodes + j perm = torch.argsort(row_sort_key, stable=True) i = i[perm] j = j[perm] nbr_cat = nbr_cat[perm] pair_vectors = pair_vectors[perm] pair_distances = pair_distances[perm] src_vals, counts = torch.unique_consecutive(i, return_counts=True) keep_src = counts >= 2 if not torch.any(keep_src): empty_int = torch.zeros((0,), dtype=torch.int64, device=i.device) return Buckets( src_ids=empty_int, row_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device), nbr_ids=empty_int, pair_vectors=pair_vectors.new_zeros((0, 3)), pair_distances=pair_distances.new_zeros((0,)), patterns=torch.zeros((0, n_cat + 1), dtype=torch.int64, device=i.device), pattern_ptr=torch.zeros((1,), dtype=torch.int64, device=i.device), ) edge_keep = torch.repeat_interleave(keep_src, counts) i = i[edge_keep] j = j[edge_keep] nbr_cat = nbr_cat[edge_keep] pair_vectors = pair_vectors[edge_keep] pair_distances = pair_distances[edge_keep] src_ids = src_vals[keep_src] degrees = counts[keep_src] source_index = torch.repeat_interleave( torch.arange(src_ids.numel(), device=i.device, dtype=torch.int64), degrees, ) counts_by_source_cat = torch.bincount( source_index * n_cat + nbr_cat.to(torch.int64), minlength=src_ids.numel() * n_cat, ).reshape(src_ids.numel(), n_cat) src_cat = node_cat[src_ids].to(torch.int64) pattern_rows = torch.cat([src_cat[:, None], counts_by_source_cat], dim=1) patterns, pattern_index = torch.unique( pattern_rows, dim=0, return_inverse=True, sorted=True, ) source_order = torch.argsort(pattern_index, stable=True) original_row_ptr = torch.cat( [torch.zeros((1,), dtype=torch.int64, device=i.device), degrees.cumsum(0)] ) ordered_degrees = degrees[source_order] ordered_edge_starts = original_row_ptr[source_order] ordered_row_starts = torch.cat( [ torch.zeros((1,), dtype=torch.int64, device=i.device), ordered_degrees.cumsum(0)[:-1], ] ) edge_offsets = torch.arange(i.numel(), device=i.device, dtype=torch.int64) edge_offsets = edge_offsets - torch.repeat_interleave( ordered_row_starts, ordered_degrees, ) edge_order = ( torch.repeat_interleave(ordered_edge_starts, ordered_degrees) + edge_offsets ) src_ids = src_ids[source_order] degrees = ordered_degrees pattern_index = pattern_index[source_order] _, pattern_counts = torch.unique_consecutive(pattern_index, return_counts=True) row_ptr = torch.cat( [torch.zeros((1,), dtype=torch.int64, device=i.device), degrees.cumsum(0)] ) pattern_ptr = torch.cat( [ torch.zeros((1,), dtype=torch.int64, device=i.device), pattern_counts.cumsum(0), ] ) return Buckets( src_ids=src_ids, row_ptr=row_ptr, nbr_ids=j[edge_order], pair_vectors=pair_vectors[edge_order], pair_distances=pair_distances[edge_order], patterns=patterns, pattern_ptr=pattern_ptr, )
def samecat_pairs( count: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """Return cached upper-triangular pair indices within one category block.""" key = (count, str(device)) if key not in _PAIR_CACHE_SAMECAT: row, col = torch.triu_indices(count, count, offset=1) _PAIR_CACHE_SAMECAT[key] = ( row.to(device, non_blocking=True), col.to(device, non_blocking=True), ) return _PAIR_CACHE_SAMECAT[key] def crosscat_pairs( ca: int, cb: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """Return cached Cartesian pair indices between two category blocks.""" key = (ca, cb, str(device)) if key not in _PAIR_CACHE_CROSSCAT: row = torch.arange(ca, dtype=torch.int64).repeat_interleave(cb) col = torch.arange(cb, dtype=torch.int64).repeat(ca) _PAIR_CACHE_CROSSCAT[key] = ( row.to(device, non_blocking=True), col.to(device, non_blocking=True), ) return _PAIR_CACHE_CROSSCAT[key] def pattern_triplet_layout( counts: Sequence[int], device: torch.device, ) -> PatternTripletLayout: """Return cached triplet row/column offsets for one source pattern.""" normalized_counts = tuple(int(count) for count in counts) key = (normalized_counts, str(device)) if key in _PATTERN_TRIPLET_LAYOUT_CACHE: return _PATTERN_TRIPLET_LAYOUT_CACHE[key] n_cat = len(normalized_counts) offsets = [0] for count in normalized_counts[:-1]: offsets.append(offsets[-1] + count) row_parts: list[torch.Tensor] = [] col_parts: list[torch.Tensor] = [] edge_parts: list[torch.Tensor] = [] for cat_a in range(n_cat): count_a = normalized_counts[cat_a] if count_a == 0: continue offset_a = offsets[cat_a] row, col = samecat_pairs(count_a, device) if row.numel() > 0: row_parts.append(offset_a + row) col_parts.append(offset_a + col) edge_parts.append( torch.full_like( row, edge_category_index(cat_a, cat_a, n_cat), ) ) for cat_b in range(cat_a + 1, n_cat): count_b = normalized_counts[cat_b] if count_b == 0: continue offset_b = offsets[cat_b] row, col = crosscat_pairs(count_a, count_b, device) if row.numel() == 0: continue row_parts.append(offset_a + row) col_parts.append(offset_b + col) edge_parts.append( torch.full_like( row, edge_category_index(cat_a, cat_b, n_cat), ) ) if row_parts: layout = PatternTripletLayout( row=torch.cat(row_parts, dim=0), col=torch.cat(col_parts, dim=0), edge_cat=torch.cat(edge_parts, dim=0), ) else: empty = torch.zeros((0,), dtype=torch.int64, device=device) layout = PatternTripletLayout( row=empty, col=empty, edge_cat=empty, ) _PATTERN_TRIPLET_LAYOUT_CACHE[key] = layout return layout def build_bucket_pattern_plans( patterns: torch.Tensor, pattern_ptr: torch.Tensor, row_ptr: torch.Tensor, device: torch.device, ) -> tuple[BucketPatternPlan, ...]: """Build reusable Python metadata for walking bucketed source patterns.""" patterns_host = ( patterns if patterns.device.type == "cpu" else patterns.detach().cpu() ) pattern_ptr_host = ( pattern_ptr if pattern_ptr.device.type == "cpu" else pattern_ptr.detach().cpu() ) row_ptr_host = row_ptr if row_ptr.device.type == "cpu" else row_ptr.detach().cpu() plans: list[BucketPatternPlan] = [] for pattern_index in range(int(patterns_host.shape[0])): pattern = patterns_host[pattern_index] counts = tuple(int(value) for value in pattern[1:].tolist()) plans.append( BucketPatternPlan( src_cat=int(pattern[0]), counts=counts, src_start=int(pattern_ptr_host[pattern_index]), src_end=int(pattern_ptr_host[pattern_index + 1]), edge_start=int(row_ptr_host[int(pattern_ptr_host[pattern_index])]), edge_end=int(row_ptr_host[int(pattern_ptr_host[pattern_index + 1])]), layout=pattern_triplet_layout(counts, device), ) ) return tuple(plans) def tensorize_bucket_pattern_plans( pattern_plans: Sequence[BucketPatternPlan], device: torch.device, ) -> TensorBucketPatternPlans: """Convert Python bucket pattern plans to a tensor-only execution contract.""" int_kwargs = {"dtype": torch.int64, "device": device} n_plans = len(pattern_plans) if n_plans == 0: empty = torch.empty((0,), **int_kwargs) return TensorBucketPatternPlans( src_cat=empty, src_start=empty, src_end=empty, edge_start=empty, edge_end=empty, degree=empty, layout_ptr=torch.zeros((1,), **int_kwargs), row=empty, col=empty, edge_cat=empty, ) src_cat = torch.tensor([plan.src_cat for plan in pattern_plans], **int_kwargs) src_start = torch.tensor([plan.src_start for plan in pattern_plans], **int_kwargs) src_end = torch.tensor([plan.src_end for plan in pattern_plans], **int_kwargs) edge_start = torch.tensor([plan.edge_start for plan in pattern_plans], **int_kwargs) edge_end = torch.tensor([plan.edge_end for plan in pattern_plans], **int_kwargs) degree = torch.tensor([sum(plan.counts) for plan in pattern_plans], **int_kwargs) layout_lengths = torch.tensor( [plan.layout.row.numel() for plan in pattern_plans], **int_kwargs, ) layout_ptr = torch.cat( [ torch.zeros((1,), **int_kwargs), layout_lengths.cumsum(0), ], ) row_parts = [plan.layout.row.to(device=device) for plan in pattern_plans] col_parts = [plan.layout.col.to(device=device) for plan in pattern_plans] edge_cat_parts = [plan.layout.edge_cat.to(device=device) for plan in pattern_plans] return TensorBucketPatternPlans( src_cat=src_cat, src_start=src_start, src_end=src_end, edge_start=edge_start, edge_end=edge_end, degree=degree, layout_ptr=layout_ptr, row=torch.cat(row_parts, dim=0), col=torch.cat(col_parts, dim=0), edge_cat=torch.cat(edge_cat_parts, dim=0), ) def build_tensor_bucket_pattern_plans( patterns: torch.Tensor, pattern_ptr: torch.Tensor, row_ptr: torch.Tensor, device: torch.device, ) -> TensorBucketPatternPlans: """Build tensorized pattern metadata directly from bucket CSR structures.""" return tensorize_bucket_pattern_plans( build_bucket_pattern_plans(patterns, pattern_ptr, row_ptr, device), device, ) def build_tensor_bucket_pattern_plans_torch( patterns: torch.Tensor, pattern_ptr: torch.Tensor, row_ptr: torch.Tensor, device: torch.device | None = None, ) -> TensorBucketPatternPlans: """ Build tensorized pattern metadata with Torch tensor operations. This path is intended for CUDA dynamic inference, where the native evaluator consumes tensor metadata and the Python ``BucketPatternPlan`` tuple would otherwise force all pattern rows through host-side loops. Args: patterns: Tensor of source-category and neighbor-count patterns. pattern_ptr: CSR pointer into source rows for each pattern. row_ptr: CSR pointer into neighbor rows for each source. device: Optional destination device for the tensorized metadata. Returns: Tensorized pattern plans on the resolved device. """ resolved_device = patterns.device if device is None else torch.device(device) int_kwargs = {"dtype": torch.int64, "device": resolved_device} n_plans = int(patterns.shape[0]) if n_plans == 0: empty = torch.empty((0,), **int_kwargs) return TensorBucketPatternPlans( src_cat=empty, src_start=empty, src_end=empty, edge_start=empty, edge_end=empty, degree=empty, layout_ptr=torch.zeros((1,), **int_kwargs), row=empty, col=empty, edge_cat=empty, ) patterns = patterns.to(device=resolved_device, dtype=torch.int64) pattern_ptr = pattern_ptr.to(device=resolved_device, dtype=torch.int64) row_ptr = row_ptr.to(device=resolved_device, dtype=torch.int64) counts = patterns[:, 1:] n_cat = int(counts.shape[1]) src_start = pattern_ptr[:-1] src_end = pattern_ptr[1:] degree = counts.sum(dim=1) max_degree = int(degree.detach().max().cpu().item()) if max_degree < 2: empty = torch.empty((0,), **int_kwargs) return TensorBucketPatternPlans( src_cat=patterns[:, 0], src_start=src_start, src_end=src_end, edge_start=row_ptr.index_select(0, src_start), edge_end=row_ptr.index_select(0, src_end), degree=degree, layout_ptr=torch.zeros((n_plans + 1,), **int_kwargs), row=empty, col=empty, edge_cat=empty, ) offsets = torch.cat( [ torch.zeros((n_plans, 1), **int_kwargs), counts.cumsum(dim=1)[:, :-1], ], dim=1, ) plan_ids = torch.arange(n_plans, **int_kwargs) layout_lengths = torch.zeros((n_plans,), **int_kwargs) row_parts: list[torch.Tensor] = [] col_parts: list[torch.Tensor] = [] edge_cat_parts: list[torch.Tensor] = [] plan_parts: list[torch.Tensor] = [] pair_order_parts: list[torch.Tensor] = [] local_order_parts: list[torch.Tensor] = [] pair_order = 0 for cat_a in range(n_cat): count_a = counts[:, cat_a] offset_a = offsets[:, cat_a] same_count = int(count_a.detach().max().cpu().item()) if same_count > 1: base_row, base_col = torch.triu_indices( same_count, same_count, offset=1, dtype=torch.int64, device=resolved_device, ) valid = base_col.unsqueeze(0) < count_a.unsqueeze(1) layout_lengths += count_a * (count_a - 1) // 2 row_values = offset_a[:, None] + base_row[None, :] col_values = offset_a[:, None] + base_col[None, :] local_order = torch.arange(base_row.numel(), **int_kwargs) row_parts.append(row_values.expand(n_plans, -1)[valid]) col_parts.append(col_values.expand(n_plans, -1)[valid]) edge_cat_parts.append( torch.full( (int(valid.sum().item()),), edge_category_index(cat_a, cat_a, n_cat), **int_kwargs, ) ) plan_parts.append( plan_ids[:, None].expand(n_plans, base_row.numel())[valid] ) pair_order_parts.append( torch.full((int(valid.sum().item()),), pair_order, **int_kwargs) ) local_order_parts.append(local_order[None, :].expand(n_plans, -1)[valid]) pair_order += 1 for cat_b in range(cat_a + 1, n_cat): count_b = counts[:, cat_b] offset_b = offsets[:, cat_b] layout_lengths += count_a * count_b max_a = int(count_a.detach().max().cpu().item()) max_b = int(count_b.detach().max().cpu().item()) if max_a == 0 or max_b == 0: pair_order += 1 continue base_row = torch.arange(max_a, **int_kwargs).repeat_interleave(max_b) base_col = torch.arange(max_b, **int_kwargs).repeat(max_a) valid = (base_row.unsqueeze(0) < count_a.unsqueeze(1)) & ( base_col.unsqueeze(0) < count_b.unsqueeze(1) ) row_values = offset_a[:, None] + base_row[None, :] col_values = offset_b[:, None] + base_col[None, :] local_order = torch.arange(base_row.numel(), **int_kwargs) row_parts.append(row_values.expand(n_plans, -1)[valid]) col_parts.append(col_values.expand(n_plans, -1)[valid]) edge_cat_parts.append( torch.full( (int(valid.sum().item()),), edge_category_index(cat_a, cat_b, n_cat), **int_kwargs, ) ) plan_parts.append( plan_ids[:, None].expand(n_plans, base_row.numel())[valid] ) pair_order_parts.append( torch.full((int(valid.sum().item()),), pair_order, **int_kwargs) ) local_order_parts.append(local_order[None, :].expand(n_plans, -1)[valid]) pair_order += 1 layout_ptr = torch.cat( [ torch.zeros((1,), **int_kwargs), layout_lengths.cumsum(dim=0), ], ) if not row_parts: empty = torch.empty((0,), **int_kwargs) row = empty col = empty edge_cat = empty else: max_local_slots = max(max_degree * max_degree, 1) sort_key = ( torch.cat(plan_parts) * (num_edge_categories(n_cat) * max_local_slots) + torch.cat(pair_order_parts) * max_local_slots + torch.cat(local_order_parts) ) order = torch.argsort(sort_key, stable=True) row = torch.cat(row_parts).index_select(0, order).contiguous() col = torch.cat(col_parts).index_select(0, order).contiguous() edge_cat = torch.cat(edge_cat_parts).index_select(0, order).contiguous() return TensorBucketPatternPlans( src_cat=patterns[:, 0], src_start=src_start, src_end=src_end, edge_start=row_ptr.index_select(0, src_start), edge_end=row_ptr.index_select(0, src_end), degree=degree, layout_ptr=layout_ptr, row=row, col=col, edge_cat=edge_cat, ) def pair_distance_partials( ex: torch.Tensor, ey: torch.Tensor, ez: torch.Tensor, vj: torch.Tensor, vk: torch.Tensor, diff: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, eps: float, ) -> tuple[torch.Tensor, torch.Tensor]: """Return pair distance partials.""" inv_x = torch.where(x > eps, x.reciprocal(), torch.zeros_like(x)) inv_y = torch.where(y > eps, y.reciprocal(), torch.zeros_like(y)) inv_z = torch.where(z > eps, z.reciprocal(), torch.zeros_like(z)) dE_dvj = ex[:, None] * (vj * inv_x[:, None]) + ez[:, None] * (diff * inv_z[:, None]) dE_dvk = ey[:, None] * (vk * inv_y[:, None]) - ez[:, None] * (diff * inv_z[:, None]) return dE_dvj, dE_dvk def pair_distance_partials_batched( ex: torch.Tensor, ey: torch.Tensor, ez: torch.Tensor, vj: torch.Tensor, vk: torch.Tensor, diff: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, eps: float, ) -> tuple[torch.Tensor, torch.Tensor]: """Return pair distance partials for one stencil axis per triplet.""" inv_x = torch.where(x > eps, x.reciprocal(), torch.zeros_like(x)) inv_y = torch.where(y > eps, y.reciprocal(), torch.zeros_like(y)) inv_z = torch.where(z > eps, z.reciprocal(), torch.zeros_like(z)) unit_x = vj * inv_x[:, None] unit_y = vk * inv_y[:, None] unit_z = diff * inv_z[:, None] dE_dvj = ex[:, :, None] * unit_x[:, None, :] + ez[:, :, None] * unit_z[:, None, :] dE_dvk = ey[:, :, None] * unit_y[:, None, :] - ez[:, :, None] * unit_z[:, None, :] return dE_dvj, dE_dvk def map_atomic_numbers_to_categories( atomic_numbers: torch.Tensor, atomic_types: Sequence[int], ) -> torch.Tensor: """Map atomic numbers to categories.""" categories = torch.full_like(atomic_numbers, fill_value=-1, dtype=torch.int64) for category, atomic_number in enumerate(atomic_types): categories[atomic_numbers == int(atomic_number)] = category return categories def shortest_unique_directed_pairs( first: torch.Tensor, second: torch.Tensor, pair_vectors: torch.Tensor, *, n_nodes: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Keep one directed pair per ``i -> j`` by selecting the shortest image.""" keep = shortest_unique_directed_pair_indices( first, second, n_nodes=n_nodes, pair_vectors=pair_vectors, ) return first[keep], second[keep], pair_vectors[keep] def shortest_unique_directed_pair_indices( first: torch.Tensor, second: torch.Tensor, *, n_nodes: int, pair_vectors: torch.Tensor | None = None, pair_distances: torch.Tensor | None = None, pair_shifts: torch.Tensor | None = None, atol: float = 1.0e-12, ) -> torch.Tensor: """Return the retained row indices when duplicate directed periodic images exist.""" if first.ndim != 1 or second.ndim != 1 or first.shape != second.shape: raise ValueError("`first` and `second` must be 1D tensors with the same shape") if pair_distances is None: if pair_vectors is None: raise ValueError( "either `pair_distances` or `pair_vectors` is required to resolve " "duplicate directed pairs" ) if pair_vectors.ndim != 2 or pair_vectors.shape != (first.numel(), 3): raise ValueError("`pair_vectors` must have shape (n_pairs, 3)") pair_distances = torch.linalg.norm(pair_vectors, dim=1) else: if pair_distances.ndim != 1 or pair_distances.shape != first.shape: raise ValueError("`pair_distances` must have shape (n_pairs,)") if pair_shifts is not None: if pair_shifts.ndim != 2 or pair_shifts.shape != (first.numel(), 3): raise ValueError("`pair_shifts` must have shape (n_pairs, 3)") shift_norm = torch.sum(torch.abs(pair_shifts.to(torch.int64)), dim=1) else: shift_norm = None pair_ids = first.to(torch.int64) * int(n_nodes) + second.to(torch.int64) if torch.unique(pair_ids).numel() == pair_ids.numel(): return torch.arange(first.numel(), device=first.device, dtype=torch.int64) order = torch.argsort(pair_ids, stable=True) pair_ids_sorted = pair_ids[order] distances_sorted = pair_distances[order] _, counts = torch.unique_consecutive(pair_ids_sorted, return_counts=True) offsets = torch.cat([counts.new_zeros(1), counts.cumsum(0)[:-1]]) kept_indices: list[torch.Tensor] = [] for offset, count in zip(offsets.tolist(), counts.tolist(), strict=True): segment = slice(offset, offset + count) local_distances = distances_sorted[segment] min_distance = torch.min(local_distances) candidates = torch.nonzero( torch.isclose( local_distances, min_distance, rtol=0.0, atol=float(atol), ) ).reshape(-1) if candidates.numel() > 1 and shift_norm is not None: local_shift_norm = shift_norm[order[segment]].index_select(0, candidates) chosen = candidates[int(torch.argmin(local_shift_norm).item())] else: chosen = candidates[0] kept_indices.append(order[offset + int(chosen.item())]) return torch.stack(kept_indices) __all__ = [ "BucketItem", "Buckets", "PatternTripletLayout", "TensorBucketPatternPlans", "build_edge_category_table", "crosscat_pairs", "build_tensor_bucket_pattern_plans", "build_tensor_bucket_pattern_plans_torch", "map_atomic_numbers_to_categories", "num_edge_categories", "pair_distance_partials", "pair_distance_partials_batched", "pattern_triplet_layout", "preprocess_sources", "samecat_pairs", "shortest_unique_directed_pair_indices", "shortest_unique_directed_pairs", "tensorize_bucket_pattern_plans", ]