Source code for ufp.core.state
"""Optional per-atom and per-system state carried by :class:`UFPInput`."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, cast
import ase
import torch
from ufp.core._arrays import ArrayLike, _to_tensor
_ATOMIC_CHARGE_ARRAY_KEYS = ("charges", "initial_charges")
_ATOMIC_SPIN_ARRAY_KEYS = ("magmoms", "initial_magmoms")
_SYSTEM_CHARGE_INFO_KEYS = ("system_charge", "charge")
_SYSTEM_SPIN_INFO_KEYS = ("system_spin_moment", "spin_moment", "magmom")
def _optional_atom_tensor(
value: Optional[ArrayLike],
*,
name: str,
n_atoms: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor | None:
"""Normalize one optional atomwise scalar tensor."""
if value is None:
return None
tensor = _to_tensor(value, dtype=dtype, device=device)
if tensor.ndim == 2 and tuple(tensor.shape) == (n_atoms, 1):
tensor = tensor[:, 0]
if tensor.ndim != 1 or tensor.shape[0] != n_atoms:
raise ValueError(f"`{name}` must have shape ({n_atoms},)")
return tensor
def _optional_system_tensor(
value: Optional[ArrayLike],
*,
name: str,
n_systems: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor | None:
"""Normalize one optional systemwise scalar tensor."""
if value is None:
return None
tensor = _to_tensor(value, dtype=dtype, device=device)
if tensor.ndim == 0 and n_systems == 1:
tensor = tensor.reshape(1)
if tensor.ndim == 2 and tuple(tensor.shape) == (n_systems, 1):
tensor = tensor[:, 0]
if tensor.ndim != 1 or tensor.shape[0] != n_systems:
raise ValueError(f"`{name}` must have shape ({n_systems},)")
return tensor
def _atoms_array(atoms: ase.Atoms, keys: Sequence[str]) -> object | None:
"""Return the first matching ASE atom array from ``keys``."""
for key in keys:
if key in atoms.arrays:
return atoms.arrays[key]
return None
def _atoms_info_scalar(atoms: ase.Atoms, keys: Sequence[str]) -> object | None:
"""Return the first matching scalar ASE info value from ``keys``."""
for key in keys:
if key in atoms.info:
return atoms.info[key]
return None
def _require_all_or_none(values: Sequence[object | None], *, name: str) -> bool:
"""Return whether all values are present, raising on mixed availability."""
any_present = any(value is not None for value in values)
if not any_present:
return False
if not all(value is not None for value in values):
raise ValueError(
f"either all ASE structures must provide `{name}`, or none may"
)
return True
[docs]
@dataclass
class UFPInputState:
"""
Optional local and system-level state variables aligned with ``UFPInput``.
The first supported state fields are local charge, collinear scalar magnetic
moment, total system charge, and total system spin moment. Missing fields stay
as ``None`` so legacy models and datasets keep their existing behavior.
"""
atomic_charges: Optional[ArrayLike] = None
atomic_spin_moments: Optional[ArrayLike] = None
system_charges: Optional[ArrayLike] = None
system_spin_moments: Optional[ArrayLike] = None
[docs]
@classmethod
def from_ase_list(
cls,
atoms_list: Sequence[ase.Atoms],
*,
dtype: torch.dtype,
device: torch.device | None = None,
) -> "UFPInputState":
"""Extract supported charge and spin arrays from ASE structures."""
atomic_charge_values = [
_atoms_array(atoms, _ATOMIC_CHARGE_ARRAY_KEYS) for atoms in atoms_list
]
atomic_spin_values = [
_atoms_array(atoms, _ATOMIC_SPIN_ARRAY_KEYS) for atoms in atoms_list
]
system_charge_values = [
_atoms_info_scalar(atoms, _SYSTEM_CHARGE_INFO_KEYS) for atoms in atoms_list
]
system_spin_values = [
_atoms_info_scalar(atoms, _SYSTEM_SPIN_INFO_KEYS) for atoms in atoms_list
]
atomic_charges = None
if _require_all_or_none(atomic_charge_values, name="atomic charge state"):
atomic_charges = torch.cat(
[
torch.as_tensor(value, dtype=dtype, device=device).reshape(-1)
for value in atomic_charge_values
],
dim=0,
)
atomic_spin_moments = None
if _require_all_or_none(atomic_spin_values, name="atomic spin state"):
atomic_spin_moments = torch.cat(
[
torch.as_tensor(value, dtype=dtype, device=device).reshape(-1)
for value in atomic_spin_values
],
dim=0,
)
system_charges = None
if _require_all_or_none(system_charge_values, name="system charge state"):
system_charges = torch.as_tensor(
[float(cast(Any, value)) for value in system_charge_values],
dtype=dtype,
device=device,
)
system_spin_moments = None
if _require_all_or_none(system_spin_values, name="system spin state"):
system_spin_moments = torch.as_tensor(
[float(cast(Any, value)) for value in system_spin_values],
dtype=dtype,
device=device,
)
return cls(
atomic_charges=atomic_charges,
atomic_spin_moments=atomic_spin_moments,
system_charges=system_charges,
system_spin_moments=system_spin_moments,
)
[docs]
def as_torch(
self,
*,
n_atoms: int,
n_systems: int,
dtype: torch.dtype,
device: torch.device,
) -> "UFPInputState":
"""Return a copy normalized to the requested input shape and tensor options."""
return UFPInputState(
atomic_charges=_optional_atom_tensor(
self.atomic_charges,
name="atomic_charges",
n_atoms=n_atoms,
dtype=dtype,
device=device,
),
atomic_spin_moments=_optional_atom_tensor(
self.atomic_spin_moments,
name="atomic_spin_moments",
n_atoms=n_atoms,
dtype=dtype,
device=device,
),
system_charges=_optional_system_tensor(
self.system_charges,
name="system_charges",
n_systems=n_systems,
dtype=dtype,
device=device,
),
system_spin_moments=_optional_system_tensor(
self.system_spin_moments,
name="system_spin_moments",
n_systems=n_systems,
dtype=dtype,
device=device,
),
)
[docs]
def to(
self,
*,
n_atoms: int,
n_systems: int,
dtype: torch.dtype,
device: torch.device,
) -> "UFPInputState":
"""Move state tensors to the requested device and dtype."""
return self.as_torch(
n_atoms=n_atoms,
n_systems=n_systems,
dtype=dtype,
device=device,
)
[docs]
def pin_memory(self) -> "UFPInputState":
"""Pin any torch-backed state tensors in place and return ``self``."""
for name in (
"atomic_charges",
"atomic_spin_moments",
"system_charges",
"system_spin_moments",
):
value = getattr(self, name)
if isinstance(value, torch.Tensor):
setattr(self, name, value.pin_memory())
return self
[docs]
def missing_fields(self, fields: Sequence[str]) -> tuple[str, ...]:
"""Return requested state field names whose values are absent."""
missing = []
for field in fields:
if not hasattr(self, field):
raise ValueError(f"unknown input state field: {field!r}")
if getattr(self, field) is None:
missing.append(str(field))
return tuple(missing)
__all__ = [
"UFPInputState",
]