"""
Neighbor-list construction helpers and backend selection.
Use this module to build ``NeighborListData`` from ASE structures while keeping
the backend choice explicit and swappable.
"""
from enum import Enum
from typing import List, Optional, Tuple, Union
import ase
import ase.neighborlist
import numpy as np
import torch
from ufp.core._arrays import ArrayLike
from ufp.neighbors._data import NeighborListData
[docs]
class NeighborListBackend(str, Enum):
"""Known neighbor-list backends supported directly by UFP."""
AUTO = "auto"
ASE = "ase"
METATOMIC = "metatomic"
VESIN = "vesin"
def _as_backend(backend: Union[str, NeighborListBackend]) -> NeighborListBackend:
"""Normalize a backend string or enum into ``NeighborListBackend``."""
if isinstance(backend, NeighborListBackend):
return backend
return NeighborListBackend(backend)
def _vesin_available() -> bool:
"""Report whether the optional ``vesin`` backend can be imported."""
try:
import vesin # noqa: F401
except ImportError:
return False
else:
return True
[docs]
def available_neighbor_backends() -> List[NeighborListBackend]:
"""
List neighbor-list backends available in the current Python environment.
Returns:
Available backends ordered by preference.
"""
backends = [NeighborListBackend.ASE]
if _vesin_available():
backends.insert(0, NeighborListBackend.VESIN)
return backends
def _maybe_to_torch(array: np.ndarray, arrays: str) -> ArrayLike:
"""Optionally convert numpy outputs to torch tensors."""
if arrays == "numpy":
return array
elif arrays == "torch":
return torch.from_numpy(np.ascontiguousarray(array))
else:
raise ValueError(
f"unknown array target '{arrays}', expected 'numpy' or 'torch'"
)
def _sort_outputs(
pairs: np.ndarray,
shifts: np.ndarray,
distances: Optional[np.ndarray],
vectors: Optional[np.ndarray],
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
"""Sort neighbor-list rows lexicographically by pair index."""
sort_indices = np.lexsort((pairs[1], pairs[0]))
pairs = pairs[:, sort_indices]
shifts = shifts[sort_indices]
if distances is not None:
distances = distances[sort_indices]
if vectors is not None:
vectors = vectors[sort_indices]
return pairs, shifts, distances, vectors
def _build_ase_neighbor_list(
atoms: ase.Atoms,
cutoff: float,
arrays: str,
sorted: bool,
) -> NeighborListData:
"""Build ase neighbor list."""
i, j, shifts, vectors, distances = ase.neighborlist.neighbor_list(
quantities="ijSDd",
a=atoms,
cutoff=cutoff,
)
pairs = np.stack([i, j], axis=0)
if sorted:
pairs, shifts, distances, vectors = _sort_outputs(
pairs,
shifts,
distances,
vectors,
)
return NeighborListData(
pairs=_maybe_to_torch(pairs, arrays),
shifts=_maybe_to_torch(shifts.astype(np.int32, copy=False), arrays),
distances=_maybe_to_torch(distances, arrays),
vectors=_maybe_to_torch(vectors, arrays),
backend=NeighborListBackend.ASE.value,
cutoff=cutoff,
full_list=True,
sorted=sorted,
strict=True,
)
def _build_vesin_neighbor_list(
atoms: ase.Atoms,
cutoff: float,
arrays: str,
full_list: bool,
sorted: bool,
) -> NeighborListData:
"""Build vesin neighbor list."""
from vesin import NeighborList
calculator = NeighborList(cutoff=cutoff, full_list=full_list, sorted=sorted)
i, j, shifts, vectors, distances = calculator.compute(
points=atoms.positions,
box=atoms.cell[:],
periodic=atoms.pbc,
quantities="ijSDd",
)
pairs = np.stack([i, j], axis=0)
return NeighborListData(
pairs=_maybe_to_torch(pairs, arrays),
shifts=_maybe_to_torch(shifts.astype(np.int32, copy=False), arrays),
distances=_maybe_to_torch(distances, arrays),
vectors=_maybe_to_torch(vectors, arrays),
backend=NeighborListBackend.VESIN.value,
cutoff=cutoff,
full_list=full_list,
sorted=sorted,
strict=True,
)
[docs]
def build_neighbor_list(
atoms: ase.Atoms,
cutoff: float,
backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO,
arrays: str = "torch",
full_list: bool = True,
sorted: bool = True,
) -> NeighborListData:
"""
Build a neighbor list for an ASE structure.
When ``backend="auto"``, UFP prefers ``vesin`` if it is importable and falls back
to :py:mod:`ase.neighborlist` otherwise.
Args:
atoms: Structure for which the neighbor list should be built.
cutoff: Spherical cutoff radius.
backend: One of ``"auto"``, ``"vesin"``, or ``"ase"``.
arrays: Output array type, either ``"torch"`` or ``"numpy"``.
full_list: Whether both ``i-j`` and ``j-i`` entries should be present.
sorted: Whether to sort the output lexicographically by pair indices.
Returns:
Normalized neighbor-list data.
Raises:
TypeError: If ``atoms`` is not an ASE structure.
ValueError: If options are invalid or the selected backend is unsupported.
"""
if not isinstance(atoms, ase.Atoms):
raise TypeError(f"`atoms` should be ase.Atoms, got {type(atoms)}")
cutoff = float(cutoff)
if not np.isfinite(cutoff) or cutoff <= 0.0:
raise ValueError("`cutoff` must be a finite, positive number")
backend = _as_backend(backend)
if backend == NeighborListBackend.AUTO:
backend = (
NeighborListBackend.VESIN if _vesin_available() else NeighborListBackend.ASE
)
if backend == NeighborListBackend.VESIN:
return _build_vesin_neighbor_list(
atoms=atoms,
cutoff=cutoff,
arrays=arrays,
full_list=full_list,
sorted=sorted,
)
if backend == NeighborListBackend.ASE:
if not full_list:
raise ValueError("the ASE neighbor-list backend only supports full lists")
return _build_ase_neighbor_list(
atoms=atoms,
cutoff=cutoff,
arrays=arrays,
sorted=sorted,
)
raise ValueError(f"unsupported neighbor-list backend: {backend}")