Source code for ufp.core.input

"""
Normalized torch-native input bundle for UFP models.

Use this module when converting ASE or engine-specific structures into the
shared atom, system, and neighbor-list representation.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import ase
import numpy as np
import torch

from ufp.core._arrays import ArrayLike, _to_tensor
from ufp.core.state import UFPInputState
from ufp.neighbors._data import NeighborListData


def _move_metadata_value(
    value: object,
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> object:
    """Move internal cached metadata that participates in model execution."""
    mover = getattr(value, "to_input_device", None)
    if callable(mover):
        return mover(device=device, dtype=dtype)
    if isinstance(value, dict):
        return {
            key: _move_metadata_value(item, device=device, dtype=dtype)
            for key, item in value.items()
        }
    return value


class _PairGeometry:
    """Private owner for pair geometry and categorical caches."""

    def __init__(self, inputs: "UFPInput") -> None:
        self.inputs = inputs
        self.pair_system_index_cache: Optional[torch.Tensor] = None
        self.pair_vectors_cache: Optional[torch.Tensor] = None
        self.pair_distances_cache: Optional[torch.Tensor] = None
        self.pair_atomic_numbers_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = (
            None
        )
        self.pair_mask_cache: dict[tuple[int, int, bool], torch.Tensor] = {}
        self.atomic_category_cache: dict[tuple[int, ...], torch.Tensor] = {}
        self.pair_category_cache: dict[tuple[tuple[int, ...], bool], torch.Tensor] = {}

    def normalize_pair_mask(self, mask: ArrayLike) -> torch.Tensor:
        """Validate and tensorize a pair-selection mask."""
        neighbor_list = self.inputs._require_neighbor_list()
        mask_tensor = _to_tensor(
            mask,
            dtype=torch.bool,
            device=neighbor_list.pairs.device,
        )
        if mask_tensor.ndim != 1 or mask_tensor.shape[0] != neighbor_list.n_pairs:
            raise ValueError("`mask` must have shape (n_pairs,)")

        return mask_tensor

    def pair_indices(
        self,
        mask: Optional[ArrayLike] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Return atom indices for all or selected neighbor-list pairs."""
        neighbor_list = self.inputs._require_neighbor_list()
        first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
        if mask is None:
            return first_atom, second_atom

        pair_mask = self.normalize_pair_mask(mask)
        return first_atom[pair_mask], second_atom[pair_mask]

    def full_pair_system_index(self) -> torch.Tensor:
        """Cache the system index attached to every neighbor-list pair."""
        if self.pair_system_index_cache is None:
            inputs = self.inputs
            neighbor_list = inputs._require_neighbor_list()
            first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
            pair_system_index = inputs.system_index.index_select(0, first_atom)
            second_system_index = inputs.system_index.index_select(0, second_atom)
            if not torch.equal(pair_system_index, second_system_index):
                raise ValueError("neighbor-list pairs may not span multiple systems")
            self.pair_system_index_cache = pair_system_index
        return self.pair_system_index_cache

    def pair_system_index(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
        """Return the owning system index for all or selected pairs."""
        pair_system_index = self.full_pair_system_index()
        if mask is None:
            return pair_system_index
        return pair_system_index[self.normalize_pair_mask(mask)]

    def full_pair_vectors(self) -> torch.Tensor:
        """Cache neighbor-list displacement vectors in concatenated atom coordinates."""
        if self.pair_vectors_cache is None:
            inputs = self.inputs
            neighbor_list = inputs._require_neighbor_list()
            if neighbor_list.vectors is not None and not inputs.positions.requires_grad:
                self.pair_vectors_cache = neighbor_list.vectors
                return self.pair_vectors_cache

            first_atom, second_atom = neighbor_list.pairs[0], neighbor_list.pairs[1]
            pair_system_index = self.full_pair_system_index()
            shifts = neighbor_list.shifts.to(device=inputs.device, dtype=inputs.dtype)
            cells = inputs.cell.index_select(0, pair_system_index)
            shift_vectors = torch.einsum("pi,pij->pj", shifts, cells)
            self.pair_vectors_cache = (
                inputs.positions.index_select(0, second_atom)
                - inputs.positions.index_select(0, first_atom)
                + shift_vectors
            )
        return self.pair_vectors_cache

    def pair_vectors(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
        """Return pair displacement vectors for all or selected pairs."""
        pair_vectors = self.full_pair_vectors()
        if mask is None:
            return pair_vectors
        return pair_vectors[self.normalize_pair_mask(mask)]

    def pair_shifts(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
        """Return cell-shift vectors for selected neighbor-list entries."""
        inputs = self.inputs
        neighbor_list = inputs._require_neighbor_list()
        shifts = neighbor_list.shifts.to(device=inputs.device, dtype=torch.int64)
        if mask is None:
            return shifts
        return shifts[self.normalize_pair_mask(mask)]

    def full_pair_distances(self) -> torch.Tensor:
        """Cache pair distances derived from the full pair-vector tensor."""
        if self.pair_distances_cache is None:
            inputs = self.inputs
            neighbor_list = inputs._require_neighbor_list()
            if (
                neighbor_list.distances is not None
                and not inputs.positions.requires_grad
            ):
                self.pair_distances_cache = neighbor_list.distances
                return self.pair_distances_cache
            self.pair_distances_cache = torch.linalg.vector_norm(
                self.full_pair_vectors(),
                dim=1,
            )
        return self.pair_distances_cache

    def pair_distances(self, mask: Optional[ArrayLike] = None) -> torch.Tensor:
        """Return pair distances for all or selected neighbor-list entries."""
        pair_distances = self.full_pair_distances()
        if mask is None:
            return pair_distances
        return pair_distances[self.normalize_pair_mask(mask)]

    def full_pair_atomic_numbers(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Cache atomic-number pairs for the full neighbor list."""
        if self.pair_atomic_numbers_cache is None:
            inputs = self.inputs
            first_atom, second_atom = self.pair_indices()
            self.pair_atomic_numbers_cache = (
                inputs.atomic_numbers.index_select(0, first_atom),
                inputs.atomic_numbers.index_select(0, second_atom),
            )
        return self.pair_atomic_numbers_cache

    def pair_atomic_numbers(
        self,
        mask: Optional[ArrayLike] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Return atomic numbers for the first and second atom of each selected pair."""
        first_numbers, second_numbers = self.full_pair_atomic_numbers()
        if mask is None:
            return first_numbers, second_numbers

        pair_mask = self.normalize_pair_mask(mask)
        return (
            first_numbers[pair_mask],
            second_numbers[pair_mask],
        )

    def pair_mask(
        self,
        first_atomic_number: int,
        second_atomic_number: int,
        *,
        symmetric: bool = False,
    ) -> torch.Tensor:
        """Build a mask selecting pairs with the requested atomic numbers."""
        first_atomic_number = int(first_atomic_number)
        second_atomic_number = int(second_atomic_number)
        if symmetric and first_atomic_number > second_atomic_number:
            first_atomic_number, second_atomic_number = (
                second_atomic_number,
                first_atomic_number,
            )
        key = (first_atomic_number, second_atomic_number, bool(symmetric))
        cached = self.pair_mask_cache.get(key)
        if cached is not None:
            return cached

        first_numbers, second_numbers = self.full_pair_atomic_numbers()
        mask = (first_numbers == first_atomic_number) & (
            second_numbers == second_atomic_number
        )
        if symmetric and first_atomic_number != second_atomic_number:
            mask = mask | (
                (first_numbers == second_atomic_number)
                & (second_numbers == first_atomic_number)
            )

        self.pair_mask_cache[key] = mask
        return mask

    def atomic_category_indices(self, atomic_types: Sequence[int]) -> torch.Tensor:
        """Return atomic-category indices for every atom."""
        inputs = self.inputs
        normalized_atomic_types = tuple(sorted(set(int(z) for z in atomic_types)))
        cached = self.atomic_category_cache.get(normalized_atomic_types)
        if cached is not None:
            return cached

        categories = torch.full_like(inputs.atomic_numbers, fill_value=-1)
        if normalized_atomic_types:
            atomic_type_tensor = torch.tensor(
                normalized_atomic_types,
                dtype=torch.int64,
                device=inputs.device,
            )
            category_indices = torch.searchsorted(
                atomic_type_tensor,
                inputs.atomic_numbers,
            )
            clamped_category_indices = category_indices.clamp_max(
                atomic_type_tensor.numel() - 1
            )
            valid = (category_indices < atomic_type_tensor.numel()) & (
                atomic_type_tensor[clamped_category_indices] == inputs.atomic_numbers
            )
            categories[valid] = category_indices[valid]

        self.atomic_category_cache[normalized_atomic_types] = categories
        return categories

    def pair_category_indices(
        self,
        atomic_types: Sequence[int],
        *,
        symmetric: bool = True,
    ) -> torch.Tensor:
        """Return pair-category indices for every neighbor-list row."""
        inputs = self.inputs
        normalized_atomic_types = tuple(sorted(set(int(z) for z in atomic_types)))
        key = (normalized_atomic_types, bool(symmetric))
        cached = self.pair_category_cache.get(key)
        if cached is not None:
            return cached

        atom_category = self.atomic_category_indices(normalized_atomic_types)
        first_atom, second_atom = self.pair_indices()
        first_category = atom_category.index_select(0, first_atom)
        second_category = atom_category.index_select(0, second_atom)
        valid = (first_category >= 0) & (second_category >= 0)

        n_categories = len(normalized_atomic_types)
        table = torch.full(
            (n_categories, n_categories),
            fill_value=-1,
            dtype=torch.int64,
            device=inputs.device,
        )
        pair_index = 0
        if symmetric:
            for first in range(n_categories):
                for second in range(first, n_categories):
                    table[first, second] = pair_index
                    table[second, first] = pair_index
                    pair_index += 1
        else:
            for first in range(n_categories):
                for second in range(n_categories):
                    table[first, second] = pair_index
                    pair_index += 1

        pair_category = torch.full(
            (inputs._require_neighbor_list().n_pairs,),
            fill_value=-1,
            dtype=torch.int64,
            device=inputs.device,
        )
        pair_category[valid] = table[first_category[valid], second_category[valid]]
        self.pair_category_cache[key] = pair_category
        return pair_category

    def slice_neighbor_list(self, mask: ArrayLike) -> NeighborListData:
        """Return a ``NeighborListData`` view restricted to selected pairs."""
        return self.inputs._require_neighbor_list().masked(
            self.normalize_pair_mask(mask)
        )


[docs] @dataclass class UFPInput: """ Torch-native input bundle passed to :class:`UFPotential`. The same structure works for single systems, batches of ASE structures, and metatomic-provided systems. Neighbor lists are optional but, when present, always refer to the concatenated atom indexing used by ``positions``. ``positions`` and ``cell`` use the same length unit as the source structure, normally angstroms for ASE inputs. The floating-point dtype and device of ``positions`` define the dtype and device used for geometric tensors and neighbor-list vectors. Attributes: positions: Atomic positions with shape ``(n_atoms, 3)``. cell: Unit cells with shape ``(n_systems, 3, 3)`` or ``(3, 3)`` for one system. pbc: Periodic boundary flags with shape ``(n_systems, 3)`` or ``(3,)``. atomic_numbers: Atomic numbers with shape ``(n_atoms,)``. system_index: System index for each atom with shape ``(n_atoms,)``. Every system in ``cell`` must appear at least once. neighbor_list: Optional neighbor-list data using concatenated atom indexing. atomic_charges: Optional local charges with shape ``(n_atoms,)``. atomic_spin_moments: Optional local collinear spin moments with shape ``(n_atoms,)``. system_charges: Optional total charge per system with shape ``(n_systems,)``. system_spin_moments: Optional total spin moment per system with shape ``(n_systems,)``. metadata: Optional metadata carried alongside the normalized tensors. source_atoms: Optional original ASE structures, one per system. Examples: >>> import torch >>> data = UFPInput( ... positions=torch.zeros((2, 3), dtype=torch.float64), ... cell=torch.eye(3, dtype=torch.float64), ... pbc=torch.tensor([False, False, False]), ... atomic_numbers=torch.tensor([1, 1]), ... system_index=torch.tensor([0, 0]), ... ) >>> data.n_atoms 2 >>> data.dtype torch.float64 """ positions: ArrayLike cell: ArrayLike pbc: ArrayLike atomic_numbers: ArrayLike system_index: ArrayLike neighbor_list: Optional[NeighborListData] = None metadata: Dict[str, object] = field(default_factory=dict) source_atoms: Optional[Sequence[ase.Atoms]] = None atomic_charges: Optional[ArrayLike] = None atomic_spin_moments: Optional[ArrayLike] = None system_charges: Optional[ArrayLike] = None system_spin_moments: Optional[ArrayLike] = None state: Optional[UFPInputState] = None _pair_geometry_cache: Optional[_PairGeometry] = field( default=None, init=False, repr=False, ) def __post_init__(self) -> None: """Normalize stored arrays into the shared tensor layout.""" positions = _to_tensor(self.positions) if positions.ndim != 2 or positions.shape[1] != 3: raise ValueError("`positions` must have shape (n_atoms, 3)") if not positions.is_floating_point(): positions = positions.to(dtype=torch.get_default_dtype()) self.positions = positions self.cell = _to_tensor( self.cell, dtype=self.positions.dtype, device=self.positions.device, ) self.pbc = _to_tensor(self.pbc, dtype=torch.bool, device=self.positions.device) self.atomic_numbers = _to_tensor( self.atomic_numbers, dtype=torch.int64, device=self.positions.device, ) self.system_index = _to_tensor( self.system_index, dtype=torch.int64, device=self.positions.device, ) if self.cell.ndim == 2: self.cell = self.cell.unsqueeze(0) if self.cell.ndim != 3 or tuple(self.cell.shape[1:]) != (3, 3): raise ValueError("`cell` must have shape (n_systems, 3, 3) or (3, 3)") if self.pbc.ndim == 1: self.pbc = self.pbc.unsqueeze(0) if self.pbc.ndim != 2 or self.pbc.shape[1] != 3: raise ValueError("`pbc` must have shape (n_systems, 3) or (3,)") n_systems = int(self.cell.shape[0]) if self.pbc.shape[0] == 1 and n_systems > 1: self.pbc = self.pbc.expand(n_systems, -1).clone() elif self.pbc.shape[0] != n_systems: raise ValueError( "`cell` and `pbc` must describe the same number of systems" ) if ( self.atomic_numbers.ndim != 1 or self.atomic_numbers.shape[0] != self.n_atoms ): raise ValueError("`atomic_numbers` must have shape (n_atoms,)") if self.system_index.ndim != 1 or self.system_index.shape[0] != self.n_atoms: raise ValueError("`system_index` must have shape (n_atoms,)") if self.n_atoms == 0: raise ValueError("`positions` must contain at least one atom") if torch.any(self.system_index < 0): raise ValueError("`system_index` can not contain negative values") if int(self.system_index.max().item()) >= n_systems: raise ValueError("`system_index` references a system outside `cell`") expected_systems = torch.arange( n_systems, device=self.device, dtype=torch.int64, ) present_systems = torch.unique(self.system_index, sorted=True) if not torch.equal(expected_systems, present_systems): raise ValueError("`system_index` must contain every system exactly once") if self.neighbor_list is not None: self.neighbor_list = self.neighbor_list.as_torch( dtype=self.dtype, device=self.device, ) if self.state is not None and any( value is not None for value in ( self.atomic_charges, self.atomic_spin_moments, self.system_charges, self.system_spin_moments, ) ): raise ValueError( "pass either `state` or explicit charge/spin state tensors, not both" ) state = self.state if state is None: state = UFPInputState( atomic_charges=self.atomic_charges, atomic_spin_moments=self.atomic_spin_moments, system_charges=self.system_charges, system_spin_moments=self.system_spin_moments, ) self.state = state.as_torch( n_atoms=self.n_atoms, n_systems=n_systems, dtype=self.dtype, device=self.device, ) self.atomic_charges = self.state.atomic_charges self.atomic_spin_moments = self.state.atomic_spin_moments self.system_charges = self.state.system_charges self.system_spin_moments = self.state.system_spin_moments if self.source_atoms is not None: if isinstance(self.source_atoms, ase.Atoms): self.source_atoms = (self.source_atoms,) else: self.source_atoms = tuple(self.source_atoms) if len(self.source_atoms) != n_systems: raise ValueError( "`source_atoms` must contain one ASE structure per system" )
[docs] @classmethod def from_ase( cls, atoms: ase.Atoms, *, neighbor_list: Optional[NeighborListData] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False, metadata: Optional[Dict[str, object]] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> "UFPInput": """Build a single-system input by delegating to ``from_ase_list``.""" return cls.from_ase_list( [atoms], neighbor_list=neighbor_list, dtype=dtype, device=device, requires_grad=requires_grad, metadata=metadata, state=state, extract_state=extract_state, )
[docs] @classmethod def from_ase_list( cls, atoms_list: Sequence[ase.Atoms], *, neighbor_list: Optional[NeighborListData] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = False, metadata: Optional[Dict[str, object]] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> "UFPInput": """Concatenate ASE structures into one batched input.""" if not atoms_list: raise ValueError("`atoms_list` must contain at least one ASE structure") resolved_dtype = torch.get_default_dtype() if dtype is None else dtype positions = [] cells = [] pbc = [] atomic_numbers = [] system_index = [] for system_i, atoms in enumerate(atoms_list): if not isinstance(atoms, ase.Atoms): raise TypeError( f"`atoms_list` should contain ase.Atoms, got {type(atoms)}" ) positions.append( torch.as_tensor(atoms.positions, dtype=resolved_dtype, device=device) ) cells.append( torch.as_tensor(atoms.cell.array, dtype=resolved_dtype, device=device) ) pbc.append( torch.as_tensor( np.asarray(atoms.pbc), dtype=torch.bool, device=device, ) ) atomic_numbers.append( torch.as_tensor(atoms.numbers, dtype=torch.int64, device=device) ) system_index.append( torch.full( (len(atoms),), system_i, dtype=torch.int64, device=device, ) ) concatenated_positions = torch.cat(positions, dim=0) if requires_grad: concatenated_positions.requires_grad_(True) input_state = state if input_state is None and extract_state: input_state = UFPInputState.from_ase_list( tuple(atoms_list), dtype=resolved_dtype, device=device, ) return cls( positions=concatenated_positions, cell=torch.stack(cells, dim=0), pbc=torch.stack(pbc, dim=0), atomic_numbers=torch.cat(atomic_numbers, dim=0), system_index=torch.cat(system_index, dim=0), neighbor_list=neighbor_list, state=input_state, metadata={} if metadata is None else dict(metadata), source_atoms=tuple(atoms_list), )
[docs] def to( self, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, neighbor_list: Optional[NeighborListData] = None, ) -> "UFPInput": """Return a copy moved to the requested device/dtype and gradient state.""" resolved_device = self.device if device is None else torch.device(device) resolved_dtype = self.dtype if dtype is None else dtype positions = self.positions.to( device=resolved_device, dtype=resolved_dtype, non_blocking=True, ) if requires_grad: positions = positions.detach().clone().requires_grad_(True) else: positions = positions.detach() metadata = { key: _move_metadata_value( value, device=resolved_device, dtype=resolved_dtype, ) for key, value in self.metadata.items() } moved_neighbor_list = ( self.neighbor_list.as_torch( dtype=resolved_dtype, device=resolved_device, ) if neighbor_list is None and self.neighbor_list is not None else neighbor_list ) return UFPInput( positions=positions, cell=self.cell.to( device=resolved_device, dtype=resolved_dtype, non_blocking=True, ), pbc=self.pbc.to(device=resolved_device, non_blocking=True), atomic_numbers=self.atomic_numbers.to( device=resolved_device, non_blocking=True, ), system_index=self.system_index.to( device=resolved_device, non_blocking=True, ), neighbor_list=moved_neighbor_list, state=self.state.to( n_atoms=self.n_atoms, n_systems=self.n_systems, dtype=resolved_dtype, device=resolved_device, ) if self.state is not None else None, metadata=metadata, source_atoms=self.source_atoms, )
[docs] def pin_memory(self) -> "UFPInput": """Pin stored tensors in place and return ``self`` for dataloader-style use.""" self.positions = self.positions.pin_memory() self.cell = self.cell.pin_memory() self.pbc = self.pbc.pin_memory() self.atomic_numbers = self.atomic_numbers.pin_memory() self.system_index = self.system_index.pin_memory() if self.state is not None: self.state = self.state.pin_memory() self.atomic_charges = self.state.atomic_charges self.atomic_spin_moments = self.state.atomic_spin_moments self.system_charges = self.state.system_charges self.system_spin_moments = self.state.system_spin_moments if self.neighbor_list is not None: self.neighbor_list = self.neighbor_list.pin_memory() return self
@property def n_atoms(self) -> int: """Return the total number of atoms across all systems.""" return int(self.positions.shape[0]) @property def n_systems(self) -> int: """Return the number of systems represented by this input.""" return int(self.cell.shape[0]) @property def device(self) -> torch.device: """Return the torch device shared by the stored tensors.""" return self.positions.device @property def dtype(self) -> torch.dtype: """Return the floating-point dtype used for geometric tensors.""" return self.positions.dtype @property def system_sizes(self) -> list[int]: """Return the atom count for each system in concatenated order.""" counts = torch.bincount(self.system_index, minlength=self.n_systems) return [int(value) for value in counts.tolist()] @property def atom_slices(self) -> list[slice]: """Return per-system slices into the concatenated atom axis.""" atom_slices = [] start = 0 for size in self.system_sizes: atom_slices.append(slice(start, start + size)) start += size return atom_slices @property def atoms(self) -> ase.Atoms: """Return the original ASE structure for a single-system input.""" if self.source_atoms is None or len(self.source_atoms) != 1: raise AttributeError( "`atoms` is only available for single-system ASE inputs" ) return self.source_atoms[0] def _require_neighbor_list(self) -> NeighborListData: """Return the stored neighbor list or raise when geometry is missing.""" if self.neighbor_list is None: raise RuntimeError("this model input does not contain a neighbor list") return self.neighbor_list def _pair_geometry(self) -> _PairGeometry: """Return the private pair-geometry cache owner.""" if self._pair_geometry_cache is None: self._pair_geometry_cache = _PairGeometry(self) return self._pair_geometry_cache @property def _pair_system_index_cache(self) -> Optional[torch.Tensor]: """Compatibility access to the pair-system-index cache.""" return self._pair_geometry().pair_system_index_cache @property def _pair_vectors_cache(self) -> Optional[torch.Tensor]: """Compatibility access to the pair-vector cache.""" return self._pair_geometry().pair_vectors_cache @property def _pair_distances_cache(self) -> Optional[torch.Tensor]: """Compatibility access to the pair-distance cache.""" return self._pair_geometry().pair_distances_cache @property def _pair_atomic_numbers_cache( self, ) -> Optional[tuple[torch.Tensor, torch.Tensor]]: """Compatibility access to the pair-atomic-number cache.""" return self._pair_geometry().pair_atomic_numbers_cache @property def _pair_mask_cache(self) -> dict[tuple[int, int, bool], torch.Tensor]: """Compatibility access to cached pair masks.""" return self._pair_geometry().pair_mask_cache @property def _atomic_category_cache(self) -> dict[tuple[int, ...], torch.Tensor]: """Compatibility access to cached atomic categories.""" return self._pair_geometry().atomic_category_cache @property def _pair_category_cache( self, ) -> dict[tuple[tuple[int, ...], bool], torch.Tensor]: """Compatibility access to cached pair categories.""" return self._pair_geometry().pair_category_cache def _normalize_pair_mask(self, mask: ArrayLike) -> torch.Tensor: """Validate and tensorize a pair-selection mask.""" return self._pair_geometry().normalize_pair_mask(mask)
[docs] def pair_indices( self, mask: Optional[ArrayLike] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Return atom indices for all or selected neighbor-list pairs.""" return self._pair_geometry().pair_indices(mask)
def _full_pair_system_index(self) -> torch.Tensor: """Cache the system index attached to every neighbor-list pair.""" return self._pair_geometry().full_pair_system_index()
[docs] def pair_system_index(self, mask: Optional[ArrayLike] = None) -> torch.Tensor: """Return the owning system index for all or selected pairs.""" return self._pair_geometry().pair_system_index(mask)
def _full_pair_vectors(self) -> torch.Tensor: """Cache neighbor-list displacement vectors in concatenated atom coordinates.""" return self._pair_geometry().full_pair_vectors()
[docs] def pair_vectors(self, mask: Optional[ArrayLike] = None) -> torch.Tensor: """Return pair displacement vectors for all or selected pairs.""" return self._pair_geometry().pair_vectors(mask)
[docs] def pair_shifts(self, mask: Optional[ArrayLike] = None) -> torch.Tensor: """Return cell-shift vectors for selected neighbor-list entries.""" return self._pair_geometry().pair_shifts(mask)
def _full_pair_distances(self) -> torch.Tensor: """Cache pair distances derived from the full pair-vector tensor.""" return self._pair_geometry().full_pair_distances()
[docs] def pair_distances(self, mask: Optional[ArrayLike] = None) -> torch.Tensor: """Return pair distances for all or selected neighbor-list entries.""" return self._pair_geometry().pair_distances(mask)
def _full_pair_atomic_numbers(self) -> tuple[torch.Tensor, torch.Tensor]: """Cache atomic-number pairs for the full neighbor list.""" return self._pair_geometry().full_pair_atomic_numbers()
[docs] def pair_atomic_numbers( self, mask: Optional[ArrayLike] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Return atomic numbers for the first and second atom of each selected pair.""" return self._pair_geometry().pair_atomic_numbers(mask)
[docs] def pair_mask( self, first_atomic_number: int, second_atomic_number: int, *, symmetric: bool = False, ) -> torch.Tensor: """Build a mask selecting pairs with the requested atomic numbers.""" return self._pair_geometry().pair_mask( first_atomic_number, second_atomic_number, symmetric=symmetric, )
[docs] def atomic_category_indices(self, atomic_types: Sequence[int]) -> torch.Tensor: """ Return atomic-category indices for every atom. Atomic numbers outside ``atomic_types`` are marked ``-1``. Category ordering follows the sorted unique atomic types used throughout UFP term layouts. """ return self._pair_geometry().atomic_category_indices(atomic_types)
[docs] def pair_category_indices( self, atomic_types: Sequence[int], *, symmetric: bool = True, ) -> torch.Tensor: """ Return pair-category indices for every neighbor-list row. Pairs containing atomic numbers outside ``atomic_types`` are marked ``-1``. When ``symmetric`` is true, category ordering matches unordered combinations with replacement over sorted unique atomic types. """ return self._pair_geometry().pair_category_indices( atomic_types, symmetric=symmetric, )
[docs] def slice_neighbor_list(self, mask: ArrayLike) -> NeighborListData: """Return a ``NeighborListData`` view restricted to selected pairs.""" return self._pair_geometry().slice_neighbor_list(mask)
[docs] def missing_state_fields(self, fields: Sequence[str]) -> tuple[str, ...]: """Return required charge/spin state fields that are absent.""" if self.state is None: return tuple(str(field) for field in fields) return self.state.missing_fields(fields)
__all__ = [ "UFPInputState", "UFPInput", ]