Source code for ufp.training.engine

"""
Training and evaluation loops for UFP models.

This module runs epoch-level optimization over ASE-backed dataloaders while
reusing the same compute path as inference.
"""

from __future__ import annotations

import warnings
from typing import Iterable, Optional

import torch
from torch.utils.data import DataLoader, Subset

from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend
from ufp.terms._twobody_shape import TwoBodySplineShapePenalty
from ufp.terms._twobody_wall import TwoBodyCriticalWallPenalty
from ufp.training._losses import (
    _compute_batch_loss,
    _twobody_shape_penalty_loss,
    _twobody_wall_penalty_loss,
)
from ufp.training._metrics import (
    EpochMetrics,
    LossWeights,
    TrainingHistory,
    _EpochAccumulator,
)
from ufp.training.batch import ASEAtomsBatch
from ufp.training.dataset import ASEAtomsDataset


BatchLoader = DataLoader | Iterable[ASEAtomsBatch]


def _unwrap_ase_dataset(loader: BatchLoader) -> Optional[ASEAtomsDataset]:
    """Return the underlying ``ASEAtomsDataset`` even when the loader wraps subsets."""
    dataset = getattr(loader, "dataset", None)
    while isinstance(dataset, Subset):
        dataset = dataset.dataset
    if isinstance(dataset, ASEAtomsDataset):
        return dataset
    return None


def _warm_loader_caches(
    model: UFPotential,
    loader: BatchLoader,
    *,
    neighbor_backend: Optional[str | NeighborListBackend],
) -> None:
    """Warm tensorized-sample and neighbor-list caches before loader iteration."""
    dataset = _unwrap_ase_dataset(loader)
    if dataset is None:
        return

    dataset.cache_tensorized_samples()
    if model.cutoff is None:
        return

    resolved_backend = (
        model.neighbor_backend
        if neighbor_backend is None
        else NeighborListBackend(neighbor_backend)
    )
    dataset.cache_neighbor_lists(
        cutoff=model.cutoff,
        backend=resolved_backend,
    )


def _iter_with_progress(
    loader: BatchLoader,
    *,
    enabled: bool,
    description: str,
):
    """Wrap one loader in tqdm when requested and available."""
    if not enabled:
        return loader

    try:
        from tqdm.auto import tqdm
    except ImportError:
        return loader

    total = len(loader) if hasattr(loader, "__len__") else None  # type: ignore[arg-type]
    return tqdm(
        loader,
        total=total,
        desc=description,
        leave=True,
    )


def _run_epoch(
    model: UFPotential,
    loader: BatchLoader,
    *,
    split: str,
    optimizer: Optional[torch.optim.Optimizer],
    dtype: torch.dtype,
    device: Optional[torch.device],
    neighbor_backend: Optional[str | NeighborListBackend],
    loss_weights: Optional[LossWeights],
    twobody_shape_penalty: TwoBodySplineShapePenalty | None,
    twobody_wall_penalty: TwoBodyCriticalWallPenalty | None,
    max_grad_norm: Optional[float],
    progress: bool,
    progress_description: str,
) -> EpochMetrics:
    """Drive one pass over a dataloader and return aggregated metrics for that split."""
    is_training = optimizer is not None
    if is_training:
        model.train()
    else:
        model.eval()

    accumulator = _EpochAccumulator(
        split=split,
        requested_loss_weights=loss_weights,
    )
    _warm_loader_caches(
        model,
        loader,
        neighbor_backend=neighbor_backend,
    )

    for batch in _iter_with_progress(
        loader,
        enabled=progress,
        description=progress_description,
    ):
        if not isinstance(batch, ASEAtomsBatch):
            raise TypeError(
                "training dataloaders must use `ase_atoms_collate_fn` or "
                "`build_ase_dataloader`"
            )

        derive_forces = batch.forces is not None
        requires_position_grad = derive_forces and not model.provides_forces()
        if is_training:
            assert optimizer is not None
            optimizer.zero_grad(set_to_none=True)
            inputs = batch.prepare_input(
                model,
                backend=neighbor_backend,
                device=device,
                dtype=dtype,
                requires_grad=requires_position_grad,
            )
            output = model.compute_input(inputs, derive_forces=derive_forces)
            loss, diffs = _compute_batch_loss(
                output,
                batch,
                loss_weights=loss_weights,
                dtype=dtype,
                device=device,
            )
            loss = loss + _twobody_shape_penalty_loss(
                model,
                twobody_shape_penalty,
                dtype=dtype,
                device=device,
            )
            loss = loss + _twobody_wall_penalty_loss(
                model,
                twobody_wall_penalty,
                dtype=dtype,
                device=device,
            )
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    message="CUDA initialization: The NVIDIA driver on your system "
                    "is too old.*",
                    category=UserWarning,
                )
                loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
        else:
            with torch.set_grad_enabled(requires_position_grad):
                inputs = batch.prepare_input(
                    model,
                    backend=neighbor_backend,
                    device=device,
                    dtype=dtype,
                    requires_grad=requires_position_grad,
                )
                output = model.compute_input(inputs, derive_forces=derive_forces)
                _, diffs = _compute_batch_loss(
                    output,
                    batch,
                    loss_weights=loss_weights,
                    dtype=dtype,
                    device=device,
                )

        accumulator.add_batch(
            batch,
            energy_diff=diffs.get("energy"),
            force_diff=diffs.get("forces"),
            stress_diff=diffs.get("stress"),
        )

    return accumulator.finalize()


