Source code for ufp.core.input
"""
Normalized torch-native input bundle for UFP models.
Use this module when converting ASE or engine-specific structures into the
shared atom, system, and neighbor-list representation.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import ase
import numpy as np
import torch
from ufp.core._arrays import ArrayLike, _to_tensor
from ufp.core.state import UFPInputState
from ufp.neighbors._data import NeighborListData
def _move_metadata_value(
value: object,
*,
device: torch.device,
dtype: torch.dtype,
) -> object:
"""Move internal cached metadata that participates in model execution."""
mover = getattr(value, "to_input_device", None)
if callable(mover):
return mover(device=device, dtype=dtype)
if isinstance(value, dict):
return {
key: _move_metadata_value(item, device=device, dtype=dtype)
for key, item in value.items()
}
return value
class _PairGeometry:
"""Private owner for pair geometry and categorical caches."""
def __init__(self, inputs: "UFPInput") -> None:
self.inputs = inputs
self.pair_system_index_cache: Optional[torch.Tensor] = None
self.pair_vectors_cache: Optional[torch.Tensor] = None
self.pair_distances_cache: Optional[torch.Tensor] = None
self.pair_atomic_numbers_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = (
None
)
self.pair_mask_cache: dict[tuple[int, int, bool], torch.Tensor] = {}
self.atomic_category_cache: dict[tuple[int, ...], torch.Tensor] = {}
self.pair_category_cache: dict[tuple[tuple[int, ...], bool], torch.Tensor] = {}
def normalize_pair_mask(self, mask: ArrayLike) -> torch.Tensor:
"""Validate and tensorize a pair-selection mask."""
neighbor_list = self.inputs._require_neighbor_list()
mask_tensor = _to_tensor(
mask,
dtype=torch.bool,
device=neighbor_list.pairs.device,
)
if mask_tensor.ndim != 1 or mask_tensor.shape[0] != neighbor_list.n_pairs:
raise ValueError("`mask` must have shape (n_pairs,)")
return mask_tensor
def pair_indices(
self,
mask: Optional[ArrayLike] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return atom indices for all or selected neighbor-list pairs."""
neighbor_list = self.inputs._require_neighbor_list()
first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
if mask is None:
return first_atom, second_atom
pair_mask = self.normalize_pair_mask(mask)
return first_atom[pair_mask], second_atom[pair_mask]
def full_pair_system_index(self) -> torch.Tensor:
"""Cache the system index attached to every neighbor-list pair."""
if self.pair_system_index_cache is None:
inputs = self.inputs
neighbor_list = inputs._require_neighbor_list()
first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
pair_system_index = inputs.system_index.index_select(0, first_atom)
second_system_index = inputs.system_index.index_select(0, second_atom)
if not torch.equal(pair_system_index, second_system_index):
raise ValueError("neighbor-list pairs may not span multiple systems")
self.pair_system_index_cache = pair_system_index
return self.pair_system_index_cache
def pair_system_index(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return the owning system index for all or selected pairs."""
pair_system_index = self.full_pair_system_index()
if mask is None:
return pair_system_index
return pair_system_index[self.normalize_pair_mask(mask)]
def full_pair_vectors(self) -> torch.Tensor:
"""Cache neighbor-list displacement vectors in concatenated atom coordinates."""
if self.pair_vectors_cache is None:
inputs = self.inputs
neighbor_list = inputs._require_neighbor_list()
if neighbor_list.vectors is not None and not inputs.positions.requires_grad:
self.pair_vectors_cache = neighbor_list.vectors
return self.pair_vectors_cache
first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
pair_system_index = self.full_pair_system_index()
shifts = neighbor_list.shifts.to(device=inputs.device, dtype=inputs.dtype)
cells = inputs.cell.index_select(0, pair_system_index)
shift_vectors = torch.einsum("pi,pij->pj", shifts, cells)
self.pair_vectors_cache = (
inputs.positions.index_select(0, second_atom)
- inputs.positions.index_select(0, first_atom)
+ shift_vectors
)
return self.pair_vectors_cache
def pair_vectors(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return pair displacement vectors for all or selected pairs."""
pair_vectors = self.full_pair_vectors()
if mask is None:
return pair_vectors
return pair_vectors[self.normalize_pair_mask(mask)]
def pair_shifts(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return cell-shift vectors for selected neighbor-list entries."""
inputs = self.inputs
neighbor_list = inputs._require_neighbor_list()
shifts = neighbor_list.shifts.to(device=inputs.device, dtype=torch.int64)
if mask is None:
return shifts
return shifts[self.normalize_pair_mask(mask)]
def full_pair_distances(self) -> torch.Tensor:
"""Cache pair distances derived from the full pair-vector tensor."""
if self.pair_distances_cache is None:
inputs = self.inputs
neighbor_list = inputs._require_neighbor_list()
if (
neighbor_list.distances is not None
and not inputs.positions.requires_grad
):
self.pair_distances_cache = neighbor_list.distances
return self.pair_distances_cache
self.pair_distances_cache = torch.linalg.vector_norm(
self.full_pair_vectors(),
dim=1,
)
return self.pair_distances_cache
def pair_distances(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return pair distances for all or selected neighbor-list entries."""
pair_distances = self.full_pair_distances()
if mask is None:
return pair_distances
return pair_distances[self.normalize_pair_mask(mask)]
def full_pair_atomic_numbers(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Cache atomic-number pairs for the full neighbor list."""
if self.pair_atomic_numbers_cache is None:
inputs = self.inputs
first_atom, second_atom = self.pair_indices()
self.pair_atomic_numbers_cache = (
inputs.atomic_numbers.index_select(0, first_atom),
inputs.atomic_numbers.index_select(0, second_atom),
)
return self.pair_atomic_numbers_cache
def pair_atomic_numbers(
self,
mask: Optional[ArrayLike] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return atomic numbers for the first and second atom of each selected pair."""
first_numbers, second_numbers = self.full_pair_atomic_numbers()
if mask is None:
return first_numbers, second_numbers
pair_mask = self.normalize_pair_mask(mask)
return (
first_numbers[pair_mask],
second_numbers[pair_mask],
)
def pair_mask(
self,
first_atomic_number: int,
second_atomic_number: int,
*,
symmetric: bool = False,
) -> torch.Tensor:
"""Build a mask selecting pairs with the requested atomic numbers."""
first_atomic_number = int(first_atomic_number)
second_atomic_number = int(second_atomic_number)
if symmetric and first_atomic_number > second_atomic_number:
first_atomic_number, second_atomic_number = (
second_atomic_number,
first_atomic_number,
)
key = (first_atomic_number, second_atomic_number, bool(symmetric))
cached = self.pair_mask_cache.get(key)
if cached is not None:
return cached
first_numbers, second_numbers = self.full_pair_atomic_numbers()
mask = (first_numbers == first_atomic_number) & (
second_numbers == second_atomic_number
)
if symmetric and first_atomic_number != second_atomic_number:
mask = mask | (
(first_numbers == second_atomic_number)
& (second_numbers == first_atomic_number)
)
self.pair_mask_cache[key] = mask
return mask
def atomic_category_indices(self, atomic_types: Sequence[int]) -> torch.Tensor:
"""Return atomic-category indices for every atom."""
inputs = self.inputs
normalized_atomic_types = tuple(sorted(set(int(z) for z in atomic_types)))
cached = self.atomic_category_cache.get(normalized_atomic_types)
if cached is not None:
return cached
categories = torch.full_like(inputs.atomic_numbers, fill_value=-1)
if normalized_atomic_types:
atomic_type_tensor = torch.tensor(
normalized_atomic_types,
dtype=torch.int64,
device=inputs.device,
)
category_indices = torch.searchsorted(
atomic_type_tensor,
inputs.atomic_numbers,
)
clamped_category_indices = category_indices.clamp_max(
atomic_type_tensor.numel() - 1
)
valid = (category_indices < atomic_type_tensor.numel()) & (
atomic_type_tensor[clamped_category_indices] == inputs.atomic_numbers
)
categories[valid] = category_indices[valid]
self.atomic_category_cache[normalized_atomic_types] = categories
return categories
def pair_category_indices(
self,
atomic_types: Sequence[int],
*,
symmetric: bool = True,
) -> torch.Tensor:
"""Return pair-category indices for every neighbor-list row."""
inputs = self.inputs
normalized_atomic_types = tuple(sorted(set(int(z) for z in atomic_types)))
key = (normalized_atomic_types, bool(symmetric))
cached = self.pair_category_cache.get(key)
if cached is not None:
return cached
atom_category = self.atomic_category_indices(normalized_atomic_types)
first_atom, second_atom = self.pair_indices()
first_category = atom_category.index_select(0, first_atom)
second_category = atom_category.index_select(0, second_atom)
valid = (first_category >= 0) & (second_category >= 0)
n_categories = len(normalized_atomic_types)
table = torch.full(
(n_categories, n_categories),
fill_value=-1,
dtype=torch.int64,
device=inputs.device,
)
pair_index = 0
if symmetric:
for first in range(n_categories):
for second in range(first, n_categories):
table[first, second] = pair_index
table[second, first] = pair_index
pair_index += 1
else:
for first in range(n_categories):
for second in range(n_categories):
table[first, second] = pair_index
pair_index += 1
pair_category = torch.full(
(inputs._require_neighbor_list().n_pairs,),
fill_value=-1,
dtype=torch.int64,
device=inputs.device,
)
pair_category[valid] = table[first_category[valid], second_category[valid]]
self.pair_category_cache[key] = pair_category
return pair_category
def slice_neighbor_list(self, mask: ArrayLike) -> NeighborListData:
"""Return a ``NeighborListData`` view restricted to selected pairs."""
return self.inputs._require_neighbor_list().masked(
self.normalize_pair_mask(mask)
)
[docs]
@dataclass
class UFPInput:
"""
Torch-native input bundle passed to :class:`UFPotential`.
The same structure works for single systems, batches of ASE structures, and
metatomic-provided systems. Neighbor lists are optional but, when present, always
refer to the concatenated atom indexing used by ``positions``.
``positions`` and ``cell`` use the same length unit as the source structure,
normally angstroms for ASE inputs. The floating-point dtype and device of
``positions`` define the dtype and device used for geometric tensors and
neighbor-list vectors.
Attributes:
positions: Atomic positions with shape ``(n_atoms, 3)``.
cell: Unit cells with shape ``(n_systems, 3, 3)`` or ``(3, 3)`` for one
system.
pbc: Periodic boundary flags with shape ``(n_systems, 3)`` or ``(3,)``.
atomic_numbers: Atomic numbers with shape ``(n_atoms,)``.
system_index: System index for each atom with shape ``(n_atoms,)``. Every
system in ``cell`` must appear at least once.
neighbor_list: Optional neighbor-list data using concatenated atom indexing.
atomic_charges: Optional local charges with shape ``(n_atoms,)``.
atomic_spin_moments: Optional local collinear spin moments with shape
``(n_atoms,)``.
system_charges: Optional total charge per system with shape ``(n_systems,)``.
system_spin_moments: Optional total spin moment per system with shape
``(n_systems,)``.
metadata: Optional metadata carried alongside the normalized tensors.
source_atoms: Optional original ASE structures, one per system.
Examples:
>>> import torch
>>> data = UFPInput(
... positions=torch.zeros((2, 3), dtype=torch.float64),
... cell=torch.eye(3, dtype=torch.float64),
... pbc=torch.tensor([False, False, False]),
... atomic_numbers=torch.tensor([1, 1]),
... system_index=torch.tensor([0, 0]),
... )
>>> data.n_atoms
2
>>> data.dtype
torch.float64
"""
positions: ArrayLike
cell: ArrayLike
pbc: ArrayLike
atomic_numbers: ArrayLike
system_index: ArrayLike
neighbor_list: Optional[NeighborListData] = None
metadata: Dict[str, object] = field(default_factory=dict)
source_atoms: Optional[Sequence[ase.Atoms]] = None
atomic_charges: Optional[ArrayLike] = None
atomic_spin_moments: Optional[ArrayLike] = None
system_charges: Optional[ArrayLike] = None
system_spin_moments: Optional[ArrayLike] = None
state: Optional[UFPInputState] = None
_pair_geometry_cache: Optional[_PairGeometry] = field(
default=None,
init=False,
repr=False,
)
def __post_init__(self) -> None:
"""Normalize stored arrays into the shared tensor layout."""
positions = _to_tensor(self.positions)
if positions.ndim != 2 or positions.shape[1] != 3:
raise ValueError("`positions` must have shape (n_atoms, 3)")
if not positions.is_floating_point():
positions = positions.to(dtype=torch.get_default_dtype())
self.positions = positions
self.cell = _to_tensor(
self.cell,
dtype=self.positions.dtype,
device=self.positions.device,
)
self.pbc = _to_tensor(self.pbc, dtype=torch.bool, device=self.positions.device)
self.atomic_numbers = _to_tensor(
self.atomic_numbers,
dtype=torch.int64,
device=self.positions.device,
)
self.system_index = _to_tensor(
self.system_index,
dtype=torch.int64,
device=self.positions.device,
)
if self.cell.ndim == 2:
self.cell = self.cell.unsqueeze(0)
if self.cell.ndim != 3 or tuple(self.cell.shape[1:]) != (3, 3):
raise ValueError("`cell` must have shape (n_systems, 3, 3) or (3, 3)")
if self.pbc.ndim == 1:
self.pbc = self.pbc.unsqueeze(0)
if self.pbc.ndim != 2 or self.pbc.shape[1] != 3:
raise ValueError("`pbc` must have shape (n_systems, 3) or (3,)")
n_systems = int(self.cell.shape[0])
if self.pbc.shape[0] == 1 and n_systems > 1:
self.pbc = self.pbc.expand(n_systems, -1).clone()
elif self.pbc.shape[0] != n_systems:
raise ValueError(
"`cell` and `pbc` must describe the same number of systems"
)
if (
self.atomic_numbers.ndim != 1
or self.atomic_numbers.shape[0] != self.n_atoms
):
raise ValueError("`atomic_numbers` must have shape (n_atoms,)")
if self.system_index.ndim != 1 or self.system_index.shape[0] != self.n_atoms:
raise ValueError("`system_index` must have shape (n_atoms,)")
if self.n_atoms == 0:
raise ValueError("`positions` must contain at least one atom")
if torch.any(self.system_index < 0):
raise ValueError("`system_index` can not contain negative values")
if int(self.system_index.max().item()) >= n_systems:
raise ValueError("`system_index` references a system outside `cell`")
expected_systems = torch.arange(
n_systems,
device=self.device,
dtype=torch.int64,
)
present_systems = torch.unique(self.system_index, sorted=True)
if not torch.equal(expected_systems, present_systems):
raise ValueError("`system_index` must contain every system exactly once")
if self.neighbor_list is not None:
self.neighbor_list = self.neighbor_list.as_torch(
dtype=self.dtype,
device=self.device,
)
if self.state is not None and any(
value is not None
for value in (
self.atomic_charges,
self.atomic_spin_moments,
self.system_charges,
self.system_spin_moments,
)
):
raise ValueError(
"pass either `state` or explicit charge/spin state tensors, not both"
)
state = self.state
if state is None:
state = UFPInputState(
atomic_charges=self.atomic_charges,
atomic_spin_moments=self.atomic_spin_moments,
system_charges=self.system_charges,
system_spin_moments=self.system_spin_moments,
)
self.state = state.as_torch(
n_atoms=self.n_atoms,
n_systems=n_systems,
dtype=self.dtype,
device=self.device,
)
self.atomic_charges = self.state.atomic_charges
self.atomic_spin_moments = self.state.atomic_spin_moments
self.system_charges = self.state.system_charges
self.system_spin_moments = self.state.system_spin_moments
if self.source_atoms is not None:
if isinstance(self.source_atoms, ase.Atoms):
self.source_atoms = (self.source_atoms,)
else:
self.source_atoms = tuple(self.source_atoms)
if len(self.source_atoms) != n_systems:
raise ValueError(
"`source_atoms` must contain one ASE structure per system"
)
[docs]
@classmethod
def from_ase(
cls,
atoms: ase.Atoms,
*,
neighbor_list: Optional[NeighborListData] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
metadata: Optional[Dict[str, object]] = None,
state: Optional[UFPInputState] = None,
extract_state: bool = True,
) -> "UFPInput":
"""Build a single-system input by delegating to ``from_ase_list``."""
return cls.from_ase_list(
[atoms],
neighbor_list=neighbor_list,
dtype=dtype,
device=device,
requires_grad=requires_grad,
metadata=metadata,
state=state,
extract_state=extract_state,
)
[docs]
@classmethod
def from_ase_list(
cls,
atoms_list: Sequence[ase.Atoms],
*,
neighbor_list: Optional[NeighborListData] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
metadata: Optional[Dict[str, object]] = None,
state: Optional[UFPInputState] = None,
extract_state: bool = True,
) -> "UFPInput":
"""Concatenate ASE structures into one batched input."""
if not atoms_list:
raise ValueError("`atoms_list` must contain at least one ASE structure")
resolved_dtype = torch.get_default_dtype() if dtype is None else dtype
positions = []
cells = []
pbc = []
atomic_numbers = []
system_index = []
for system_i, atoms in enumerate(atoms_list):
if not isinstance(atoms, ase.Atoms):
raise TypeError(
f"`atoms_list` should contain ase.Atoms, got {type(atoms)}"
)
positions.append(
torch.as_tensor(atoms.positions, dtype=resolved_dtype, device=device)
)
cells.append(
torch.as_tensor(atoms.cell.array, dtype=resolved_dtype, device=device)
)
pbc.append(
torch.as_tensor(
np.asarray(atoms.pbc),
dtype=torch.bool,
device=device,
)
)
atomic_numbers.append(
torch.as_tensor(atoms.numbers, dtype=torch.int64, device=device)
)
system_index.append(
torch.full(
(len(atoms),),
system_i,
dtype=torch.int64,
device=device,
)
)
concatenated_positions = torch.cat(positions, dim=0)
if requires_grad:
concatenated_positions.requires_grad_(True)
input_state = state
if input_state is None and extract_state:
input_state = UFPInputState.from_ase_list(
tuple(atoms_list),
dtype=resolved_dtype,
device=device,
)
return cls(
positions=concatenated_positions,
cell=torch.stack(cells, dim=0),
pbc=torch.stack(pbc, dim=0),
atomic_numbers=torch.cat(atomic_numbers, dim=0),
system_index=torch.cat(system_index, dim=0),
neighbor_list=neighbor_list,
state=input_state,
metadata={} if metadata is None else dict(metadata),
source_atoms=tuple(atoms_list),
)
[docs]
def to(
self,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
neighbor_list: Optional[NeighborListData] = None,
) -> "UFPInput":
"""Return a copy moved to the requested device/dtype and gradient state."""
resolved_device = self.device if device is None else torch.device(device)
resolved_dtype = self.dtype if dtype is None else dtype
positions = self.positions.to(
device=resolved_device,
dtype=resolved_dtype,
non_blocking=True,
)
if requires_grad:
positions = positions.detach().clone().requires_grad_(True)
else:
positions = positions.detach()
metadata = {
key: _move_metadata_value(
value,
device=resolved_device,
dtype=resolved_dtype,
)
for key, value in self.metadata.items()
}
moved_neighbor_list = (
self.neighbor_list.as_torch(
dtype=resolved_dtype,
device=resolved_device,
)
if neighbor_list is None and self.neighbor_list is not None
else neighbor_list
)
return UFPInput(
positions=positions,
cell=self.cell.to(
device=resolved_device,
dtype=resolved_dtype,
non_blocking=True,
),
pbc=self.pbc.to(device=resolved_device, non_blocking=True),
atomic_numbers=self.atomic_numbers.to(
device=resolved_device,
non_blocking=True,
),
system_index=self.system_index.to(
device=resolved_device,
non_blocking=True,
),
neighbor_list=moved_neighbor_list,
state=self.state.to(
n_atoms=self.n_atoms,
n_systems=self.n_systems,
dtype=resolved_dtype,
device=resolved_device,
)
if self.state is not None
else None,
metadata=metadata,
source_atoms=self.source_atoms,
)
[docs]
def pin_memory(self) -> "UFPInput":
"""Pin stored tensors in place and return ``self`` for dataloader-style use."""
self.positions = self.positions.pin_memory()
self.cell = self.cell.pin_memory()
self.pbc = self.pbc.pin_memory()
self.atomic_numbers = self.atomic_numbers.pin_memory()
self.system_index = self.system_index.pin_memory()
if self.state is not None:
self.state = self.state.pin_memory()
self.atomic_charges = self.state.atomic_charges
self.atomic_spin_moments = self.state.atomic_spin_moments
self.system_charges = self.state.system_charges
self.system_spin_moments = self.state.system_spin_moments
if self.neighbor_list is not None:
self.neighbor_list = self.neighbor_list.pin_memory()
return self
@property
def n_atoms(self) -> int:
"""Return the total number of atoms across all systems."""
return int(self.positions.shape[0])
@property
def n_systems(self) -> int:
"""Return the number of systems represented by this input."""
return int(self.cell.shape[0])
@property
def device(self) -> torch.device:
"""Return the torch device shared by the stored tensors."""
return self.positions.device
@property
def dtype(self) -> torch.dtype:
"""Return the floating-point dtype used for geometric tensors."""
return self.positions.dtype
@property
def system_sizes(self) -> list[int]:
"""Return the atom count for each system in concatenated order."""
counts = torch.bincount(self.system_index, minlength=self.n_systems)
return [int(value) for value in counts.tolist()]
@property
def atom_slices(self) -> list[slice]:
"""Return per-system slices into the concatenated atom axis."""
atom_slices = []
start = 0
for size in self.system_sizes:
atom_slices.append(slice(start, start + size))
start += size
return atom_slices
@property
def atoms(self) -> ase.Atoms:
"""Return the original ASE structure for a single-system input."""
if self.source_atoms is None or len(self.source_atoms) != 1:
raise AttributeError(
"`atoms` is only available for single-system ASE inputs"
)
return self.source_atoms[0]
def _require_neighbor_list(self) -> NeighborListData:
"""Return the stored neighbor list or raise when geometry is missing."""
if self.neighbor_list is None:
raise RuntimeError("this model input does not contain a neighbor list")
return self.neighbor_list
def _pair_geometry(self) -> _PairGeometry:
"""Return the private pair-geometry cache owner."""
if self._pair_geometry_cache is None:
self._pair_geometry_cache = _PairGeometry(self)
return self._pair_geometry_cache
@property
def _pair_system_index_cache(self) -> Optional[torch.Tensor]:
"""Compatibility access to the pair-system-index cache."""
return self._pair_geometry().pair_system_index_cache
@property
def _pair_vectors_cache(self) -> Optional[torch.Tensor]:
"""Compatibility access to the pair-vector cache."""
return self._pair_geometry().pair_vectors_cache
@property
def _pair_distances_cache(self) -> Optional[torch.Tensor]:
"""Compatibility access to the pair-distance cache."""
return self._pair_geometry().pair_distances_cache
@property
def _pair_atomic_numbers_cache(
self,
) -> Optional[tuple[torch.Tensor, torch.Tensor]]:
"""Compatibility access to the pair-atomic-number cache."""
return self._pair_geometry().pair_atomic_numbers_cache
@property
def _pair_mask_cache(self) -> dict[tuple[int, int, bool], torch.Tensor]:
"""Compatibility access to cached pair masks."""
return self._pair_geometry().pair_mask_cache
@property
def _atomic_category_cache(self) -> dict[tuple[int, ...], torch.Tensor]:
"""Compatibility access to cached atomic categories."""
return self._pair_geometry().atomic_category_cache
@property
def _pair_category_cache(
self,
) -> dict[tuple[tuple[int, ...], bool], torch.Tensor]:
"""Compatibility access to cached pair categories."""
return self._pair_geometry().pair_category_cache
def _normalize_pair_mask(self, mask: ArrayLike) -> torch.Tensor:
"""Validate and tensorize a pair-selection mask."""
return self._pair_geometry().normalize_pair_mask(mask)
[docs]
def pair_indices(
self,
mask: Optional[ArrayLike] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return atom indices for all or selected neighbor-list pairs."""
return self._pair_geometry().pair_indices(mask)
def _full_pair_system_index(self) -> torch.Tensor:
"""Cache the system index attached to every neighbor-list pair."""
return self._pair_geometry().full_pair_system_index()
[docs]
def pair_system_index(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return the owning system index for all or selected pairs."""
return self._pair_geometry().pair_system_index(mask)
def _full_pair_vectors(self) -> torch.Tensor:
"""Cache neighbor-list displacement vectors in concatenated atom coordinates."""
return self._pair_geometry().full_pair_vectors()
[docs]
def pair_vectors(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return pair displacement vectors for all or selected pairs."""
return self._pair_geometry().pair_vectors(mask)
[docs]
def pair_shifts(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return cell-shift vectors for selected neighbor-list entries."""
return self._pair_geometry().pair_shifts(mask)
def _full_pair_distances(self) -> torch.Tensor:
"""Cache pair distances derived from the full pair-vector tensor."""
return self._pair_geometry().full_pair_distances()
[docs]
def pair_distances(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
"""Return pair distances for all or selected neighbor-list entries."""
return self._pair_geometry().pair_distances(mask)
def _full_pair_atomic_numbers(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Cache atomic-number pairs for the full neighbor list."""
return self._pair_geometry().full_pair_atomic_numbers()
[docs]
def pair_atomic_numbers(
self,
mask: Optional[ArrayLike] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return atomic numbers for the first and second atom of each selected pair."""
return self._pair_geometry().pair_atomic_numbers(mask)
[docs]
def pair_mask(
self,
first_atomic_number: int,
second_atomic_number: int,
*,
symmetric: bool = False,
) -> torch.Tensor:
"""Build a mask selecting pairs with the requested atomic numbers."""
return self._pair_geometry().pair_mask(
first_atomic_number,
second_atomic_number,
symmetric=symmetric,
)
[docs]
def atomic_category_indices(self, atomic_types: Sequence[int]) -> torch.Tensor:
"""
Return atomic-category indices for every atom.
Atomic numbers outside ``atomic_types`` are marked ``-1``. Category ordering
follows the sorted unique atomic types used throughout UFP term layouts.
"""
return self._pair_geometry().atomic_category_indices(atomic_types)
[docs]
def pair_category_indices(
self,
atomic_types: Sequence[int],
*,
symmetric: bool = True,
) -> torch.Tensor:
"""
Return pair-category indices for every neighbor-list row.
Pairs containing atomic numbers outside ``atomic_types`` are marked ``-1``.
When ``symmetric`` is true, category ordering matches unordered
combinations with replacement over sorted unique atomic types.
"""
return self._pair_geometry().pair_category_indices(
atomic_types,
symmetric=symmetric,
)
[docs]
def slice_neighbor_list(self, mask: ArrayLike) -> NeighborListData:
"""Return a ``NeighborListData`` view restricted to selected pairs."""
return self._pair_geometry().slice_neighbor_list(mask)
[docs]
def missing_state_fields(self, fields: Sequence[str]) -> tuple[str, ...]:
"""Return required charge/spin state fields that are absent."""
if self.state is None:
return tuple(str(field) for field in fields)
return self.state.missing_fields(fields)
__all__ = [
"UFPInputState",
"UFPInput",
]