"""
Batch containers and dataloader helpers for ASE-backed training.
This module keeps batched structures and targets aligned while exposing a
``prepare_input`` bridge back into the shared model input path.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Mapping, Optional, Sequence
import ase
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from ufp.core.input import UFPInput
from ufp.core.potential import UFPotential
from ufp.core.state import UFPInputState
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
from ufp.neighbors._neighbors import NeighborListBackend
from ufp.terms._base import TermCacheOptions
from ufp.training.dataset import (
ASEAtomsDataset,
ASEAtomsSample,
ASEDataLoaders,
ASEDatasetSplit,
_as_force_mask_tensor,
_as_forces_tensor,
)
def _model_device(model: UFPotential) -> torch.device:
"""Return the first parameter or buffer device for ``model``."""
parameter = next(model.parameters(), None)
if parameter is not None:
return parameter.device
buffer = next(model.buffers(), None)
if buffer is not None:
return buffer.device
return torch.device("cpu")
def _iter_with_progress(
iterable,
*,
enabled: bool,
description: str,
total: int | None = None,
):
"""Wrap an iterable in a tqdm progress bar when requested."""
if not enabled:
return iterable
from tqdm.auto import tqdm
if total is None and hasattr(iterable, "__len__"):
total = len(iterable)
return tqdm(iterable, total=total, desc=description, leave=True)
def _warm_input_caches(
model: UFPotential,
inputs: UFPInput,
*,
feature_cache_storage: str = "cpu",
feature_cache_mode: str = "auto",
feature_cache_dir: Path | None = None,
cache_prefix: str = "batch",
legacy_cache_prefix: str | None = None,
include_per_atom_energy: bool = True,
) -> None:
"""Let model terms precompute reusable geometry metadata for one input."""
terms = getattr(model, "terms", ())
for term_index, term in enumerate(terms):
cache_input = getattr(term, "cache_input", None)
if callable(cache_input):
legacy_cache_prefixes: tuple[str, ...] = ()
if legacy_cache_prefix is not None:
legacy_cache_prefixes = (f"{legacy_cache_prefix}_term{term_index}",)
options = TermCacheOptions(
feature_cache_storage=feature_cache_storage,
feature_cache_mode=feature_cache_mode,
feature_cache_dir=feature_cache_dir,
cache_prefix=f"{cache_prefix}_term{term_index}",
legacy_cache_prefixes=legacy_cache_prefixes,
include_per_atom_energy=include_per_atom_energy,
)
cache_input(inputs, options=options)
[docs]
@dataclass
class ASEAtomsBatch:
"""
Mini-batch of structures and targets produced by :func:`ase_atoms_collate_fn`.
"""
atoms: Sequence[ase.Atoms]
energy: Optional[torch.Tensor] = None
forces: Optional[torch.Tensor] = None
force_mask: Optional[torch.Tensor] = None
stress: Optional[torch.Tensor] = None
neighbor_lists: Optional[Sequence[NeighborListData]] = None
metadata: Sequence[Mapping[str, object]] = field(default_factory=tuple)
samples: Sequence[ASEAtomsSample] = field(default_factory=tuple, repr=False)
cached_input: Optional[UFPInput] = field(default=None, repr=False)
system_sizes: torch.Tensor = field(init=False)
def __post_init__(self) -> None:
"""Validate structures and align optional energy, force, and stress targets."""
atoms = tuple(self.atoms)
if not atoms:
raise ValueError("`atoms` must contain at least one ASE structure")
if any(not isinstance(item, ase.Atoms) for item in atoms):
raise TypeError("all batched structures must be ase.Atoms")
self.atoms = atoms
self.system_sizes = torch.tensor(
[len(item) for item in atoms],
dtype=torch.int64,
)
if self.energy is not None:
energy = torch.as_tensor(self.energy, dtype=torch.get_default_dtype())
if energy.ndim == 0:
energy = energy.reshape(1)
if energy.ndim == 2 and tuple(energy.shape) == (self.n_systems, 1):
energy = energy[:, 0]
if tuple(energy.shape) != (self.n_systems,):
raise ValueError(f"`energy` must have shape ({self.n_systems},)")
self.energy = energy
if self.forces is not None:
self.forces = _as_forces_tensor(
self.forces,
n_atoms=self.n_atoms,
name="forces",
)
if self.force_mask is not None:
if self.forces is None:
raise ValueError("`force_mask` requires `forces`")
self.force_mask = _as_force_mask_tensor(
self.force_mask,
n_atoms=self.n_atoms,
name="force_mask",
)
if self.stress is not None:
stress = torch.as_tensor(self.stress, dtype=torch.get_default_dtype())
if tuple(stress.shape) == (3, 3) and self.n_systems == 1:
stress = stress.unsqueeze(0)
if tuple(stress.shape) != (self.n_systems, 3, 3):
raise ValueError(f"`stress` must have shape ({self.n_systems}, 3, 3)")
self.stress = stress
metadata = tuple(dict(item) for item in self.metadata)
if metadata and len(metadata) != self.n_systems:
raise ValueError("`metadata` must contain one entry per system")
self.metadata = (
metadata if metadata else tuple({} for _ in range(self.n_systems))
)
samples = tuple(self.samples)
if samples and len(samples) != self.n_systems:
raise ValueError("`samples` must contain one entry per system")
if any(not isinstance(sample, ASEAtomsSample) for sample in samples):
raise TypeError("all `samples` entries must be ASEAtomsSample")
self.samples = samples
if self.neighbor_lists is not None:
neighbor_lists = tuple(self.neighbor_lists)
if len(neighbor_lists) != self.n_systems:
raise ValueError("`neighbor_lists` must contain one entry per system")
if any(not isinstance(item, NeighborListData) for item in neighbor_lists):
raise TypeError("all `neighbor_lists` entries must be NeighborListData")
self.neighbor_lists = neighbor_lists
if self.cached_input is not None:
if not isinstance(self.cached_input, UFPInput):
raise TypeError("`cached_input` must be UFPInput when provided")
if self.cached_input.n_systems != self.n_systems:
raise ValueError("`cached_input` must describe the same systems")
if self.cached_input.n_atoms != self.n_atoms:
raise ValueError("`cached_input` must describe the same atoms")
@property
def n_systems(self) -> int:
"""Return the number of structures stored in the batch."""
return len(self.atoms)
@property
def n_atoms(self) -> int:
"""Return the total atom count across the batch."""
return int(self.system_sizes.sum().item())
@property
def atom_slices(self) -> tuple[slice, ...]:
"""Return per-system slices into the concatenated atom axis."""
slices = []
start = 0
for size in self.system_sizes.tolist():
slices.append(slice(start, start + int(size)))
start += int(size)
return tuple(slices)
[docs]
def pin_memory(self) -> "ASEAtomsBatch":
"""Pin stored tensors in place for faster host-to-device transfer."""
if self.energy is not None:
self.energy = self.energy.pin_memory()
if self.forces is not None:
self.forces = self.forces.pin_memory()
if self.force_mask is not None:
self.force_mask = self.force_mask.pin_memory()
if self.stress is not None:
self.stress = self.stress.pin_memory()
if self.neighbor_lists is not None:
self.neighbor_lists = tuple(
neighbor_list.pin_memory() for neighbor_list in self.neighbor_lists
)
if self.cached_input is not None:
self.cached_input = self.cached_input.pin_memory()
self.system_sizes = self.system_sizes.pin_memory()
return self
[docs]
def cache_targets_on_device(
self,
*,
device: torch.device,
dtype: torch.dtype,
) -> "ASEAtomsBatch":
"""Move supervised targets to the training device in place."""
if self.energy is not None:
self.energy = self.energy.to(
device=device,
dtype=dtype,
non_blocking=True,
)
if self.forces is not None:
self.forces = self.forces.to(
device=device,
dtype=dtype,
non_blocking=True,
)
if self.force_mask is not None:
self.force_mask = self.force_mask.to(
device=device,
non_blocking=True,
)
if self.stress is not None:
self.stress = self.stress.to(
device=device,
dtype=dtype,
non_blocking=True,
)
return self
[docs]
def ase_atoms_collate_fn(samples: Sequence[ASEAtomsSample]) -> ASEAtomsBatch:
"""Collate normalized ASE samples into one ``ASEAtomsBatch``."""
if not samples:
raise ValueError("`samples` must contain at least one ASEAtomsSample")
normalized_samples = tuple(samples)
atoms = tuple(sample.atoms for sample in normalized_samples)
metadata = tuple(dict(sample.metadata) for sample in samples)
neighbor_lists = None
neighbor_list_values = [sample.neighbor_list for sample in samples]
if any(value is not None for value in neighbor_list_values):
if not all(value is not None for value in neighbor_list_values):
raise ValueError(
"either all samples in a batch must provide `neighbor_list`, or none"
)
neighbor_lists = tuple(
value for value in neighbor_list_values if value is not None
)
energy_values = [sample.energy for sample in samples]
energies = None
if any(value is not None for value in energy_values):
if not all(value is not None for value in energy_values):
raise ValueError(
"either all samples in a batch must provide `energy`, or none"
)
energies = torch.tensor(
[float(value) for value in energy_values if value is not None],
dtype=torch.get_default_dtype(),
)
force_values = [sample.forces for sample in samples]
forces = None
if any(value is not None for value in force_values):
if not all(value is not None for value in force_values):
raise ValueError(
"either all samples in a batch must provide `forces`, or none"
)
forces = torch.cat(
[value for value in force_values if value is not None],
dim=0,
)
force_mask_values = [sample.force_mask for sample in samples]
force_mask = None
if any(value is not None for value in force_mask_values):
if not all(value is not None for value in force_mask_values):
raise ValueError(
"either all samples in a batch must provide `force_mask`, or none"
)
force_mask = torch.cat(
[value for value in force_mask_values if value is not None],
dim=0,
)
stress_values = [sample.stress for sample in samples]
stress = None
if any(value is not None for value in stress_values):
if not all(value is not None for value in stress_values):
raise ValueError(
"either all samples in a batch must provide `stress`, or none"
)
stress = torch.stack(
[value for value in stress_values if value is not None],
dim=0,
)
tensorized_samples = [sample.tensorized() for sample in normalized_samples]
system_sizes = torch.tensor(
[tensorized.positions.shape[0] for tensorized in tensorized_samples],
dtype=torch.int64,
)
atom_offsets = torch.cumsum(system_sizes, dim=0) - system_sizes
cached_neighbor_list = None
if neighbor_lists is not None:
cached_neighbor_list = concatenate_neighbor_lists(
list(neighbor_lists),
atom_offsets=[int(offset.item()) for offset in atom_offsets],
)
cached_input = UFPInput(
positions=torch.cat(
[tensorized.positions for tensorized in tensorized_samples],
dim=0,
),
cell=torch.stack(
[tensorized.cell for tensorized in tensorized_samples],
dim=0,
),
pbc=torch.stack(
[tensorized.pbc for tensorized in tensorized_samples],
dim=0,
),
atomic_numbers=torch.cat(
[tensorized.atomic_numbers for tensorized in tensorized_samples],
dim=0,
),
system_index=torch.repeat_interleave(
torch.arange(len(tensorized_samples), dtype=torch.int64),
system_sizes,
),
neighbor_list=cached_neighbor_list,
source_atoms=atoms,
state=UFPInputState.from_ase_list(
atoms,
dtype=torch.get_default_dtype(),
),
)
return ASEAtomsBatch(
atoms=atoms,
energy=energies,
forces=forces,
force_mask=force_mask,
stress=stress,
neighbor_lists=neighbor_lists,
metadata=metadata,
samples=normalized_samples,
cached_input=cached_input,
)
[docs]
def build_ase_dataloader(
dataset: Dataset[ASEAtomsSample] | Sequence[ASEAtomsSample],
*,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
prefetch_factor: int | None = None,
) -> DataLoader:
"""Build a dataloader that uses ``ase_atoms_collate_fn`` for batching."""
if isinstance(dataset, Dataset):
normalized_dataset = dataset
else:
normalized_dataset = ASEAtomsDataset(dataset)
worker_kwargs: dict[str, bool | int] = {}
if num_workers > 0:
worker_kwargs["persistent_workers"] = persistent_workers
if prefetch_factor is not None:
worker_kwargs["prefetch_factor"] = prefetch_factor
return DataLoader(
normalized_dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=ase_atoms_collate_fn,
**worker_kwargs,
)
[docs]
class CachedASEBatchLoader:
"""Reusable iterable over pre-collated batches for static ASE datasets."""
def __init__(
self,
dataset: Dataset[ASEAtomsSample],
batches: Sequence[ASEAtomsBatch],
*,
shuffle: bool = False,
seed: int = 0,
) -> None:
if not batches:
raise ValueError("`batches` must contain at least one ASEAtomsBatch")
self.dataset = dataset
self._batches = tuple(batches)
self._shuffle = bool(shuffle)
self._generator = torch.Generator()
self._generator.manual_seed(seed)
def __iter__(self):
"""Yield cached batches, optionally shuffling batch order each pass."""
if not self._shuffle:
yield from self._batches
return
order = torch.randperm(len(self._batches), generator=self._generator).tolist()
for index in order:
yield self._batches[int(index)]
def __len__(self) -> int:
"""Return the number of cached batches."""
return len(self._batches)
[docs]
def build_cached_ase_dataloader(
model: UFPotential,
dataset: Dataset[ASEAtomsSample] | Sequence[ASEAtomsSample],
*,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
prefetch_factor: int | None = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str] = None,
backend: Optional[str | NeighborListBackend] = None,
cache_batches_on_device: bool = False,
feature_cache_storage: str = "cpu",
feature_cache_mode: str = "auto",
feature_cache_dir: Path | str | None = None,
feature_cache_prefix: str = "batch",
feature_cache_per_atom_energy: bool = True,
seed: int = 0,
progress: bool = False,
progress_description: str = "Caching training batches",
) -> CachedASEBatchLoader:
"""
Pre-collate fixed-geometry batches and reuse them across epochs.
Cached loaders intentionally materialize batches in-process. Using PyTorch
worker processes here would retain shared-memory handles for every cached
tensor and can quickly exhaust the process file-descriptor limit.
Args:
model: Potential used to infer cutoff and optional feature-cache settings.
dataset: Supervised ASE samples or raw sample sequence.
batch_size: Number of samples per cached batch.
shuffle: Whether to shuffle cached batch order on each iteration.
drop_last: Whether to drop the final incomplete batch before caching.
num_workers: Worker count used only while building the initial cache.
pin_memory: Whether the initial data loader should pin memory.
persistent_workers: Whether worker processes should persist while caching.
prefetch_factor: Optional prefetch factor for worker-based cache creation.
dtype: Optional dtype for cached inputs.
device: Optional device for cached inputs.
backend: Optional neighbor-list backend override.
cache_batches_on_device: Whether to retain cached batches on ``device``.
feature_cache_storage: Three-body feature-cache storage policy.
feature_cache_mode: Disk feature-cache read/write policy.
feature_cache_dir: Optional directory for disk-backed feature caches.
feature_cache_prefix: Prefix for feature-cache entries.
feature_cache_per_atom_energy: Whether feature caches include per-atom rows.
seed: Seed used when shuffling cached batches.
progress: Whether to show cache-building progress.
progress_description: Label for the progress bar.
Returns:
Reusable iterable over cached ASE batches.
"""
if isinstance(dataset, Dataset):
normalized_dataset = dataset
else:
normalized_dataset = ASEAtomsDataset(dataset)
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
if num_workers < 0:
raise ValueError("`num_workers` must be non-negative")
unwrapped_dataset: Dataset[ASEAtomsSample] = normalized_dataset
while isinstance(unwrapped_dataset, Subset):
unwrapped_dataset = unwrapped_dataset.dataset
if isinstance(unwrapped_dataset, ASEAtomsDataset):
unwrapped_dataset.cache_tensorized_samples()
if model.cutoff is not None:
resolved_backend = (
model.neighbor_backend
if backend is None
else NeighborListBackend(backend)
)
unwrapped_dataset.cache_neighbor_lists(
cutoff=model.cutoff,
backend=resolved_backend,
)
loader = build_ase_dataloader(
normalized_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last,
num_workers=0,
pin_memory=False,
persistent_workers=False,
prefetch_factor=None,
)
resolved_dtype = model.preferred_dtype() if dtype is None else dtype
resolved_device = None if device is None else torch.device(device)
if cache_batches_on_device and resolved_device is None:
resolved_device = _model_device(model)
resolved_feature_cache_dir = (
None if feature_cache_dir is None else Path(feature_cache_dir)
)
if feature_cache_storage not in {"none", "cpu", "disk"}:
raise ValueError("`feature_cache_storage` must be 'none', 'cpu', or 'disk'")
if feature_cache_mode not in {"auto", "read", "refresh"}:
raise ValueError("`feature_cache_mode` must be 'auto', 'read', or 'refresh'")
if feature_cache_mode == "read" and feature_cache_storage != "disk":
raise ValueError("`feature_cache_mode='read'` requires disk feature caches")
if feature_cache_storage == "disk" and resolved_feature_cache_dir is None:
raise ValueError(
"`feature_cache_dir` is required when `feature_cache_storage='disk'`"
)
batches: list[ASEAtomsBatch] = []
cached_loader = _iter_with_progress(
loader,
enabled=progress,
description=progress_description,
)
for batch_index, batch in enumerate(cached_loader):
if not isinstance(batch, ASEAtomsBatch):
raise TypeError("cached ASE loader requires ASEAtomsBatch instances")
batch.cached_input = batch.prepare_input(
model,
backend=backend,
dtype=resolved_dtype,
device=resolved_device if cache_batches_on_device else None,
requires_grad=False,
)
cache_input = batch.cached_input
if (
feature_cache_storage == "disk"
and not cache_batches_on_device
and resolved_device is not None
and resolved_device.type != "cpu"
):
cache_input = batch.cached_input.to(
device=resolved_device,
dtype=resolved_dtype,
requires_grad=False,
)
_warm_input_caches(
model,
cache_input,
feature_cache_storage=feature_cache_storage,
feature_cache_mode=feature_cache_mode,
feature_cache_dir=resolved_feature_cache_dir,
cache_prefix=f"{feature_cache_prefix}{batch_index}",
legacy_cache_prefix=f"batch{batch_index}",
include_per_atom_energy=feature_cache_per_atom_energy,
)
if cache_input is not batch.cached_input:
batch.cached_input.metadata = cache_input.metadata
if pin_memory and not cache_batches_on_device:
batch.pin_memory()
elif (
pin_memory
and resolved_device is not None
and resolved_device.type == "cuda"
):
batch.pin_memory()
if cache_batches_on_device:
assert resolved_device is not None
batch.cached_input = batch.cached_input.to(
device=resolved_device,
dtype=resolved_dtype,
requires_grad=False,
)
batch.cache_targets_on_device(
device=resolved_device,
dtype=resolved_dtype,
)
batches.append(batch)
return CachedASEBatchLoader(
normalized_dataset,
batches,
shuffle=shuffle,
seed=seed,
)
[docs]
def build_ase_training_loader(
model: UFPotential,
dataset: Dataset[ASEAtomsSample] | Sequence[ASEAtomsSample],
*,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
prefetch_factor: int | None = None,
dtype: Optional[torch.dtype] = None,
backend: Optional[str | NeighborListBackend] = None,
cache_batches: bool = True,
cache_batches_on_device: bool = False,
feature_cache_storage: str = "cpu",
feature_cache_mode: str = "auto",
feature_cache_dir: Path | str | None = None,
feature_cache_prefix: str = "batch",
feature_cache_per_atom_energy: bool = True,
device: Optional[torch.device | str] = None,
seed: int = 0,
progress: bool = False,
progress_description: str = "Caching training batches",
) -> DataLoader | CachedASEBatchLoader:
"""Build a training loader, caching static geometry by default."""
if cache_batches:
return build_cached_ase_dataloader(
model,
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
dtype=dtype,
device=device,
backend=backend,
cache_batches_on_device=cache_batches_on_device,
feature_cache_storage=feature_cache_storage,
feature_cache_mode=feature_cache_mode,
feature_cache_dir=feature_cache_dir,
feature_cache_prefix=feature_cache_prefix,
feature_cache_per_atom_energy=feature_cache_per_atom_energy,
seed=seed,
progress=progress,
progress_description=progress_description,
)
return build_ase_dataloader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
)
[docs]
def build_ase_dataloaders(
split: ASEDatasetSplit,
*,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
prefetch_factor: int | None = None,
shuffle_train: bool = True,
drop_last_train: bool = False,
) -> ASEDataLoaders:
"""Build train, validation, and test dataloaders from one split object."""
return ASEDataLoaders(
train=build_ase_dataloader(
split.train,
batch_size=batch_size,
shuffle=shuffle_train,
drop_last=drop_last_train,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
),
validation=build_ase_dataloader(
split.validation,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
),
test=build_ase_dataloader(
split.test,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
),
)
__all__ = [
"ASEAtomsBatch",
"CachedASEBatchLoader",
"ase_atoms_collate_fn",
"build_ase_training_loader",
"build_cached_ase_dataloader",
"build_ase_dataloader",
"build_ase_dataloaders",
]