[docs] def train_one_epoch( model: UFPotential, loader: BatchLoader, *, optimizer: torch.optim.Optimizer, dtype: Optional[torch.dtype] = None, device: Optional[torch.device | str] = None, neighbor_backend: Optional[str | NeighborListBackend] = None, loss_weights: Optional[LossWeights] = None, twobody_shape_penalty: TwoBodySplineShapePenalty | None = None, twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None, max_grad_norm: Optional[float] = None, progress: bool = False, progress_description: str = "train", ) -> EpochMetrics: """Run one optimization epoch over a training dataloader.""" resolved_device = None if device is None else torch.device(device) if resolved_device is not None: model.to(resolved_device) resolved_dtype = model.preferred_dtype() if dtype is None else dtype if max_grad_norm is not None and max_grad_norm <= 0.0: raise ValueError("`max_grad_norm` must be positive") return _run_epoch( model, loader, split="train", optimizer=optimizer, dtype=resolved_dtype, device=resolved_device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, twobody_shape_penalty=twobody_shape_penalty, twobody_wall_penalty=twobody_wall_penalty, max_grad_norm=max_grad_norm, progress=progress, progress_description=progress_description, )
[docs] def evaluate_model( model: UFPotential, loader: BatchLoader, *, split: str = "validation", dtype: Optional[torch.dtype] = None, device: Optional[torch.device | str] = None, neighbor_backend: Optional[str | NeighborListBackend] = None, loss_weights: Optional[LossWeights] = None, progress: bool = False, progress_description: str | None = None, ) -> EpochMetrics: """Run one evaluation pass without optimizer updates.""" resolved_device = None if device is None else torch.device(device) if resolved_device is not None: model.to(resolved_device) resolved_dtype = model.preferred_dtype() if dtype is None else dtype return _run_epoch( model, loader, split=split, optimizer=None, dtype=resolved_dtype, device=resolved_device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, twobody_shape_penalty=None, twobody_wall_penalty=None, max_grad_norm=None, progress=progress, progress_description=( split if progress_description is None else progress_description ), )
[docs] def test_model( model: UFPotential, loader: BatchLoader, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device | str] = None, neighbor_backend: Optional[str | NeighborListBackend] = None, loss_weights: Optional[LossWeights] = None, progress: bool = False, progress_description: str = "test", ) -> EpochMetrics: """Evaluate the model on a test dataloader.""" return evaluate_model( model, loader, split="test", dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, progress=progress, progress_description=progress_description, )
[docs] def fit_model( model: UFPotential, train_loader: BatchLoader, *, optimizer: torch.optim.Optimizer, epochs: int, validation_loader: Optional[BatchLoader] = None, test_loader: Optional[BatchLoader] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device | str] = None, neighbor_backend: Optional[str | NeighborListBackend] = None, loss_weights: Optional[LossWeights] = None, twobody_shape_penalty: TwoBodySplineShapePenalty | None = None, twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None, max_grad_norm: Optional[float] = None, validation_frequency: Optional[int] = 1, evaluate_initial: bool = False, evaluate_final_train: bool = False, compile_model: bool = False, progress: bool = False, ) -> TrainingHistory: """Repeat training and evaluation epochs, then package the full history.""" if epochs <= 0: raise ValueError("`epochs` must be a positive integer") if validation_frequency is not None and validation_frequency <= 0: raise ValueError("`validation_frequency` must be positive") training_model: UFPotential if compile_model: training_model = torch.compile(model) # type: ignore[assignment] else: training_model = model initial_train = None initial_validation = None if evaluate_initial: initial_train = evaluate_model( training_model, train_loader, split="train", dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, ) if validation_loader is not None: initial_validation = evaluate_model( training_model, validation_loader, split="validation", dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, ) train_history: list[EpochMetrics] = [] validation_history: list[EpochMetrics] = [] for epoch in range(1, int(epochs) + 1): train_history.append( train_one_epoch( training_model, train_loader, optimizer=optimizer, dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, twobody_shape_penalty=twobody_shape_penalty, twobody_wall_penalty=twobody_wall_penalty, max_grad_norm=max_grad_norm, progress=progress, progress_description=f"train {epoch}/{epochs}", ) ) should_validate = ( validation_loader is not None and validation_frequency is not None and (epoch == int(epochs) or epoch % validation_frequency == 0) ) if should_validate: validation_history.append( evaluate_model( training_model, validation_loader, # type: ignore[arg-type] split="validation", dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, progress=progress, progress_description=f"validation {epoch}/{epochs}", ) ) final_train = None if evaluate_final_train: final_train = evaluate_model( training_model, train_loader, split="train", dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, progress=progress, progress_description="train final", ) test_metrics = None if test_loader is not None: test_metrics = test_model( training_model, test_loader, dtype=dtype, device=device, neighbor_backend=neighbor_backend, loss_weights=loss_weights, progress=progress, progress_description="test", ) return TrainingHistory( train=tuple(train_history), validation=tuple(validation_history), test=test_metrics, initial_train=initial_train, initial_validation=initial_validation, final_train=final_train, )
__all__ = [ "evaluate_model", "fit_model", "test_model", "train_one_epoch", ]