"""
Neighbor-list data container and concatenation utilities.
This module keeps pair indices, shifts, vectors, and distances in one reusable
structure that can move cleanly between numpy and torch.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
import numpy as np
import torch
from ufp.core._arrays import ArrayLike, _concatenate_arrays, _to_tensor
def _normalize_pair_array(pairs: ArrayLike) -> ArrayLike:
"""Normalize pair indices to shape ``(2, n_pairs)``."""
if pairs.ndim != 2:
raise ValueError("`pairs` must be a 2-dimensional array")
if pairs.shape[0] == 2:
return pairs
if pairs.shape[1] == 2:
return pairs.T
raise ValueError("`pairs` must have shape (2, n_pairs) or (n_pairs, 2)")
[docs]
@dataclass
class NeighborListData:
"""
Lightweight container for neighbor-list data used by UFP models.
Attributes:
pairs: Pair indices with shape ``(2, n_pairs)`` or ``(n_pairs, 2)``.
shifts: Periodic cell shifts with shape ``(n_pairs, 3)``.
distances: Optional pair distances with shape ``(n_pairs,)``.
vectors: Optional pair displacement vectors with shape ``(n_pairs, 3)``.
backend: Name of the backend that produced this neighbor list.
cutoff: Cutoff used to build the list, if known.
full_list: Whether the list contains both ``i-j`` and ``j-i`` entries.
sorted: Whether the list is lexicographically sorted by pair indices.
strict: Whether all pairs are guaranteed to be strictly within the cutoff.
"""
pairs: ArrayLike
shifts: ArrayLike
distances: Optional[ArrayLike] = None
vectors: Optional[ArrayLike] = None
backend: str = "unknown"
cutoff: Optional[float] = None
full_list: bool = True
sorted: bool = False
strict: Optional[bool] = None
def __post_init__(self) -> None:
"""Validate pair, shift, distance, and vector shapes."""
self.pairs = _normalize_pair_array(self.pairs)
if self.shifts.ndim != 2 or self.shifts.shape[1] != 3:
raise ValueError("`shifts` must have shape (n_pairs, 3)")
n_pairs = self.pairs.shape[1]
if self.shifts.shape[0] != n_pairs:
raise ValueError(
"`pairs` and `shifts` must describe the same number of pairs"
)
if self.distances is not None and self.distances.shape[0] != n_pairs:
raise ValueError("`distances` must have shape (n_pairs,)")
if self.vectors is not None:
if self.vectors.ndim != 2 or self.vectors.shape != (n_pairs, 3):
raise ValueError("`vectors` must have shape (n_pairs, 3)")
@property
def n_pairs(self) -> int:
"""Return the number of pairs."""
return int(self.pairs.shape[1])
[docs]
def shifted(self, atom_offset: int) -> "NeighborListData":
"""Return a copy with all pair indices shifted by the requested atom offset."""
if atom_offset == 0:
return self
if isinstance(self.pairs, torch.Tensor):
pairs = self.pairs.to(dtype=torch.int64) + atom_offset
else:
pairs = np.asarray(self.pairs, dtype=np.int64) + atom_offset
return NeighborListData(
pairs=pairs,
shifts=self.shifts,
distances=self.distances,
vectors=self.vectors,
backend=self.backend,
cutoff=self.cutoff,
full_list=self.full_list,
sorted=self.sorted,
strict=self.strict,
)
[docs]
def as_torch(
self,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> "NeighborListData":
"""Return a torch-backed copy of this neighbor list."""
float_dtype = torch.get_default_dtype() if dtype is None else dtype
return NeighborListData(
pairs=_to_tensor(self.pairs, dtype=torch.int64, device=device),
shifts=_to_tensor(self.shifts, dtype=torch.int64, device=device),
distances=None
if self.distances is None
else _to_tensor(self.distances, dtype=float_dtype, device=device),
vectors=None
if self.vectors is None
else _to_tensor(self.vectors, dtype=float_dtype, device=device),
backend=self.backend,
cutoff=self.cutoff,
full_list=self.full_list,
sorted=self.sorted,
strict=self.strict,
)
[docs]
def pin_memory(self) -> "NeighborListData":
"""Pin any torch-backed arrays in place and return ``self``."""
if isinstance(self.pairs, torch.Tensor):
self.pairs = self.pairs.pin_memory()
if isinstance(self.shifts, torch.Tensor):
self.shifts = self.shifts.pin_memory()
if isinstance(self.distances, torch.Tensor):
self.distances = self.distances.pin_memory()
if isinstance(self.vectors, torch.Tensor):
self.vectors = self.vectors.pin_memory()
return self
[docs]
def masked(self, mask: ArrayLike) -> "NeighborListData":
"""Return a new neighbor list containing only the selected pairs."""
if isinstance(self.pairs, torch.Tensor):
mask_tensor = _to_tensor(mask, dtype=torch.bool, device=self.pairs.device)
if mask_tensor.ndim != 1 or mask_tensor.shape[0] != self.n_pairs:
raise ValueError("`mask` must have shape (n_pairs,)")
pairs = self.pairs[:, mask_tensor]
shifts = self.shifts[mask_tensor]
distances = None if self.distances is None else self.distances[mask_tensor]
vectors = None if self.vectors is None else self.vectors[mask_tensor]
else:
mask_array = np.asarray(mask, dtype=bool)
if mask_array.ndim != 1 or mask_array.shape[0] != self.n_pairs:
raise ValueError("`mask` must have shape (n_pairs,)")
pairs = np.asarray(self.pairs)[:, mask_array]
shifts = np.asarray(self.shifts)[mask_array]
distances = (
None
if self.distances is None
else np.asarray(self.distances)[mask_array]
)
vectors = (
None if self.vectors is None else np.asarray(self.vectors)[mask_array]
)
return NeighborListData(
pairs=pairs,
shifts=shifts,
distances=distances,
vectors=vectors,
backend=self.backend,
cutoff=self.cutoff,
full_list=self.full_list,
sorted=self.sorted,
strict=self.strict,
)
[docs]
def concatenate_neighbor_lists(
neighbor_lists: Sequence[NeighborListData],
*,
atom_offsets: Optional[Sequence[int]] = None,
) -> Optional[NeighborListData]:
"""Concatenate per-system neighbor lists into one batch-level list."""
if not neighbor_lists:
return None
if atom_offsets is None:
atom_offsets = [0] * len(neighbor_lists)
elif len(atom_offsets) != len(neighbor_lists):
raise ValueError("`atom_offsets` must match the number of neighbor lists")
shifted_lists = [
neighbor_list.shifted(atom_offset)
for neighbor_list, atom_offset in zip(neighbor_lists, atom_offsets, strict=True)
]
if len(shifted_lists) == 1:
return shifted_lists[0]
first = shifted_lists[0]
if any(
neighbor_list.full_list != first.full_list
for neighbor_list in shifted_lists[1:]
):
raise ValueError("all neighbor lists must agree on `full_list`")
if any(neighbor_list.sorted != first.sorted for neighbor_list in shifted_lists[1:]):
raise ValueError("all neighbor lists must agree on `sorted`")
strict = first.strict
if any(neighbor_list.strict != strict for neighbor_list in shifted_lists[1:]):
strict = None
cutoff = first.cutoff
if any(neighbor_list.cutoff != cutoff for neighbor_list in shifted_lists[1:]):
cutoff = None
backend = first.backend
if any(neighbor_list.backend != backend for neighbor_list in shifted_lists[1:]):
backend = "mixed"
pairs = _concatenate_arrays(
[neighbor_list.pairs for neighbor_list in shifted_lists],
axis=1,
)
shifts = _concatenate_arrays(
[neighbor_list.shifts for neighbor_list in shifted_lists],
axis=0,
)
distances = None
if all(neighbor_list.distances is not None for neighbor_list in shifted_lists):
distances = _concatenate_arrays(
[neighbor_list.distances for neighbor_list in shifted_lists],
axis=0,
)
vectors = None
if all(neighbor_list.vectors is not None for neighbor_list in shifted_lists):
vectors = _concatenate_arrays(
[neighbor_list.vectors for neighbor_list in shifted_lists],
axis=0,
)
return NeighborListData(
pairs=pairs,
shifts=shifts,
distances=distances,
vectors=vectors,
backend=backend,
cutoff=cutoff,
full_list=first.full_list,
sorted=first.sorted,
strict=strict,
)
__all__ = [
"NeighborListData",
"concatenate_neighbor_lists",
]