Source code for ufp.workflows.training

"""Reusable workflow helpers for examples and small supervised UFP studies."""

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import torch

from ufp.terms import (
    TwoBodyCriticalWallPenalty,
    TwoBodySplineShapePenalty,
    UFPModel,
)
from ufp.training import (
    ASEAtomsDataset,
    EpochMetrics,
    LossWeights,
    TrainingHistory,
    build_ase_training_loader,
    evaluate_model,
    train_one_epoch,
)
from ufp.workflows._training import (
    CompileModelPolicy,
    _copy_state_dict_to_cpu,
    _epoch_metrics_dict,
    _optimizer_parameter_groups,
    _resolve_compile_model,
    _save_training_state_checkpoint,
)
from ufp.workflows.onebody import initialize_onebody_terms_from_dataset


TrainingProgress = bool | Literal["summary"]


def _resolve_training_progress(progress: TrainingProgress) -> tuple[bool, bool]:
    """Return whether to show batch progress bars and epoch summaries."""
    if isinstance(progress, bool):
        return progress, progress
    if progress == "summary":
        return False, True
    raise ValueError("`progress` must be True, False, or 'summary'")


def _format_optional_metric(value: float | None, unit: str = "") -> str:
    """Format metrics that may be absent for energy-only or force-only data."""
    if value is None:
        return "n/a"
    return f"{value:.6e}{unit}"


def _format_train_validation_summary(
    *,
    epoch: int,
    epochs: int,
    train_metrics: EpochMetrics,
    validation_metrics: EpochMetrics,
) -> str:
    """Format one compact line after a validation pass."""
    return (
        f"Validation epoch {epoch:4d}/{epochs}: "
        f"train_loss={train_metrics.loss:.6e}, "
        "train_energy_mae="
        f"{_format_optional_metric(train_metrics.energy_mae, ' eV')}, "
        "train_forces_mae="
        f"{_format_optional_metric(train_metrics.forces_mae, ' eV/A')}, "
        f"validation_loss={validation_metrics.loss:.6e}, "
        "validation_energy_mae="
        f"{_format_optional_metric(validation_metrics.energy_mae, ' eV')}, "
        "validation_forces_mae="
        f"{_format_optional_metric(validation_metrics.forces_mae, ' eV/A')}"
    )


def _cuda_is_available() -> bool:
    """Return CUDA availability without letting driver warnings abort CPU runs."""
    try:
        return torch.cuda.is_available()
    except Warning:
        return False


