Source code for ufp.neighbors._neighbors

"""
Neighbor-list construction helpers and backend selection.

Use this module to build ``NeighborListData`` from ASE structures while keeping
the backend choice explicit and swappable.
"""

from enum import Enum
from typing import List, Optional, Tuple, Union

import ase
import ase.neighborlist
import numpy as np
import torch

from ufp.core._arrays import ArrayLike
from ufp.neighbors._data import NeighborListData


[docs] class NeighborListBackend(str, Enum): """Known neighbor-list backends supported directly by UFP.""" AUTO = "auto" ASE = "ase" METATOMIC = "metatomic" VESIN = "vesin"
def _as_backend(backend: Union[str, NeighborListBackend]) -> NeighborListBackend: """Normalize a backend string or enum into ``NeighborListBackend``.""" if isinstance(backend, NeighborListBackend): return backend return NeighborListBackend(backend) def _vesin_available() -> bool: """Report whether the optional ``vesin`` backend can be imported.""" try: import vesin # noqa: F401 except ImportError: return False else: return True
[docs] def available_neighbor_backends() -> List[NeighborListBackend]: """ List neighbor-list backends available in the current Python environment. Returns: Available backends ordered by preference. """ backends = [NeighborListBackend.ASE] if _vesin_available(): backends.insert(0, NeighborListBackend.VESIN) return backends
def _maybe_to_torch(array: np.ndarray, arrays: str) -> ArrayLike: """Optionally convert numpy outputs to torch tensors.""" if arrays == "numpy": return array elif arrays == "torch": return torch.from_numpy(np.ascontiguousarray(array)) else: raise ValueError( f"unknown array target '{arrays}', expected 'numpy' or 'torch'" ) def _sort_outputs( pairs: np.ndarray, shifts: np.ndarray, distances: Optional[np.ndarray], vectors: Optional[np.ndarray], ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: """Sort neighbor-list rows lexicographically by pair index.""" sort_indices = np.lexsort((pairs[1], pairs[0])) pairs = pairs[:, sort_indices] shifts = shifts[sort_indices] if distances is not None: distances = distances[sort_indices] if vectors is not None: vectors = vectors[sort_indices] return pairs, shifts, distances, vectors def _build_ase_neighbor_list( atoms: ase.Atoms, cutoff: float, arrays: str, sorted: bool, ) -> NeighborListData: """Build ase neighbor list.""" i, j, shifts, vectors, distances = ase.neighborlist.neighbor_list( quantities="ijSDd", a=atoms, cutoff=cutoff, ) pairs = np.stack([i, j], axis=0) if sorted: pairs, shifts, distances, vectors = _sort_outputs( pairs, shifts, distances, vectors, ) return NeighborListData( pairs=_maybe_to_torch(pairs, arrays), shifts=_maybe_to_torch(shifts.astype(np.int32, copy=False), arrays), distances=_maybe_to_torch(distances, arrays), vectors=_maybe_to_torch(vectors, arrays), backend=NeighborListBackend.ASE.value, cutoff=cutoff, full_list=True, sorted=sorted, strict=True, ) def _build_vesin_neighbor_list( atoms: ase.Atoms, cutoff: float, arrays: str, full_list: bool, sorted: bool, ) -> NeighborListData: """Build vesin neighbor list.""" from vesin import NeighborList calculator = NeighborList(cutoff=cutoff, full_list=full_list, sorted=sorted) i, j, shifts, vectors, distances = calculator.compute( points=atoms.positions, box=atoms.cell[:], periodic=atoms.pbc, quantities="ijSDd", ) pairs = np.stack([i, j], axis=0) return NeighborListData( pairs=_maybe_to_torch(pairs, arrays), shifts=_maybe_to_torch(shifts.astype(np.int32, copy=False), arrays), distances=_maybe_to_torch(distances, arrays), vectors=_maybe_to_torch(vectors, arrays), backend=NeighborListBackend.VESIN.value, cutoff=cutoff, full_list=full_list, sorted=sorted, strict=True, )
[docs] def build_neighbor_list( atoms: ase.Atoms, cutoff: float, backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO, arrays: str = "torch", full_list: bool = True, sorted: bool = True, ) -> NeighborListData: """ Build a neighbor list for an ASE structure. When ``backend="auto"``, UFP prefers ``vesin`` if it is importable and falls back to :py:mod:`ase.neighborlist` otherwise. Args: atoms: Structure for which the neighbor list should be built. cutoff: Spherical cutoff radius. backend: One of ``"auto"``, ``"vesin"``, or ``"ase"``. arrays: Output array type, either ``"torch"`` or ``"numpy"``. full_list: Whether both ``i-j`` and ``j-i`` entries should be present. sorted: Whether to sort the output lexicographically by pair indices. Returns: Normalized neighbor-list data. Raises: TypeError: If ``atoms`` is not an ASE structure. ValueError: If options are invalid or the selected backend is unsupported. """ if not isinstance(atoms, ase.Atoms): raise TypeError(f"`atoms` should be ase.Atoms, got {type(atoms)}") cutoff = float(cutoff) if not np.isfinite(cutoff) or cutoff <= 0.0: raise ValueError("`cutoff` must be a finite, positive number") backend = _as_backend(backend) if backend == NeighborListBackend.AUTO: backend = ( NeighborListBackend.VESIN if _vesin_available() else NeighborListBackend.ASE ) if backend == NeighborListBackend.VESIN: return _build_vesin_neighbor_list( atoms=atoms, cutoff=cutoff, arrays=arrays, full_list=full_list, sorted=sorted, ) if backend == NeighborListBackend.ASE: if not full_list: raise ValueError("the ASE neighbor-list backend only supports full lists") return _build_ase_neighbor_list( atoms=atoms, cutoff=cutoff, arrays=arrays, sorted=sorted, ) raise ValueError(f"unsupported neighbor-list backend: {backend}")