Source code for ufp.training.dataset

"""
Dataset and split helpers for supervised ``ase.Atoms`` samples.

Use this module to normalize energies, forces, stresses, and metadata into a
small PyTorch-friendly dataset layer.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from math import floor, isclose
from typing import Mapping, Optional, Sequence, TypeVar, Union

import ase
import torch
from ase.stress import voigt_6_to_full_3x3_stress
from torch.utils.data import DataLoader, Dataset, Subset

from ufp.neighbors._data import NeighborListData
from ufp.neighbors._neighbors import NeighborListBackend, build_neighbor_list


T = TypeVar("T")


def _as_scalar(value: object, *, name: str) -> float:
    """Coerce a scalar-like value into a Python float for stored supervision targets."""
    tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
    if tensor.numel() != 1:
        raise ValueError(f"`{name}` must contain exactly one scalar value")

    return float(tensor.reshape(-1)[0].item())


def _as_forces_tensor(value: object, *, n_atoms: int, name: str) -> torch.Tensor:
    """Coerce force targets into a detached ``(n_atoms, 3)`` tensor."""
    tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
    if tuple(tensor.shape) != (n_atoms, 3):
        raise ValueError(f"`{name}` must have shape ({n_atoms}, 3)")

    return tensor.detach().clone()


def _as_force_mask_tensor(value: object, *, n_atoms: int, name: str) -> torch.Tensor:
    """Coerce force-component masks into a detached bool ``(n_atoms, 3)`` tensor."""
    tensor = torch.as_tensor(value, dtype=torch.bool)
    if tuple(tensor.shape) != (n_atoms, 3):
        raise ValueError(f"`{name}` must have shape ({n_atoms}, 3)")

    return tensor.detach().clone()


def _as_stress_tensor(value: object, *, name: str) -> torch.Tensor:
    """Coerce stress targets into a detached full ``(3, 3)`` tensor."""
    tensor = torch.as_tensor(value, dtype=torch.get_default_dtype())
    if tuple(tensor.shape) == (3, 3):
        return tensor.detach().clone()

    if tuple(tensor.shape) == (6,):
        full = voigt_6_to_full_3x3_stress(tensor.detach().cpu().numpy())
        return torch.as_tensor(full, dtype=tensor.dtype)

    raise ValueError(f"`{name}` must have shape (3, 3) or (6,)")


def _sequence_or_none(
    values: Optional[Sequence[T]],
    *,
    size: int,
    name: str,
) -> Optional[list[T]]:
    """Validate optional per-sample sequences against the number of structures."""
    if values is None:
        return None

    sequence = list(values)
    if len(sequence) != size:
        raise ValueError(f"`{name}` must contain exactly {size} entries")

    return sequence


def _read_target(
    atoms: ase.Atoms,
    *,
    key: Optional[str],
    name: str,
    arrays_first: bool = False,
) -> object | None:
    """Read one target from ``atoms.info`` or ``atoms.arrays``."""
    if key is None:
        return None

    if arrays_first and key in atoms.arrays:
        return atoms.arrays[key]

    if key in atoms.info:
        return atoms.info[key]

    if key in atoms.arrays:
        return atoms.arrays[key]

    raise KeyError(f"ASE atoms object is missing `{name}` target `{key}`")


@dataclass(frozen=True)
class _ASETensorizedSample:
    """Cached tensor view of one ASE sample's structure fields."""

    positions: torch.Tensor
    cell: torch.Tensor
    pbc: torch.Tensor
    atomic_numbers: torch.Tensor


