Source code for ufp.core.state

"""Optional per-atom and per-system state carried by :class:`UFPInput`."""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, cast

import ase
import torch

from ufp.core._arrays import ArrayLike, _to_tensor


_ATOMIC_CHARGE_ARRAY_KEYS = ("charges", "initial_charges")
_ATOMIC_SPIN_ARRAY_KEYS = ("magmoms", "initial_magmoms")
_SYSTEM_CHARGE_INFO_KEYS = ("system_charge", "charge")
_SYSTEM_SPIN_INFO_KEYS = ("system_spin_moment", "spin_moment", "magmom")


def _optional_atom_tensor(
    value: Optional[ArrayLike],
    *,
    name: str,
    n_atoms: int,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor | None:
    """Normalize one optional atomwise scalar tensor."""
    if value is None:
        return None
    tensor = _to_tensor(value, dtype=dtype, device=device)
    if tensor.ndim == 2 and tuple(tensor.shape) == (n_atoms, 1):
        tensor = tensor[:, 0]
    if tensor.ndim != 1 or tensor.shape[0] != n_atoms:
        raise ValueError(f"`{name}` must have shape ({n_atoms},)")
    return tensor


def _optional_system_tensor(
    value: Optional[ArrayLike],
    *,
    name: str,
    n_systems: int,
    dtype: torch.dtype,
    device: torch.device,
) -> torch.Tensor | None:
    """Normalize one optional systemwise scalar tensor."""
    if value is None:
        return None
    tensor = _to_tensor(value, dtype=dtype, device=device)
    if tensor.ndim == 0 and n_systems == 1:
        tensor = tensor.reshape(1)
    if tensor.ndim == 2 and tuple(tensor.shape) == (n_systems, 1):
        tensor = tensor[:, 0]
    if tensor.ndim != 1 or tensor.shape[0] != n_systems:
        raise ValueError(f"`{name}` must have shape ({n_systems},)")
    return tensor


def _atoms_array(atoms: ase.Atoms, keys: Sequence[str]) -> object | None:
    """Return the first matching ASE atom array from ``keys``."""
    for key in keys:
        if key in atoms.arrays:
            return atoms.arrays[key]
    return None


def _atoms_info_scalar(atoms: ase.Atoms, keys: Sequence[str]) -> object | None:
    """Return the first matching scalar ASE info value from ``keys``."""
    for key in keys:
        if key in atoms.info:
            return atoms.info[key]
    return None


def _require_all_or_none(values: Sequence[object | None], *, name: str) -> bool:
    """Return whether all values are present, raising on mixed availability."""
    any_present = any(value is not None for value in values)
    if not any_present:
        return False
    if not all(value is not None for value in values):
        raise ValueError(
            f"either all ASE structures must provide `{name}`, or none may"
        )
    return True


[docs] @dataclass class UFPInputState: """ Optional local and system-level state variables aligned with ``UFPInput``. The first supported state fields are local charge, collinear scalar magnetic moment, total system charge, and total system spin moment. Missing fields stay as ``None`` so legacy models and datasets keep their existing behavior. """ atomic_charges: Optional[ArrayLike] = None atomic_spin_moments: Optional[ArrayLike] = None system_charges: Optional[ArrayLike] = None system_spin_moments: Optional[ArrayLike] = None
[docs] @classmethod def from_ase_list( cls, atoms_list: Sequence[ase.Atoms], *, dtype: torch.dtype, device: torch.device | None = None, ) -> "UFPInputState": """Extract supported charge and spin arrays from ASE structures.""" atomic_charge_values = [ _atoms_array(atoms, _ATOMIC_CHARGE_ARRAY_KEYS) for atoms in atoms_list ] atomic_spin_values = [ _atoms_array(atoms, _ATOMIC_SPIN_ARRAY_KEYS) for atoms in atoms_list ] system_charge_values = [ _atoms_info_scalar(atoms, _SYSTEM_CHARGE_INFO_KEYS) for atoms in atoms_list ] system_spin_values = [ _atoms_info_scalar(atoms, _SYSTEM_SPIN_INFO_KEYS) for atoms in atoms_list ] atomic_charges = None if _require_all_or_none(atomic_charge_values, name="atomic charge state"): atomic_charges = torch.cat( [ torch.as_tensor(value, dtype=dtype, device=device).reshape(-1) for value in atomic_charge_values ], dim=0, ) atomic_spin_moments = None if _require_all_or_none(atomic_spin_values, name="atomic spin state"): atomic_spin_moments = torch.cat( [ torch.as_tensor(value, dtype=dtype, device=device).reshape(-1) for value in atomic_spin_values ], dim=0, ) system_charges = None if _require_all_or_none(system_charge_values, name="system charge state"): system_charges = torch.as_tensor( [float(cast(Any, value)) for value in system_charge_values], dtype=dtype, device=device, ) system_spin_moments = None if _require_all_or_none(system_spin_values, name="system spin state"): system_spin_moments = torch.as_tensor( [float(cast(Any, value)) for value in system_spin_values], dtype=dtype, device=device, ) return cls( atomic_charges=atomic_charges, atomic_spin_moments=atomic_spin_moments, system_charges=system_charges, system_spin_moments=system_spin_moments, )
[docs] def as_torch( self, *, n_atoms: int, n_systems: int, dtype: torch.dtype, device: torch.device, ) -> "UFPInputState": """Return a copy normalized to the requested input shape and tensor options.""" return UFPInputState( atomic_charges=_optional_atom_tensor( self.atomic_charges, name="atomic_charges", n_atoms=n_atoms, dtype=dtype, device=device, ), atomic_spin_moments=_optional_atom_tensor( self.atomic_spin_moments, name="atomic_spin_moments", n_atoms=n_atoms, dtype=dtype, device=device, ), system_charges=_optional_system_tensor( self.system_charges, name="system_charges", n_systems=n_systems, dtype=dtype, device=device, ), system_spin_moments=_optional_system_tensor( self.system_spin_moments, name="system_spin_moments", n_systems=n_systems, dtype=dtype, device=device, ), )
[docs] def to( self, *, n_atoms: int, n_systems: int, dtype: torch.dtype, device: torch.device, ) -> "UFPInputState": """Move state tensors to the requested device and dtype.""" return self.as_torch( n_atoms=n_atoms, n_systems=n_systems, dtype=dtype, device=device, )
[docs] def pin_memory(self) -> "UFPInputState": """Pin any torch-backed state tensors in place and return ``self``.""" for name in ( "atomic_charges", "atomic_spin_moments", "system_charges", "system_spin_moments", ): value = getattr(self, name) if isinstance(value, torch.Tensor): setattr(self, name, value.pin_memory()) return self
[docs] def missing_fields(self, fields: Sequence[str]) -> tuple[str, ...]: """Return requested state field names whose values are absent.""" missing = [] for field in fields: if not hasattr(self, field): raise ValueError(f"unknown input state field: {field!r}") if getattr(self, field) is None: missing.append(str(field)) return tuple(missing)
__all__ = [ "UFPInputState", ]