"""
Dataset and split helpers for supervised ``ase.Atoms`` samples.
Use this module to normalize energies, forces, stresses, and metadata into a
small PyTorch-friendly dataset layer.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from math import floor, isclose
from typing import Mapping, Optional, Sequence, TypeVar, Union
import ase
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from torch.utils.data import DataLoader, Dataset, Subset
from ufp.neighbors._data import NeighborListData
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list
T = TypeVar("T")
def _as_scalar(value: object, *, name: str) -> float:
"""Coerce a scalar-like value into a Python float for stored supervision targets."""
tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
if tensor.numel() != 1:
raise ValueError(f"`{name}` must contain exactly one scalar value")
return float(tensor.reshape(-1)[0].item())
def _as_forces_tensor(value: object, *, n_atoms: int, name: str) -> torch.Tensor:
"""Coerce force targets into a detached ``(n_atoms, 3)`` tensor."""
tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
if tuple(tensor.shape) != (n_atoms, 3):
raise ValueError(f"`{name}` must have shape ({n_atoms}, 3)")
return tensor.detach().clone()
def _as_force_mask_tensor(value: object, *, n_atoms: int, name: str) -> torch.Tensor:
"""Coerce force-component masks into a detached bool ``(n_atoms, 3)`` tensor."""
tensor = torch.as_tensor(value, dtype=torch.bool)
if tuple(tensor.shape) != (n_atoms, 3):
raise ValueError(f"`{name}` must have shape ({n_atoms}, 3)")
return tensor.detach().clone()
def _as_stress_tensor(value: object, *, name: str) -> torch.Tensor:
"""Coerce stress targets into a detached full ``(3, 3)`` tensor."""
tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
if tuple(tensor.shape) == (3, 3):
return tensor.detach().clone()
if tuple(tensor.shape) == (6,):
full = voigt_6_to_full_3x3_stress(tensor.detach().cpu().numpy())
return torch.as_tensor(full, dtype=tensor.dtype)
raise ValueError(f"`{name}` must have shape (3, 3) or (6,)")
def _sequence_or_none(
values: Optional[Sequence[T]],
*,
size: int,
name: str,
) -> Optional[list[T]]:
"""Validate optional per-sample sequences against the number of structures."""
if values is None:
return None
sequence = list(values)
if len(sequence) != size:
raise ValueError(f"`{name}` must contain exactly {size} entries")
return sequence
def _read_target(
atoms: ase.Atoms,
*,
key: Optional[str],
name: str,
arrays_first: bool = False,
) -> object | None:
"""Read one target from ``atoms.info`` or ``atoms.arrays``."""
if key is None:
return None
if arrays_first and key in atoms.arrays:
return atoms.arrays[key]
if key in atoms.info:
return atoms.info[key]
if key in atoms.arrays:
return atoms.arrays[key]
raise KeyError(f"ASE atoms object is missing `{name}` target `{key}`")
@dataclass(frozen=True)
class _ASETensorizedSample:
"""Cached tensor view of one ASE sample's structure fields."""
positions: torch.Tensor
cell: torch.Tensor
pbc: torch.Tensor
atomic_numbers: torch.Tensor
[docs]
@dataclass(frozen=True)
class ASEAtomsSample:
"""
Single supervised training sample backed by :class:`ase.Atoms`.
"""
atoms: ase.Atoms
energy: Optional[float] = None
forces: Optional[torch.Tensor] = None
force_mask: Optional[torch.Tensor] = None
stress: Optional[torch.Tensor] = None
neighbor_list: Optional[NeighborListData] = None
metadata: Mapping[str, object] = field(default_factory=dict)
_tensorized: Optional[_ASETensorizedSample] = field(
default=None,
init=False,
repr=False,
compare=False,
)
_neighbor_list_cache: dict[tuple[float, str], NeighborListData] = field(
default_factory=dict,
init=False,
repr=False,
compare=False,
)
def __post_init__(self) -> None:
"""Copy the source atoms and normalize any stored supervision targets."""
if not isinstance(self.atoms, ase.Atoms):
raise TypeError(f"`atoms` must be ase.Atoms, got {type(self.atoms)}")
atoms = self.atoms.copy()
object.__setattr__(self, "atoms", atoms)
if self.energy is not None:
object.__setattr__(self, "energy", _as_scalar(self.energy, name="energy"))
if self.forces is not None:
object.__setattr__(
self,
"forces",
_as_forces_tensor(self.forces, n_atoms=len(atoms), name="forces"),
)
if self.force_mask is not None:
if self.forces is None:
raise ValueError("`force_mask` requires `forces`")
object.__setattr__(
self,
"force_mask",
_as_force_mask_tensor(
self.force_mask,
n_atoms=len(atoms),
name="force_mask",
),
)
if self.stress is not None:
object.__setattr__(
self,
"stress",
_as_stress_tensor(self.stress, name="stress"),
)
if self.neighbor_list is not None and not isinstance(
self.neighbor_list,
NeighborListData,
):
raise TypeError("`neighbor_list` must be NeighborListData when provided")
object.__setattr__(self, "metadata", dict(self.metadata))
[docs]
def tensorized(self) -> _ASETensorizedSample:
"""Return and cache the tensorized structural fields for this sample."""
tensorized = self._tensorized
if tensorized is None:
tensorized = _ASETensorizedSample(
positions=torch.as_tensor(
self.atoms.positions,
dtype=torch.get_default_dtype(),
),
cell=torch.as_tensor(
self.atoms.cell.array,
dtype=torch.get_default_dtype(),
),
pbc=torch.as_tensor(
self.atoms.pbc,
dtype=torch.bool,
),
atomic_numbers=torch.as_tensor(
self.atoms.numbers,
dtype=torch.int64,
),
)
object.__setattr__(self, "_tensorized", tensorized)
return tensorized
[docs]
def cached_neighbor_list(
self,
*,
cutoff: float,
backend: Union[str, NeighborListBackend],
) -> NeighborListData:
"""Return an explicit or cached neighbor list."""
if self.neighbor_list is not None:
return self.neighbor_list
key = (float(cutoff), NeighborListBackend(backend).value)
neighbor_list = self._neighbor_list_cache.get(key)
if neighbor_list is None:
neighbor_list = build_neighbor_list(
atoms=self.atoms,
cutoff=cutoff,
backend=backend,
arrays="torch",
)
self._neighbor_list_cache[key] = neighbor_list
return neighbor_list
[docs]
class ASEAtomsDataset(Dataset[ASEAtomsSample]):
"""
Thin :mod:`torch.utils.data` dataset for supervised ASE structures.
"""
def __init__(self, samples: Sequence[ASEAtomsSample]) -> None:
"""Store a non-empty sequence of normalized ASE training samples."""
normalized_samples = list(samples)
if not normalized_samples:
raise ValueError("`samples` must contain at least one ASEAtomsSample")
for sample in normalized_samples:
if not isinstance(sample, ASEAtomsSample):
raise TypeError("all dataset entries must be ASEAtomsSample instances")
self._samples = tuple(normalized_samples)
self._tensorized_cache_warmed = False
self._warmed_neighbor_list_keys: set[tuple[float, str]] = set()
[docs]
@classmethod
def from_atoms(
cls,
atoms_list: Sequence[ase.Atoms],
*,
energies: Optional[Sequence[object]] = None,
forces: Optional[Sequence[object]] = None,
force_masks: Optional[Sequence[object]] = None,
stresses: Optional[Sequence[object]] = None,
neighbor_lists: Optional[Sequence[NeighborListData | None]] = None,
energy_key: Optional[str] = "energy",
forces_key: Optional[str] = None,
stress_key: Optional[str] = None,
metadata: Optional[Sequence[Mapping[str, object]]] = None,
metadata_keys: Sequence[str] = (),
) -> "ASEAtomsDataset":
"""Build dataset samples directly from labeled ``ase.Atoms`` objects."""
structures = list(atoms_list)
if not structures:
raise ValueError("`atoms_list` must contain at least one ASE structure")
explicit_energies = _sequence_or_none(
energies,
size=len(structures),
name="energies",
)
explicit_forces = _sequence_or_none(
forces,
size=len(structures),
name="forces",
)
explicit_force_masks = _sequence_or_none(
force_masks,
size=len(structures),
name="force_masks",
)
explicit_stresses = _sequence_or_none(
stresses,
size=len(structures),
name="stresses",
)
explicit_neighbor_lists = _sequence_or_none(
neighbor_lists,
size=len(structures),
name="neighbor_lists",
)
explicit_metadata = _sequence_or_none(
metadata,
size=len(structures),
name="metadata",
)
samples: list[ASEAtomsSample] = []
for index, atoms in enumerate(structures):
if not isinstance(atoms, ase.Atoms):
raise TypeError(
f"`atoms_list` should contain ase.Atoms, got {type(atoms)}"
)
energy = (
explicit_energies[index]
if explicit_energies is not None
else _read_target(atoms, key=energy_key, name="energy")
)
force_values = (
explicit_forces[index]
if explicit_forces is not None
else _read_target(
atoms,
key=forces_key,
name="forces",
arrays_first=True,
)
)
force_mask_values = (
None if explicit_force_masks is None else explicit_force_masks[index]
)
stress_values = (
explicit_stresses[index]
if explicit_stresses is not None
else _read_target(atoms, key=stress_key, name="stress")
)
if explicit_metadata is not None:
sample_metadata = dict(explicit_metadata[index])
else:
sample_metadata = {
key: atoms.info[key] for key in metadata_keys if key in atoms.info
}
samples.append(
ASEAtomsSample(
atoms=atoms,
energy=energy, # type: ignore[arg-type]
forces=force_values, # type: ignore[arg-type]
force_mask=force_mask_values, # type: ignore[arg-type]
stress=stress_values, # type: ignore[arg-type]
neighbor_list=(
None
if explicit_neighbor_lists is None
else explicit_neighbor_lists[index]
),
metadata=sample_metadata,
)
)
return cls(samples)
def __len__(self) -> int:
"""Return the number of supervised structures in the dataset."""
return len(self._samples)
def __getitem__(self, index: int) -> ASEAtomsSample:
"""Return one normalized training sample by index."""
return self._samples[index]
[docs]
def cache_tensorized_samples(self) -> None:
"""Warm the per-sample tensor cache."""
if self._tensorized_cache_warmed:
return
for sample in self._samples:
sample.tensorized()
self._tensorized_cache_warmed = True
[docs]
def cache_neighbor_lists(
self,
*,
cutoff: float,
backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO,
) -> None:
"""Warm neighbor-list caches for every sample for one cutoff/backend pair."""
key = (float(cutoff), NeighborListBackend(backend).value)
if key in self._warmed_neighbor_list_keys:
return
for sample in self._samples:
sample.cached_neighbor_list(cutoff=cutoff, backend=backend)
self._warmed_neighbor_list_keys.add(key)
[docs]
@dataclass(frozen=True)
class ASEDatasetSplit:
"""Grouped train, validation, and test subsets over one ASE dataset."""
train: Subset
validation: Subset
test: Subset
[docs]
@dataclass(frozen=True)
class ASEDataLoaders:
"""Grouped dataloaders aligned with an ``ASEDatasetSplit``."""
train: DataLoader
validation: DataLoader
test: DataLoader
def _split_lengths(
size: int,
*,
train_fraction: float,
validation_fraction: float,
test_fraction: float,
) -> tuple[int, int, int]:
"""Convert split fractions into integer subset lengths."""
fractions = [train_fraction, validation_fraction, test_fraction]
if any(fraction < 0.0 for fraction in fractions):
raise ValueError("split fractions must be non-negative")
if not isclose(sum(fractions), 1.0, rel_tol=0.0, abs_tol=1.0e-8):
raise ValueError("split fractions must sum to 1.0")
raw_lengths = [size * fraction for fraction in fractions]
lengths = [floor(value) for value in raw_lengths]
remainder = size - sum(lengths)
ordering = sorted(
range(len(raw_lengths)),
key=lambda idx: raw_lengths[idx] - lengths[idx],
reverse=True,
)
for idx in ordering[:remainder]:
lengths[idx] += 1
return lengths[0], lengths[1], lengths[2]
[docs]
def split_ase_dataset(
dataset: Dataset[ASEAtomsSample] | Sequence[ASEAtomsSample],
*,
train_fraction: float = 0.8,
validation_fraction: float = 0.1,
test_fraction: float = 0.1,
shuffle: bool = True,
seed: int = 0,
) -> ASEDatasetSplit:
"""Split one ASE dataset into train, validation, and test subsets."""
if isinstance(dataset, Dataset):
normalized_dataset = dataset
else:
normalized_dataset = ASEAtomsDataset(dataset)
dataset_size = len(normalized_dataset) # type: ignore[arg-type]
if dataset_size == 0:
raise ValueError("`dataset` must contain at least one sample")
n_train, n_validation, n_test = _split_lengths(
dataset_size,
train_fraction=train_fraction,
validation_fraction=validation_fraction,
test_fraction=test_fraction,
)
if shuffle:
generator = torch.Generator()
generator.manual_seed(seed)
indices = torch.randperm(dataset_size, generator=generator).tolist()
else:
indices = list(range(dataset_size))
train_indices = indices[:n_train]
validation_indices = indices[n_train : n_train + n_validation]
test_indices = indices[n_train + n_validation :]
return ASEDatasetSplit(
train=Subset(normalized_dataset, train_indices),
validation=Subset(normalized_dataset, validation_indices),
test=Subset(normalized_dataset, test_indices),
)
__all__ = [
"ASEAtomsDataset",
"ASEAtomsSample",
"ASEDataLoaders",
"ASEDatasetSplit",
"split_ase_dataset",
]