"""
Training and evaluation loops for UFP models.
This module runs epoch-level optimization over ASE-backed dataloaders while
reusing the same compute path as inference.
"""
from __future__ import annotations
import warnings
from typing import Iterable, Optional
import torch
from torch.utils.data import DataLoader, Subset
from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend
from ufp.terms._twobody_shape import TwoBodySplineShapePenalty
from ufp.terms._twobody_wall import TwoBodyCriticalWallPenalty
from ufp.training._losses import (
_compute_batch_loss,
_twobody_shape_penalty_loss,
_twobody_wall_penalty_loss,
)
from ufp.training._metrics import (
EpochMetrics,
LossWeights,
TrainingHistory,
_EpochAccumulator,
)
from ufp.training.batch import ASEAtomsBatch
from ufp.training.dataset import ASEAtomsDataset
BatchLoader = DataLoader | Iterable[ASEAtomsBatch]
def _unwrap_ase_dataset(loader: BatchLoader) -> Optional[ASEAtomsDataset]:
"""Return the underlying ``ASEAtomsDataset`` even when the loader wraps subsets."""
dataset = getattr(loader, "dataset", None)
while isinstance(dataset, Subset):
dataset = dataset.dataset
if isinstance(dataset, ASEAtomsDataset):
return dataset
return None
def _warm_loader_caches(
model: UFPotential,
loader: BatchLoader,
*,
neighbor_backend: Optional[str | NeighborListBackend],
) -> None:
"""Warm tensorized-sample and neighbor-list caches before loader iteration."""
dataset = _unwrap_ase_dataset(loader)
if dataset is None:
return
dataset.cache_tensorized_samples()
if model.cutoff is None:
return
resolved_backend = (
model.neighbor_backend
if neighbor_backend is None
else NeighborListBackend(neighbor_backend)
)
dataset.cache_neighbor_lists(
cutoff=model.cutoff,
backend=resolved_backend,
)
def _iter_with_progress(
loader: BatchLoader,
*,
enabled: bool,
description: str,
):
"""Wrap one loader in tqdm when requested and available."""
if not enabled:
return loader
try:
from tqdm.auto import tqdm
except ImportError:
return loader
total = len(loader) if hasattr(loader, "__len__") else None # type: ignore[arg-type]
return tqdm(
loader,
total=total,
desc=description,
leave=True,
)
def _run_epoch(
model: UFPotential,
loader: BatchLoader,
*,
split: str,
optimizer: Optional[torch.optim.Optimizer],
dtype: torch.dtype,
device: Optional[torch.device],
neighbor_backend: Optional[str | NeighborListBackend],
loss_weights: Optional[LossWeights],
twobody_shape_penalty: TwoBodySplineShapePenalty | None,
twobody_wall_penalty: TwoBodyCriticalWallPenalty | None,
max_grad_norm: Optional[float],
progress: bool,
progress_description: str,
) -> EpochMetrics:
"""Drive one pass over a dataloader and return aggregated metrics for that split."""
is_training = optimizer is not None
if is_training:
model.train()
else:
model.eval()
accumulator = _EpochAccumulator(
split=split,
requested_loss_weights=loss_weights,
)
_warm_loader_caches(
model,
loader,
neighbor_backend=neighbor_backend,
)
for batch in _iter_with_progress(
loader,
enabled=progress,
description=progress_description,
):
if not isinstance(batch, ASEAtomsBatch):
raise TypeError(
"training dataloaders must use `ase_atoms_collate_fn` or "
"`build_ase_dataloader`"
)
derive_forces = batch.forces is not None
requires_position_grad = derive_forces and not model.provides_forces()
if is_training:
assert optimizer is not None
optimizer.zero_grad(set_to_none=True)
inputs = batch.prepare_input(
model,
backend=neighbor_backend,
device=device,
dtype=dtype,
requires_grad=requires_position_grad,
)
output = model.compute_input(inputs, derive_forces=derive_forces)
loss, diffs = _compute_batch_loss(
output,
batch,
loss_weights=loss_weights,
dtype=dtype,
device=device,
)
loss = loss + _twobody_shape_penalty_loss(
model,
twobody_shape_penalty,
dtype=dtype,
device=device,
)
loss = loss + _twobody_wall_penalty_loss(
model,
twobody_wall_penalty,
dtype=dtype,
device=device,
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="CUDA initialization: The NVIDIA driver on your system "
"is too old.*",
category=UserWarning,
)
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
else:
with torch.set_grad_enabled(requires_position_grad):
inputs = batch.prepare_input(
model,
backend=neighbor_backend,
device=device,
dtype=dtype,
requires_grad=requires_position_grad,
)
output = model.compute_input(inputs, derive_forces=derive_forces)
_, diffs = _compute_batch_loss(
output,
batch,
loss_weights=loss_weights,
dtype=dtype,
device=device,
)
accumulator.add_batch(
batch,
energy_diff=diffs.get("energy"),
force_diff=diffs.get("forces"),
stress_diff=diffs.get("stress"),
)
return accumulator.finalize()
[docs]
def train_one_epoch(
model: UFPotential,
loader: BatchLoader,
*,
optimizer: torch.optim.Optimizer,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str] = None,
neighbor_backend: Optional[str | NeighborListBackend] = None,
loss_weights: Optional[LossWeights] = None,
twobody_shape_penalty: TwoBodySplineShapePenalty | None = None,
twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None,
max_grad_norm: Optional[float] = None,
progress: bool = False,
progress_description: str = "train",
) -> EpochMetrics:
"""Run one optimization epoch over a training dataloader."""
resolved_device = None if device is None else torch.device(device)
if resolved_device is not None:
model.to(resolved_device)
resolved_dtype = model.preferred_dtype() if dtype is None else dtype
if max_grad_norm is not None and max_grad_norm <= 0.0:
raise ValueError("`max_grad_norm` must be positive")
return _run_epoch(
model,
loader,
split="train",
optimizer=optimizer,
dtype=resolved_dtype,
device=resolved_device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
twobody_shape_penalty=twobody_shape_penalty,
twobody_wall_penalty=twobody_wall_penalty,
max_grad_norm=max_grad_norm,
progress=progress,
progress_description=progress_description,
)
[docs]
def evaluate_model(
model: UFPotential,
loader: BatchLoader,
*,
split: str = "validation",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str] = None,
neighbor_backend: Optional[str | NeighborListBackend] = None,
loss_weights: Optional[LossWeights] = None,
progress: bool = False,
progress_description: str | None = None,
) -> EpochMetrics:
"""Run one evaluation pass without optimizer updates."""
resolved_device = None if device is None else torch.device(device)
if resolved_device is not None:
model.to(resolved_device)
resolved_dtype = model.preferred_dtype() if dtype is None else dtype
return _run_epoch(
model,
loader,
split=split,
optimizer=None,
dtype=resolved_dtype,
device=resolved_device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
twobody_shape_penalty=None,
twobody_wall_penalty=None,
max_grad_norm=None,
progress=progress,
progress_description=(
split if progress_description is None else progress_description
),
)
[docs]
def test_model(
model: UFPotential,
loader: BatchLoader,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str] = None,
neighbor_backend: Optional[str | NeighborListBackend] = None,
loss_weights: Optional[LossWeights] = None,
progress: bool = False,
progress_description: str = "test",
) -> EpochMetrics:
"""Evaluate the model on a test dataloader."""
return evaluate_model(
model,
loader,
split="test",
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
progress=progress,
progress_description=progress_description,
)
[docs]
def fit_model(
model: UFPotential,
train_loader: BatchLoader,
*,
optimizer: torch.optim.Optimizer,
epochs: int,
validation_loader: Optional[BatchLoader] = None,
test_loader: Optional[BatchLoader] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str] = None,
neighbor_backend: Optional[str | NeighborListBackend] = None,
loss_weights: Optional[LossWeights] = None,
twobody_shape_penalty: TwoBodySplineShapePenalty | None = None,
twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None,
max_grad_norm: Optional[float] = None,
validation_frequency: Optional[int] = 1,
evaluate_initial: bool = False,
evaluate_final_train: bool = False,
compile_model: bool = False,
progress: bool = False,
) -> TrainingHistory:
"""Repeat training and evaluation epochs, then package the full history."""
if epochs <= 0:
raise ValueError("`epochs` must be a positive integer")
if validation_frequency is not None and validation_frequency <= 0:
raise ValueError("`validation_frequency` must be positive")
training_model: UFPotential
if compile_model:
training_model = torch.compile(model) # type: ignore[assignment]
else:
training_model = model
initial_train = None
initial_validation = None
if evaluate_initial:
initial_train = evaluate_model(
training_model,
train_loader,
split="train",
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
)
if validation_loader is not None:
initial_validation = evaluate_model(
training_model,
validation_loader,
split="validation",
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
)
train_history: list[EpochMetrics] = []
validation_history: list[EpochMetrics] = []
for epoch in range(1, int(epochs) + 1):
train_history.append(
train_one_epoch(
training_model,
train_loader,
optimizer=optimizer,
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
twobody_shape_penalty=twobody_shape_penalty,
twobody_wall_penalty=twobody_wall_penalty,
max_grad_norm=max_grad_norm,
progress=progress,
progress_description=f"train {epoch}/{epochs}",
)
)
should_validate = (
validation_loader is not None
and validation_frequency is not None
and (epoch == int(epochs) or epoch % validation_frequency == 0)
)
if should_validate:
validation_history.append(
evaluate_model(
training_model,
validation_loader, # type: ignore[arg-type]
split="validation",
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
progress=progress,
progress_description=f"validation {epoch}/{epochs}",
)
)
final_train = None
if evaluate_final_train:
final_train = evaluate_model(
training_model,
train_loader,
split="train",
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
progress=progress,
progress_description="train final",
)
test_metrics = None
if test_loader is not None:
test_metrics = test_model(
training_model,
test_loader,
dtype=dtype,
device=device,
neighbor_backend=neighbor_backend,
loss_weights=loss_weights,
progress=progress,
progress_description="test",
)
return TrainingHistory(
train=tuple(train_history),
validation=tuple(validation_history),
test=test_metrics,
initial_train=initial_train,
initial_validation=initial_validation,
final_train=final_train,
)
__all__ = [
"evaluate_model",
"fit_model",
"test_model",
"train_one_epoch",
]