Source code for ufp.neighbors._data

"""
Neighbor-list data container and concatenation utilities.

This module keeps pair indices, shifts, vectors, and distances in one reusable
structure that can move cleanly between numpy and torch.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence

import numpy as np
import torch

from ufp.core._arrays import ArrayLike, _concatenate_arrays, _to_tensor


def _normalize_pair_array(pairs: ArrayLike) -> ArrayLike:
    """Normalize pair indices to shape ``(2, n_pairs)``."""
    if pairs.ndim != 2:
        raise ValueError("`pairs` must be a 2-dimensional array")

    if pairs.shape[0] == 2:
        return pairs

    if pairs.shape[1] == 2:
        return pairs.T

    raise ValueError("`pairs` must have shape (2, n_pairs) or (n_pairs, 2)")


[docs] @dataclass class NeighborListData: """ Lightweight container for neighbor-list data used by UFP models. Attributes: pairs: Pair indices with shape ``(2, n_pairs)`` or ``(n_pairs, 2)``. shifts: Periodic cell shifts with shape ``(n_pairs, 3)``. distances: Optional pair distances with shape ``(n_pairs,)``. vectors: Optional pair displacement vectors with shape ``(n_pairs, 3)``. backend: Name of the backend that produced this neighbor list. cutoff: Cutoff used to build the list, if known. full_list: Whether the list contains both ``i-j`` and ``j-i`` entries. sorted: Whether the list is lexicographically sorted by pair indices. strict: Whether all pairs are guaranteed to be strictly within the cutoff. """ pairs: ArrayLike shifts: ArrayLike distances: Optional[ArrayLike] = None vectors: Optional[ArrayLike] = None backend: str = "unknown" cutoff: Optional[float] = None full_list: bool = True sorted: bool = False strict: Optional[bool] = None def __post_init__(self) -> None: """Validate pair, shift, distance, and vector shapes.""" self.pairs = _normalize_pair_array(self.pairs) if self.shifts.ndim != 2 or self.shifts.shape[1] != 3: raise ValueError("`shifts` must have shape (n_pairs, 3)") n_pairs = self.pairs.shape[1] if self.shifts.shape[0] != n_pairs: raise ValueError( "`pairs` and `shifts` must describe the same number of pairs" ) if self.distances is not None and self.distances.shape[0] != n_pairs: raise ValueError("`distances` must have shape (n_pairs,)") if self.vectors is not None: if self.vectors.ndim != 2 or self.vectors.shape != (n_pairs, 3): raise ValueError("`vectors` must have shape (n_pairs, 3)") @property def n_pairs(self) -> int: """Return the number of pairs.""" return int(self.pairs.shape[1])
[docs] def shifted(self, atom_offset: int) -> "NeighborListData": """Return a copy with all pair indices shifted by the requested atom offset.""" if atom_offset == 0: return self if isinstance(self.pairs, torch.Tensor): pairs = self.pairs.to(dtype=torch.int64) + atom_offset else: pairs = np.asarray(self.pairs, dtype=np.int64) + atom_offset return NeighborListData( pairs=pairs, shifts=self.shifts, distances=self.distances, vectors=self.vectors, backend=self.backend, cutoff=self.cutoff, full_list=self.full_list, sorted=self.sorted, strict=self.strict, )
[docs] def as_torch( self, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> "NeighborListData": """Return a torch-backed copy of this neighbor list.""" float_dtype = torch.get_default_dtype() if dtype is None else dtype return NeighborListData( pairs=_to_tensor(self.pairs, dtype=torch.int64, device=device), shifts=_to_tensor(self.shifts, dtype=torch.int64, device=device), distances=None if self.distances is None else _to_tensor(self.distances, dtype=float_dtype, device=device), vectors=None if self.vectors is None else _to_tensor(self.vectors, dtype=float_dtype, device=device), backend=self.backend, cutoff=self.cutoff, full_list=self.full_list, sorted=self.sorted, strict=self.strict, )
[docs] def pin_memory(self) -> "NeighborListData": """Pin any torch-backed arrays in place and return ``self``.""" if isinstance(self.pairs, torch.Tensor): self.pairs = self.pairs.pin_memory() if isinstance(self.shifts, torch.Tensor): self.shifts = self.shifts.pin_memory() if isinstance(self.distances, torch.Tensor): self.distances = self.distances.pin_memory() if isinstance(self.vectors, torch.Tensor): self.vectors = self.vectors.pin_memory() return self
[docs] def masked(self, mask: ArrayLike) -> "NeighborListData": """Return a new neighbor list containing only the selected pairs.""" if isinstance(self.pairs, torch.Tensor): mask_tensor = _to_tensor(mask, dtype=torch.bool, device=self.pairs.device) if mask_tensor.ndim != 1 or mask_tensor.shape[0] != self.n_pairs: raise ValueError("`mask` must have shape (n_pairs,)") pairs = self.pairs[:, mask_tensor] shifts = self.shifts[mask_tensor] distances = None if self.distances is None else self.distances[mask_tensor] vectors = None if self.vectors is None else self.vectors[mask_tensor] else: mask_array = np.asarray(mask, dtype=bool) if mask_array.ndim != 1 or mask_array.shape[0] != self.n_pairs: raise ValueError("`mask` must have shape (n_pairs,)") pairs = np.asarray(self.pairs)[:, mask_array] shifts = np.asarray(self.shifts)[mask_array] distances = ( None if self.distances is None else np.asarray(self.distances)[mask_array] ) vectors = ( None if self.vectors is None else np.asarray(self.vectors)[mask_array] ) return NeighborListData( pairs=pairs, shifts=shifts, distances=distances, vectors=vectors, backend=self.backend, cutoff=self.cutoff, full_list=self.full_list, sorted=self.sorted, strict=self.strict, )
[docs] def concatenate_neighbor_lists( neighbor_lists: Sequence[NeighborListData], *, atom_offsets: Optional[Sequence[int]] = None, ) -> Optional[NeighborListData]: """Concatenate per-system neighbor lists into one batch-level list.""" if not neighbor_lists: return None if atom_offsets is None: atom_offsets = [0] * len(neighbor_lists) elif len(atom_offsets) != len(neighbor_lists): raise ValueError("`atom_offsets` must match the number of neighbor lists") shifted_lists = [ neighbor_list.shifted(atom_offset) for neighbor_list, atom_offset in zip(neighbor_lists, atom_offsets, strict=True) ] if len(shifted_lists) == 1: return shifted_lists[0] first = shifted_lists[0] if any( neighbor_list.full_list != first.full_list for neighbor_list in shifted_lists[1:] ): raise ValueError("all neighbor lists must agree on `full_list`") if any(neighbor_list.sorted != first.sorted for neighbor_list in shifted_lists[1:]): raise ValueError("all neighbor lists must agree on `sorted`") strict = first.strict if any(neighbor_list.strict != strict for neighbor_list in shifted_lists[1:]): strict = None cutoff = first.cutoff if any(neighbor_list.cutoff != cutoff for neighbor_list in shifted_lists[1:]): cutoff = None backend = first.backend if any(neighbor_list.backend != backend for neighbor_list in shifted_lists[1:]): backend = "mixed" pairs = _concatenate_arrays( [neighbor_list.pairs for neighbor_list in shifted_lists], axis=1, ) shifts = _concatenate_arrays( [neighbor_list.shifts for neighbor_list in shifted_lists], axis=0, ) distances = None if all(neighbor_list.distances is not None for neighbor_list in shifted_lists): distances = _concatenate_arrays( [neighbor_list.distances for neighbor_list in shifted_lists], axis=0, ) vectors = None if all(neighbor_list.vectors is not None for neighbor_list in shifted_lists): vectors = _concatenate_arrays( [neighbor_list.vectors for neighbor_list in shifted_lists], axis=0, ) return NeighborListData( pairs=pairs, shifts=shifts, distances=distances, vectors=vectors, backend=backend, cutoff=cutoff, full_list=first.full_list, sorted=first.sorted, strict=strict, )
__all__ = [ "NeighborListData", "concatenate_neighbor_lists", ]