[docs] @dataclass(frozen=True) class ASEAtomsSample: """ Single supervised training sample backed by :class:`ase.Atoms`. """ atoms: ase.Atoms energy: Optional[float] = None forces: Optional[torch.Tensor] = None force_mask: Optional[torch.Tensor] = None stress: Optional[torch.Tensor] = None neighbor_list: Optional[NeighborListData] = None metadata: Mapping[str, object] = field(default_factory=dict) _tensorized: Optional[_ASETensorizedSample] = field( default=None, init=False, repr=False, compare=False, ) _neighbor_list_cache: dict[tuple[float, str], NeighborListData] = field( default_factory=dict, init=False, repr=False, compare=False, ) def __post_init__(self) -> None: """Copy the source atoms and normalize any stored supervision targets.""" if not isinstance(self.atoms, ase.Atoms): raise TypeError(f"`atoms` must be ase.Atoms, got {type(self.atoms)}") atoms = self.atoms.copy() object.__setattr__(self, "atoms", atoms) if self.energy is not None: object.__setattr__(self, "energy", _as_scalar(self.energy, name="energy")) if self.forces is not None: object.__setattr__( self, "forces", _as_forces_tensor(self.forces, n_atoms=len(atoms), name="forces"), ) if self.force_mask is not None: if self.forces is None: raise ValueError("`force_mask` requires `forces`") object.__setattr__( self, "force_mask", _as_force_mask_tensor( self.force_mask, n_atoms=len(atoms), name="force_mask", ), ) if self.stress is not None: object.__setattr__( self, "stress", _as_stress_tensor(self.stress, name="stress"), ) if self.neighbor_list is not None and not isinstance( self.neighbor_list, NeighborListData, ): raise TypeError("`neighbor_list` must be NeighborListData when provided") object.__setattr__(self, "metadata", dict(self.metadata))
[docs] def tensorized(self) -> _ASETensorizedSample: """Return and cache the tensorized structural fields for this sample.""" tensorized = self._tensorized if tensorized is None: tensorized = _ASETensorizedSample( positions=torch.as_tensor( self.atoms.positions, dtype=torch.get_default_dtype(), ), cell=torch.as_tensor( self.atoms.cell.array, dtype=torch.get_default_dtype(), ), pbc=torch.as_tensor( self.atoms.pbc, dtype=torch.bool, ), atomic_numbers=torch.as_tensor( self.atoms.numbers, dtype=torch.int64, ), ) object.__setattr__(self, "_tensorized", tensorized) return tensorized
[docs] def cached_neighbor_list( self, *, cutoff: float, backend: Union[str, NeighborListBackend], ) -> NeighborListData: """Return an explicit or cached neighbor list.""" if self.neighbor_list is not None: return self.neighbor_list key = (float(cutoff), NeighborListBackend(backend).value) neighbor_list = self._neighbor_list_cache.get(key) if neighbor_list is None: neighbor_list = build_neighbor_list( atoms=self.atoms, cutoff=cutoff, backend=backend, arrays="torch", ) self._neighbor_list_cache[key] = neighbor_list return neighbor_list
[docs] class ASEAtomsDataset(Dataset[ASEAtomsSample]): """ Thin :mod:`torch.utils.data` dataset for supervised ASE structures. """ def __init__(self, samples: Sequence[ASEAtomsSample]) -> None: """Store a non-empty sequence of normalized ASE training samples.""" normalized_samples = list(samples) if not normalized_samples: raise ValueError("`samples` must contain at least one ASEAtomsSample") for sample in normalized_samples: if not isinstance(sample, ASEAtomsSample): raise TypeError("all dataset entries must be ASEAtomsSample instances") self._samples = tuple(normalized_samples) self._tensorized_cache_warmed = False self._warmed_neighbor_list_keys: set[tuple[float, str]] = set()
[docs] @classmethod def from_atoms( cls, atoms_list: Sequence[ase.Atoms], *, energies: Optional[Sequence[object]] = None, forces: Optional[Sequence[object]] = None, force_masks: Optional[Sequence[object]] = None, stresses: Optional[Sequence[object]] = None, neighbor_lists: Optional[Sequence[NeighborListData | None]] = None, energy_key: Optional[str] = "energy", forces_key: Optional[str] = None, stress_key: Optional[str] = None, metadata: Optional[Sequence[Mapping[str, object]]] = None, metadata_keys: Sequence[str] = (), ) -> "ASEAtomsDataset": """Build dataset samples directly from labeled ``ase.Atoms`` objects.""" structures = list(atoms_list) if not structures: raise ValueError("`atoms_list` must contain at least one ASE structure") explicit_energies = _sequence_or_none( energies, size=len(structures), name="energies", ) explicit_forces = _sequence_or_none( forces, size=len(structures), name="forces", ) explicit_force_masks = _sequence_or_none( force_masks, size=len(structures), name="force_masks", ) explicit_stresses = _sequence_or_none( stresses, size=len(structures), name="stresses", ) explicit_neighbor_lists = _sequence_or_none( neighbor_lists, size=len(structures), name="neighbor_lists", ) explicit_metadata = _sequence_or_none( metadata, size=len(structures), name="metadata", ) samples: list[ASEAtomsSample] = [] for index, atoms in enumerate(structures): if not isinstance(atoms, ase.Atoms): raise TypeError( f"`atoms_list` should contain ase.Atoms, got {type(atoms)}" ) energy = ( explicit_energies[index] if explicit_energies is not None else _read_target(atoms, key=energy_key, name="energy") ) force_values = ( explicit_forces[index] if explicit_forces is not None else _read_target( atoms, key=forces_key, name="forces", arrays_first=True, ) ) force_mask_values = ( None if explicit_force_masks is None else explicit_force_masks[index] ) stress_values = ( explicit_stresses[index] if explicit_stresses is not None else _read_target(atoms, key=stress_key, name="stress") ) if explicit_metadata is not None: sample_metadata = dict(explicit_metadata[index]) else: sample_metadata = { key: atoms.info[key] for key in metadata_keys if key in atoms.info } samples.append( ASEAtomsSample( atoms=atoms, energy=energy, # type: ignore[arg-type] forces=force_values, # type: ignore[arg-type] force_mask=force_mask_values, # type: ignore[arg-type] stress=stress_values, # type: ignore[arg-type] neighbor_list=( None if explicit_neighbor_lists is None else explicit_neighbor_lists[index] ), metadata=sample_metadata, ) ) return cls(samples)
def __len__(self) -> int: """Return the number of supervised structures in the dataset.""" return len(self._samples) def __getitem__(self, index: int) -> ASEAtomsSample: """Return one normalized training sample by index.""" return self._samples[index]
[docs] def cache_tensorized_samples(self) -> None: """Warm the per-sample tensor cache.""" if self._tensorized_cache_warmed: return for sample in self._samples: sample.tensorized() self._tensorized_cache_warmed = True
[docs] def cache_neighbor_lists( self, *, cutoff: float, backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO, ) -> None: """Warm neighbor-list caches for every sample for one cutoff/backend pair.""" key = (float(cutoff), NeighborListBackend(backend).value) if key in self._warmed_neighbor_list_keys: return for sample in self._samples: sample.cached_neighbor_list(cutoff=cutoff, backend=backend) self._warmed_neighbor_list_keys.add(key)
[docs] @dataclass(frozen=True) class ASEDatasetSplit: """Grouped train, validation, and test subsets over one ASE dataset.""" train: Subset validation: Subset test: Subset
[docs] @dataclass(frozen=True) class ASEDataLoaders: """Grouped dataloaders aligned with an ``ASEDatasetSplit``.""" train: DataLoader validation: DataLoader test: DataLoader
def _split_lengths( size: int, *, train_fraction: float, validation_fraction: float, test_fraction: float, ) -> tuple[int, int, int]: """Convert split fractions into integer subset lengths.""" fractions = [train_fraction, validation_fraction, test_fraction] if any(fraction < 0.0 for fraction in fractions): raise ValueError("split fractions must be non-negative") if not isclose(sum(fractions), 1.0, rel_tol=0.0, abs_tol=1.0e-8): raise ValueError("split fractions must sum to 1.0") raw_lengths = [size * fraction for fraction in fractions] lengths = [floor(value) for value in raw_lengths] remainder = size - sum(lengths) ordering = sorted( range(len(raw_lengths)), key=lambda idx: raw_lengths[idx] - lengths[idx], reverse=True, ) for idx in ordering[:remainder]: lengths[idx] += 1 return lengths[0], lengths[1], lengths[2]
[docs] def split_ase_dataset( dataset: Dataset[ASEAtomsSample] | Sequence[ASEAtomsSample], *, train_fraction: float = 0.8, validation_fraction: float = 0.1, test_fraction: float = 0.1, shuffle: bool = True, seed: int = 0, ) -> ASEDatasetSplit: """Split one ASE dataset into train, validation, and test subsets.""" if isinstance(dataset, Dataset): normalized_dataset = dataset else: normalized_dataset = ASEAtomsDataset(dataset) dataset_size = len(normalized_dataset) # type: ignore[arg-type] if dataset_size == 0: raise ValueError("`dataset` must contain at least one sample") n_train, n_validation, n_test = _split_lengths( dataset_size, train_fraction=train_fraction, validation_fraction=validation_fraction, test_fraction=test_fraction, ) if shuffle: generator = torch.Generator() generator.manual_seed(seed) indices = torch.randperm(dataset_size, generator=generator).tolist() else: indices = list(range(dataset_size)) train_indices = indices[:n_train] validation_indices = indices[n_train : n_train + n_validation] test_indices = indices[n_train + n_validation :] return ASEDatasetSplit( train=Subset(normalized_dataset, train_indices), validation=Subset(normalized_dataset, validation_indices), test=Subset(normalized_dataset, test_indices), )
__all__ = [ "ASEAtomsDataset", "ASEAtomsSample", "ASEDataLoaders", "ASEDatasetSplit", "split_ase_dataset", ]