[docs] @dataclass(frozen=True) class TorchTrainingResult: """Summary of one torch-native supervised training run.""" history: TrainingHistory initial_train_metrics: EpochMetrics | None initial_validation_metrics: EpochMetrics | None final_train_metrics: EpochMetrics final_validation_metrics: EpochMetrics | None loss_weights: LossWeights device: str dtype: str batch_size: int validation_batch_size: int | None epochs: int completed_epochs: int learning_rate: float weight_decay: float onebody_weight_decay: float pair_weight_decay: float threebody_weight_decay: float embedding_weight_decay: float state_weight_decay: float charge_spin_weight_decay: float term_weight_decays: dict[str, float] max_grad_norm: float | None validation_frequency: int | None early_stopping_patience: int | None early_stopping_min_delta: float restore_best: bool stopped_early: bool interrupted: bool restored_checkpoint_path: str | None best_validation_epoch: int | None best_validation_metrics: EpochMetrics | None best_checkpoint_path: str | None num_workers: int pin_memory: bool cache_batches: bool cache_batches_on_device: bool feature_cache_storage: str feature_cache_mode: str validation_feature_cache_storage: str validation_feature_cache_mode: str feature_cache_dir: str | None validation_feature_cache_dir: str | None feature_cache_per_atom_energy: bool compile_model: bool evaluate_initial: bool evaluate_final_train: bool initial_onebody_energies: dict[int, float] n_parameters: int
[docs] def train_interaction_model( model: UFPModel, train_dataset: ASEAtomsDataset, *, batch_size: int, epochs: int, learning_rate: float, force_loss_weight: float = 5.0, weight_decay: float = 0.0, onebody_weight_decay: float | None = None, pair_weight_decay: float | None = None, threebody_weight_decay: float | None = None, embedding_weight_decay: float | None = None, state_weight_decay: float | None = None, charge_spin_weight_decay: float | None = None, term_weight_decays: Mapping[str, float] | None = None, twobody_shape_penalty: TwoBodySplineShapePenalty | None = None, twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None, max_grad_norm: float | None = None, validation_dataset: ASEAtomsDataset | None = None, validation_batch_size: int | None = None, validation_frequency: int | None = 10, num_workers: int = 0, pin_memory: bool | None = None, persistent_workers: bool = True, prefetch_factor: int | None = 2, cache_batches: bool = True, cache_batches_on_device: bool = False, feature_cache_storage: str = "cpu", feature_cache_mode: str = "auto", feature_cache_dir: Path | str | None = None, validation_feature_cache_storage: str = "none", validation_feature_cache_mode: str | None = None, validation_feature_cache_dir: Path | str | None = None, feature_cache_per_atom_energy: bool = False, compile_model: CompileModelPolicy = "auto", evaluate_initial: bool = False, evaluate_final_train: bool = False, initialize_onebody: bool = True, onebody_rcond: float | None = None, early_stopping_patience: int | None = None, early_stopping_min_delta: float = 0.0, restore_best: bool = True, best_checkpoint_path: Path | str | None = None, best_checkpoint_metadata: dict[str, object] | None = None, device: torch.device | str | None = None, dtype: torch.dtype = torch.float32, seed: int = 7, progress: TrainingProgress = False, progress_frequency: int = 10, ) -> TorchTrainingResult: """Optimize an interaction model with the torch-native training stack. Pass ``progress="summary"`` to print validation-epoch summaries without per-batch progress bars. """ if batch_size <= 0: raise ValueError("`batch_size` must be positive") if epochs <= 0: raise ValueError("`epochs` must be positive") if validation_batch_size is not None and validation_batch_size <= 0: raise ValueError("`validation_batch_size` must be positive") if validation_frequency is not None and validation_frequency <= 0: raise ValueError("`validation_frequency` must be positive") resolved_onebody_weight_decay = ( float(weight_decay) if onebody_weight_decay is None else float(onebody_weight_decay) ) resolved_pair_weight_decay = ( float(weight_decay) if pair_weight_decay is None else float(pair_weight_decay) ) resolved_threebody_weight_decay = ( float(weight_decay) if threebody_weight_decay is None else float(threebody_weight_decay) ) resolved_embedding_weight_decay = ( float(weight_decay) if embedding_weight_decay is None else float(embedding_weight_decay) ) resolved_state_weight_decay = ( float(weight_decay) if state_weight_decay is None else float(state_weight_decay) ) resolved_charge_spin_weight_decay = ( float(weight_decay) if charge_spin_weight_decay is None else float(charge_spin_weight_decay) ) resolved_term_weight_decays = { "embedding": resolved_embedding_weight_decay, "state": resolved_state_weight_decay, "charge_spin": resolved_charge_spin_weight_decay, } if term_weight_decays is not None: resolved_term_weight_decays.update( {str(key): float(value) for key, value in term_weight_decays.items()} ) if weight_decay < 0.0: raise ValueError("`weight_decay` must be non-negative") if resolved_onebody_weight_decay < 0.0: raise ValueError("`onebody_weight_decay` must be non-negative") if resolved_pair_weight_decay < 0.0: raise ValueError("`pair_weight_decay` must be non-negative") if resolved_threebody_weight_decay < 0.0: raise ValueError("`threebody_weight_decay` must be non-negative") for group_name, decay in resolved_term_weight_decays.items(): if decay < 0.0: raise ValueError( f"`term_weight_decays[{group_name!r}]` must be non-negative" ) if early_stopping_patience is not None and early_stopping_patience <= 0: raise ValueError("`early_stopping_patience` must be positive") if early_stopping_min_delta < 0.0: raise ValueError("`early_stopping_min_delta` must be non-negative") if early_stopping_patience is not None and ( validation_dataset is None or validation_frequency is None ): raise ValueError( "early stopping requires `validation_dataset` and `validation_frequency`" ) if best_checkpoint_path is not None and ( validation_dataset is None or validation_frequency is None ): raise ValueError( "best-checkpoint saving requires `validation_dataset` and " "`validation_frequency`" ) if num_workers < 0: raise ValueError("`num_workers` must be non-negative") if prefetch_factor is not None and prefetch_factor <= 0: raise ValueError("`prefetch_factor` must be positive") if progress_frequency <= 0: raise ValueError("`progress_frequency` must be positive") if feature_cache_storage not in {"none", "cpu", "disk"}: raise ValueError("`feature_cache_storage` must be 'none', 'cpu', or 'disk'") if validation_feature_cache_storage not in {"none", "cpu", "disk"}: raise ValueError( "`validation_feature_cache_storage` must be 'none', 'cpu', or 'disk'" ) if feature_cache_mode not in {"auto", "read", "refresh"}: raise ValueError("`feature_cache_mode` must be 'auto', 'read', or 'refresh'") if validation_feature_cache_mode is None: resolved_validation_feature_cache_mode = ( feature_cache_mode if validation_feature_cache_storage != "none" else "auto" ) else: resolved_validation_feature_cache_mode = validation_feature_cache_mode if resolved_validation_feature_cache_mode not in {"auto", "read", "refresh"}: raise ValueError( "`validation_feature_cache_mode` must be 'auto', 'read', or 'refresh'" ) if feature_cache_mode == "read" and feature_cache_storage != "disk": raise ValueError("`feature_cache_mode='read'` requires disk feature caches") if ( resolved_validation_feature_cache_mode == "read" and validation_feature_cache_storage != "disk" ): raise ValueError( "`validation_feature_cache_mode='read'` requires disk feature caches" ) resolved_best_checkpoint_path = ( None if best_checkpoint_path is None else Path(best_checkpoint_path) ) resolved_best_checkpoint_metadata = ( {} if best_checkpoint_metadata is None else dict(best_checkpoint_metadata) ) progress_bars, progress_summary = _resolve_training_progress(progress) torch.manual_seed(seed) resolved_device = ( torch.device("cuda" if _cuda_is_available() else "cpu") if device is None else torch.device(device) ) model.to(device=resolved_device, dtype=dtype) initial_onebody_energies: dict[int, float] = {} if initialize_onebody: initial_onebody_energies = initialize_onebody_terms_from_dataset( model, train_dataset, rcond=onebody_rcond, ) resolved_compile_model = _resolve_compile_model( compile_model, epochs=epochs, device=resolved_device, ) training_model: UFPModel if resolved_compile_model: training_model = torch.compile(model) else: training_model = model resolved_pin_memory = ( resolved_device.type == "cuda" if pin_memory is None else pin_memory ) resolved_validation_batch_size = ( batch_size if validation_batch_size is None else validation_batch_size ) train_loader = build_ase_training_loader( model, train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=resolved_pin_memory, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, dtype=dtype, device=resolved_device, cache_batches=cache_batches, cache_batches_on_device=cache_batches_on_device, feature_cache_storage=feature_cache_storage, feature_cache_mode=feature_cache_mode, feature_cache_dir=feature_cache_dir, feature_cache_prefix="train_batch", feature_cache_per_atom_energy=feature_cache_per_atom_energy, seed=seed, progress=progress_bars, progress_description="Caching training batches", ) validation_loader = None if validation_dataset is not None: resolved_validation_feature_cache_dir = ( feature_cache_dir if validation_feature_cache_dir is None else validation_feature_cache_dir ) validation_loader = build_ase_training_loader( model, validation_dataset, batch_size=resolved_validation_batch_size, shuffle=False, num_workers=num_workers, pin_memory=resolved_pin_memory, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, dtype=dtype, device=resolved_device, cache_batches=cache_batches, cache_batches_on_device=cache_batches_on_device, feature_cache_storage=validation_feature_cache_storage, feature_cache_mode=resolved_validation_feature_cache_mode, feature_cache_dir=resolved_validation_feature_cache_dir, feature_cache_prefix="validation_batch", feature_cache_per_atom_energy=feature_cache_per_atom_energy, seed=seed + 1, progress=progress_bars, progress_description="Caching validation batches", ) loss_weights = LossWeights(energy=1.0, forces=force_loss_weight) optimizer = torch.optim.AdamW( _optimizer_parameter_groups( model, weight_decay=float(weight_decay), onebody_weight_decay=resolved_onebody_weight_decay, pair_weight_decay=resolved_pair_weight_decay, threebody_weight_decay=resolved_threebody_weight_decay, term_weight_decays=resolved_term_weight_decays, ), lr=learning_rate, ) initial_train_metrics = None initial_validation_metrics = None if evaluate_initial: initial_train_metrics = evaluate_model( training_model, train_loader, split="train", dtype=dtype, device=resolved_device, loss_weights=loss_weights, ) if evaluate_initial and validation_loader is not None: initial_validation_metrics = evaluate_model( training_model, validation_loader, split="validation", dtype=dtype, device=resolved_device, loss_weights=loss_weights, ) train_history: list[EpochMetrics] = [] validation_history: list[EpochMetrics] = [] best_validation_loss: float | None = None best_validation_epoch: int | None = None best_validation_metrics: EpochMetrics | None = None best_state_dict: dict[str, torch.Tensor] | None = None checks_without_improvement = 0 stopped_early = False interrupted = False restored_checkpoint_path: str | None = None if progress_summary: print( "Starting torch-native optimization: " f"{len(train_dataset)} structures, " f"{epochs} epochs, batch_size={batch_size}, " f"device={resolved_device}, dtype={dtype}" ) print( "Batching: " f"cache_batches={cache_batches}, pin_memory={resolved_pin_memory}, " f"cache_batches_on_device={cache_batches_on_device}, " f"feature_cache_storage={feature_cache_storage}, " f"feature_cache_mode={feature_cache_mode}, " f"num_workers={num_workers}, compile_model={resolved_compile_model}" ) if initial_onebody_energies: formatted_onebody = ", ".join( f"Z={atomic_type}: {value:.8f} eV" for atomic_type, value in sorted(initial_onebody_energies.items()) ) print("Initialized one-body terms:", formatted_onebody) if initial_train_metrics is not None: print( "Initial train metrics: " f"loss={initial_train_metrics.loss:.6e}, " "energy_mae=" f"{_format_optional_metric(initial_train_metrics.energy_mae, ' eV')}, " "forces_mae=" f"{_format_optional_metric(initial_train_metrics.forces_mae, ' eV/A')}" ) if initial_validation_metrics is not None: initial_validation_energy_mae = _format_optional_metric( initial_validation_metrics.energy_mae, " eV", ) initial_validation_forces_mae = _format_optional_metric( initial_validation_metrics.forces_mae, " eV/A", ) print( "Initial validation metrics: " f"loss={initial_validation_metrics.loss:.6e}, " f"energy_mae={initial_validation_energy_mae}, " f"forces_mae={initial_validation_forces_mae}" ) if early_stopping_patience is not None: print( "Early stopping: " f"patience={early_stopping_patience} validation checks, " f"min_delta={early_stopping_min_delta:.3e}, " f"restore_best={restore_best}" ) if resolved_best_checkpoint_path is not None: print("Best-validation checkpoint:", resolved_best_checkpoint_path) print( "Weight decay: " f"default={float(weight_decay):.3e}, " f"one_body={resolved_onebody_weight_decay:.3e}, " f"two_body={resolved_pair_weight_decay:.3e}, " f"three_body={resolved_threebody_weight_decay:.3e}, " f"embedding={resolved_embedding_weight_decay:.3e}, " f"state={resolved_state_weight_decay:.3e}, " f"charge_spin={resolved_charge_spin_weight_decay:.3e}" ) try: for epoch in range(1, epochs + 1): metrics = train_one_epoch( training_model, train_loader, optimizer=optimizer, dtype=dtype, device=resolved_device, loss_weights=loss_weights, twobody_shape_penalty=twobody_shape_penalty, twobody_wall_penalty=twobody_wall_penalty, max_grad_norm=max_grad_norm, progress=progress_bars, progress_description=f"train {epoch}/{epochs}", ) train_history.append(metrics) if progress_bars and ( epoch == 1 or epoch == epochs or epoch % progress_frequency == 0 ): print( f"Epoch {epoch:4d}/{epochs}: " f"loss={metrics.loss:.6e}, " "energy_mae=" f"{_format_optional_metric(metrics.energy_mae, ' eV')}, " "forces_mae=" f"{_format_optional_metric(metrics.forces_mae, ' eV/A')}" ) should_validate = ( validation_loader is not None and validation_frequency is not None and (epoch == epochs or epoch % validation_frequency == 0) ) if should_validate: validation_metrics = evaluate_model( training_model, validation_loader, # type: ignore[arg-type] split="validation", dtype=dtype, device=resolved_device, loss_weights=loss_weights, progress=progress_bars, progress_description=f"validation {epoch}/{epochs}", ) validation_history.append(validation_metrics) improved = ( best_validation_loss is None or validation_metrics.loss < best_validation_loss - early_stopping_min_delta ) if improved: best_validation_loss = validation_metrics.loss best_validation_epoch = epoch best_validation_metrics = validation_metrics checks_without_improvement = 0 if restore_best: best_state_dict = _copy_state_dict_to_cpu(model) if resolved_best_checkpoint_path is not None: checkpoint_metadata = dict(resolved_best_checkpoint_metadata) checkpoint_metadata.update( { "checkpoint": "best_validation", "best_validation_epoch": int(epoch), "best_validation_loss": float(validation_metrics.loss), "best_validation_metrics": _epoch_metrics_dict( validation_metrics ), "completed_epochs": int(epoch), "dtype": str(dtype).replace("torch.", ""), "device": str(resolved_device), "force_loss_weight": float(force_loss_weight), } ) _save_training_state_checkpoint( resolved_best_checkpoint_path, model=model, metadata=checkpoint_metadata, ) if progress_summary: print( "Saved best-validation checkpoint: " f"{resolved_best_checkpoint_path}" ) elif early_stopping_patience is not None: checks_without_improvement += 1 if progress_summary: print( _format_train_validation_summary( epoch=epoch, epochs=epochs, train_metrics=metrics, validation_metrics=validation_metrics, ) ) if early_stopping_patience is not None: print( "Early stopping status: " f"best_epoch={best_validation_epoch}, " f"checks_without_improvement=" f"{checks_without_improvement}/" f"{early_stopping_patience}" ) if ( early_stopping_patience is not None and checks_without_improvement >= early_stopping_patience ): stopped_early = True if progress_summary: print( "Stopping early after validation epoch " f"{epoch}; best validation epoch was " f"{best_validation_epoch}." ) break except KeyboardInterrupt: interrupted = True if progress_summary: print( "Interrupted during training; restoring the best current-run " "checkpoint when available." ) if interrupted and restore_best and best_state_dict is not None: model.load_state_dict(best_state_dict) if resolved_best_checkpoint_path is not None: restored_checkpoint_path = str(resolved_best_checkpoint_path) if progress_summary: print(f"Restored best validation state from epoch {best_validation_epoch}.") elif ( interrupted and restore_best and resolved_best_checkpoint_path is not None and resolved_best_checkpoint_path.is_file() ): try: checkpoint = torch.load( resolved_best_checkpoint_path, map_location="cpu", weights_only=False, ) except TypeError: checkpoint = torch.load(resolved_best_checkpoint_path, map_location="cpu") state_dict = ( checkpoint.get("state_dict") if isinstance(checkpoint, dict) else None ) if isinstance(state_dict, dict): model.load_state_dict(state_dict) restored_checkpoint_path = str(resolved_best_checkpoint_path) if progress_summary: print( "Restored best validation checkpoint: " f"{resolved_best_checkpoint_path}" ) if ( not interrupted and early_stopping_patience is not None and restore_best and best_state_dict is not None ): model.load_state_dict(best_state_dict) need_final_train_evaluation = evaluate_final_train or not train_history if need_final_train_evaluation: final_train_metrics = evaluate_model( training_model, train_loader, split="train", dtype=dtype, device=resolved_device, loss_weights=loss_weights, progress=progress_bars, progress_description="train final", ) else: final_train_metrics = train_history[-1] if progress_summary and need_final_train_evaluation: print( "Final train metrics: " f"loss={final_train_metrics.loss:.6e}, " "energy_mae=" f"{_format_optional_metric(final_train_metrics.energy_mae, ' eV')}, " "forces_mae=" f"{_format_optional_metric(final_train_metrics.forces_mae, ' eV/A')}" ) final_validation_metrics = ( None if not validation_history else validation_history[-1] ) if ( restore_best and best_validation_metrics is not None and (early_stopping_patience is not None or interrupted) ): final_validation_metrics = best_validation_metrics n_parameters = sum( parameter.numel() for parameter in model.parameters() if parameter.requires_grad ) return TorchTrainingResult( history=TrainingHistory( train=tuple(train_history), validation=tuple(validation_history), ), initial_train_metrics=initial_train_metrics, initial_validation_metrics=initial_validation_metrics, final_train_metrics=final_train_metrics, final_validation_metrics=final_validation_metrics, loss_weights=loss_weights, device=str(resolved_device), dtype=str(dtype).replace("torch.", ""), batch_size=int(batch_size), validation_batch_size=( None if validation_dataset is None else int(resolved_validation_batch_size) ), epochs=int(epochs), completed_epochs=int(len(train_history)), learning_rate=float(learning_rate), weight_decay=float(weight_decay), onebody_weight_decay=float(resolved_onebody_weight_decay), pair_weight_decay=float(resolved_pair_weight_decay), threebody_weight_decay=float(resolved_threebody_weight_decay), embedding_weight_decay=float(resolved_embedding_weight_decay), state_weight_decay=float(resolved_state_weight_decay), charge_spin_weight_decay=float(resolved_charge_spin_weight_decay), term_weight_decays={ str(key): float(value) for key, value in sorted(resolved_term_weight_decays.items()) }, max_grad_norm=max_grad_norm, validation_frequency=validation_frequency, early_stopping_patience=early_stopping_patience, early_stopping_min_delta=float(early_stopping_min_delta), restore_best=bool(restore_best), stopped_early=bool(stopped_early), interrupted=bool(interrupted), restored_checkpoint_path=restored_checkpoint_path, best_validation_epoch=best_validation_epoch, best_validation_metrics=best_validation_metrics, best_checkpoint_path=( None if resolved_best_checkpoint_path is None else str(resolved_best_checkpoint_path) ), num_workers=int(num_workers), pin_memory=bool(resolved_pin_memory), cache_batches=bool(cache_batches), cache_batches_on_device=bool(cache_batches_on_device), feature_cache_storage=str(feature_cache_storage), feature_cache_mode=str(feature_cache_mode), validation_feature_cache_storage=str(validation_feature_cache_storage), validation_feature_cache_mode=str(resolved_validation_feature_cache_mode), feature_cache_dir=None if feature_cache_dir is None else str(feature_cache_dir), validation_feature_cache_dir=( None if validation_feature_cache_dir is None else str(validation_feature_cache_dir) ), feature_cache_per_atom_energy=bool(feature_cache_per_atom_energy), compile_model=bool(resolved_compile_model), evaluate_initial=bool(evaluate_initial), evaluate_final_train=bool(evaluate_final_train), initial_onebody_energies=initial_onebody_energies, n_parameters=int(n_parameters), )