"""Reusable workflow helpers for examples and small supervised UFP studies."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
from ufp.terms import (
TwoBodyCriticalWallPenalty,
TwoBodySplineShapePenalty,
UFPModel,
)
from ufp.training import (
ASEAtomsDataset,
EpochMetrics,
LossWeights,
TrainingHistory,
build_ase_training_loader,
evaluate_model,
train_one_epoch,
)
from ufp.workflows._training import (
CompileModelPolicy,
_copy_state_dict_to_cpu,
_epoch_metrics_dict,
_optimizer_parameter_groups,
_resolve_compile_model,
_save_training_state_checkpoint,
)
from ufp.workflows.onebody import initialize_onebody_terms_from_dataset
TrainingProgress = bool | Literal["summary"]
def _resolve_training_progress(progress: TrainingProgress) -> tuple[bool, bool]:
"""Return whether to show batch progress bars and epoch summaries."""
if isinstance(progress, bool):
return progress, progress
if progress == "summary":
return False, True
raise ValueError("`progress` must be True, False, or 'summary'")
def _format_optional_metric(value: float | None, unit: str = "") -> str:
"""Format metrics that may be absent for energy-only or force-only data."""
if value is None:
return "n/a"
return f"{value:.6e}{unit}"
def _format_train_validation_summary(
*,
epoch: int,
epochs: int,
train_metrics: EpochMetrics,
validation_metrics: EpochMetrics,
) -> str:
"""Format one compact line after a validation pass."""
return (
f"Validation epoch {epoch:4d}/{epochs}: "
f"train_loss={train_metrics.loss:.6e}, "
"train_energy_mae="
f"{_format_optional_metric(train_metrics.energy_mae, ' eV')}, "
"train_forces_mae="
f"{_format_optional_metric(train_metrics.forces_mae, ' eV/A')}, "
f"validation_loss={validation_metrics.loss:.6e}, "
"validation_energy_mae="
f"{_format_optional_metric(validation_metrics.energy_mae, ' eV')}, "
"validation_forces_mae="
f"{_format_optional_metric(validation_metrics.forces_mae, ' eV/A')}"
)
def _cuda_is_available() -> bool:
"""Return CUDA availability without letting driver warnings abort CPU runs."""
try:
return torch.cuda.is_available()
except Warning:
return False
[docs]
@dataclass(frozen=True)
class TorchTrainingResult:
"""Summary of one torch-native supervised training run."""
history: TrainingHistory
initial_train_metrics: EpochMetrics | None
initial_validation_metrics: EpochMetrics | None
final_train_metrics: EpochMetrics
final_validation_metrics: EpochMetrics | None
loss_weights: LossWeights
device: str
dtype: str
batch_size: int
validation_batch_size: int | None
epochs: int
completed_epochs: int
learning_rate: float
weight_decay: float
onebody_weight_decay: float
pair_weight_decay: float
threebody_weight_decay: float
embedding_weight_decay: float
state_weight_decay: float
charge_spin_weight_decay: float
term_weight_decays: dict[str, float]
max_grad_norm: float | None
validation_frequency: int | None
early_stopping_patience: int | None
early_stopping_min_delta: float
restore_best: bool
stopped_early: bool
interrupted: bool
restored_checkpoint_path: str | None
best_validation_epoch: int | None
best_validation_metrics: EpochMetrics | None
best_checkpoint_path: str | None
num_workers: int
pin_memory: bool
cache_batches: bool
cache_batches_on_device: bool
feature_cache_storage: str
feature_cache_mode: str
validation_feature_cache_storage: str
validation_feature_cache_mode: str
feature_cache_dir: str | None
validation_feature_cache_dir: str | None
feature_cache_per_atom_energy: bool
compile_model: bool
evaluate_initial: bool
evaluate_final_train: bool
initial_onebody_energies: dict[int, float]
n_parameters: int
[docs]
def train_interaction_model(
model: UFPModel,
train_dataset: ASEAtomsDataset,
*,
batch_size: int,
epochs: int,
learning_rate: float,
force_loss_weight: float = 5.0,
weight_decay: float = 0.0,
onebody_weight_decay: float | None = None,
pair_weight_decay: float | None = None,
threebody_weight_decay: float | None = None,
embedding_weight_decay: float | None = None,
state_weight_decay: float | None = None,
charge_spin_weight_decay: float | None = None,
term_weight_decays: Mapping[str, float] | None = None,
twobody_shape_penalty: TwoBodySplineShapePenalty | None = None,
twobody_wall_penalty: TwoBodyCriticalWallPenalty | None = None,
max_grad_norm: float | None = None,
validation_dataset: ASEAtomsDataset | None = None,
validation_batch_size: int | None = None,
validation_frequency: int | None = 10,
num_workers: int = 0,
pin_memory: bool | None = None,
persistent_workers: bool = True,
prefetch_factor: int | None = 2,
cache_batches: bool = True,
cache_batches_on_device: bool = False,
feature_cache_storage: str = "cpu",
feature_cache_mode: str = "auto",
feature_cache_dir: Path | str | None = None,
validation_feature_cache_storage: str = "none",
validation_feature_cache_mode: str | None = None,
validation_feature_cache_dir: Path | str | None = None,
feature_cache_per_atom_energy: bool = False,
compile_model: CompileModelPolicy = "auto",
evaluate_initial: bool = False,
evaluate_final_train: bool = False,
initialize_onebody: bool = True,
onebody_rcond: float | None = None,
early_stopping_patience: int | None = None,
early_stopping_min_delta: float = 0.0,
restore_best: bool = True,
best_checkpoint_path: Path | str | None = None,
best_checkpoint_metadata: dict[str, object] | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
seed: int = 7,
progress: TrainingProgress = False,
progress_frequency: int = 10,
) -> TorchTrainingResult:
"""Optimize an interaction model with the torch-native training stack.
Pass ``progress="summary"`` to print validation-epoch summaries without
per-batch progress bars.
"""
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
if epochs <= 0:
raise ValueError("`epochs` must be positive")
if validation_batch_size is not None and validation_batch_size <= 0:
raise ValueError("`validation_batch_size` must be positive")
if validation_frequency is not None and validation_frequency <= 0:
raise ValueError("`validation_frequency` must be positive")
resolved_onebody_weight_decay = (
float(weight_decay)
if onebody_weight_decay is None
else float(onebody_weight_decay)
)
resolved_pair_weight_decay = (
float(weight_decay) if pair_weight_decay is None else float(pair_weight_decay)
)
resolved_threebody_weight_decay = (
float(weight_decay)
if threebody_weight_decay is None
else float(threebody_weight_decay)
)
resolved_embedding_weight_decay = (
float(weight_decay)
if embedding_weight_decay is None
else float(embedding_weight_decay)
)
resolved_state_weight_decay = (
float(weight_decay) if state_weight_decay is None else float(state_weight_decay)
)
resolved_charge_spin_weight_decay = (
float(weight_decay)
if charge_spin_weight_decay is None
else float(charge_spin_weight_decay)
)
resolved_term_weight_decays = {
"embedding": resolved_embedding_weight_decay,
"state": resolved_state_weight_decay,
"charge_spin": resolved_charge_spin_weight_decay,
}
if term_weight_decays is not None:
resolved_term_weight_decays.update(
{str(key): float(value) for key, value in term_weight_decays.items()}
)
if weight_decay < 0.0:
raise ValueError("`weight_decay` must be non-negative")
if resolved_onebody_weight_decay < 0.0:
raise ValueError("`onebody_weight_decay` must be non-negative")
if resolved_pair_weight_decay < 0.0:
raise ValueError("`pair_weight_decay` must be non-negative")
if resolved_threebody_weight_decay < 0.0:
raise ValueError("`threebody_weight_decay` must be non-negative")
for group_name, decay in resolved_term_weight_decays.items():
if decay < 0.0:
raise ValueError(
f"`term_weight_decays[{group_name!r}]` must be non-negative"
)
if early_stopping_patience is not None and early_stopping_patience <= 0:
raise ValueError("`early_stopping_patience` must be positive")
if early_stopping_min_delta < 0.0:
raise ValueError("`early_stopping_min_delta` must be non-negative")
if early_stopping_patience is not None and (
validation_dataset is None or validation_frequency is None
):
raise ValueError(
"early stopping requires `validation_dataset` and `validation_frequency`"
)
if best_checkpoint_path is not None and (
validation_dataset is None or validation_frequency is None
):
raise ValueError(
"best-checkpoint saving requires `validation_dataset` and "
"`validation_frequency`"
)
if num_workers < 0:
raise ValueError("`num_workers` must be non-negative")
if prefetch_factor is not None and prefetch_factor <= 0:
raise ValueError("`prefetch_factor` must be positive")
if progress_frequency <= 0:
raise ValueError("`progress_frequency` must be positive")
if feature_cache_storage not in {"none", "cpu", "disk"}:
raise ValueError("`feature_cache_storage` must be 'none', 'cpu', or 'disk'")
if validation_feature_cache_storage not in {"none", "cpu", "disk"}:
raise ValueError(
"`validation_feature_cache_storage` must be 'none', 'cpu', or 'disk'"
)
if feature_cache_mode not in {"auto", "read", "refresh"}:
raise ValueError("`feature_cache_mode` must be 'auto', 'read', or 'refresh'")
if validation_feature_cache_mode is None:
resolved_validation_feature_cache_mode = (
feature_cache_mode if validation_feature_cache_storage != "none" else "auto"
)
else:
resolved_validation_feature_cache_mode = validation_feature_cache_mode
if resolved_validation_feature_cache_mode not in {"auto", "read", "refresh"}:
raise ValueError(
"`validation_feature_cache_mode` must be 'auto', 'read', or 'refresh'"
)
if feature_cache_mode == "read" and feature_cache_storage != "disk":
raise ValueError("`feature_cache_mode='read'` requires disk feature caches")
if (
resolved_validation_feature_cache_mode == "read"
and validation_feature_cache_storage != "disk"
):
raise ValueError(
"`validation_feature_cache_mode='read'` requires disk feature caches"
)
resolved_best_checkpoint_path = (
None if best_checkpoint_path is None else Path(best_checkpoint_path)
)
resolved_best_checkpoint_metadata = (
{} if best_checkpoint_metadata is None else dict(best_checkpoint_metadata)
)
progress_bars, progress_summary = _resolve_training_progress(progress)
torch.manual_seed(seed)
resolved_device = (
torch.device("cuda" if _cuda_is_available() else "cpu")
if device is None
else torch.device(device)
)
model.to(device=resolved_device, dtype=dtype)
initial_onebody_energies: dict[int, float] = {}
if initialize_onebody:
initial_onebody_energies = initialize_onebody_terms_from_dataset(
model,
train_dataset,
rcond=onebody_rcond,
)
resolved_compile_model = _resolve_compile_model(
compile_model,
epochs=epochs,
device=resolved_device,
)
training_model: UFPModel
if resolved_compile_model:
training_model = torch.compile(model)
else:
training_model = model
resolved_pin_memory = (
resolved_device.type == "cuda" if pin_memory is None else pin_memory
)
resolved_validation_batch_size = (
batch_size if validation_batch_size is None else validation_batch_size
)
train_loader = build_ase_training_loader(
model,
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=resolved_pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
dtype=dtype,
device=resolved_device,
cache_batches=cache_batches,
cache_batches_on_device=cache_batches_on_device,
feature_cache_storage=feature_cache_storage,
feature_cache_mode=feature_cache_mode,
feature_cache_dir=feature_cache_dir,
feature_cache_prefix="train_batch",
feature_cache_per_atom_energy=feature_cache_per_atom_energy,
seed=seed,
progress=progress_bars,
progress_description="Caching training batches",
)
validation_loader = None
if validation_dataset is not None:
resolved_validation_feature_cache_dir = (
feature_cache_dir
if validation_feature_cache_dir is None
else validation_feature_cache_dir
)
validation_loader = build_ase_training_loader(
model,
validation_dataset,
batch_size=resolved_validation_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=resolved_pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
dtype=dtype,
device=resolved_device,
cache_batches=cache_batches,
cache_batches_on_device=cache_batches_on_device,
feature_cache_storage=validation_feature_cache_storage,
feature_cache_mode=resolved_validation_feature_cache_mode,
feature_cache_dir=resolved_validation_feature_cache_dir,
feature_cache_prefix="validation_batch",
feature_cache_per_atom_energy=feature_cache_per_atom_energy,
seed=seed + 1,
progress=progress_bars,
progress_description="Caching validation batches",
)
loss_weights = LossWeights(energy=1.0, forces=force_loss_weight)
optimizer = torch.optim.AdamW(
_optimizer_parameter_groups(
model,
weight_decay=float(weight_decay),
onebody_weight_decay=resolved_onebody_weight_decay,
pair_weight_decay=resolved_pair_weight_decay,
threebody_weight_decay=resolved_threebody_weight_decay,
term_weight_decays=resolved_term_weight_decays,
),
lr=learning_rate,
)
initial_train_metrics = None
initial_validation_metrics = None
if evaluate_initial:
initial_train_metrics = evaluate_model(
training_model,
train_loader,
split="train",
dtype=dtype,
device=resolved_device,
loss_weights=loss_weights,
)
if evaluate_initial and validation_loader is not None:
initial_validation_metrics = evaluate_model(
training_model,
validation_loader,
split="validation",
dtype=dtype,
device=resolved_device,
loss_weights=loss_weights,
)
train_history: list[EpochMetrics] = []
validation_history: list[EpochMetrics] = []
best_validation_loss: float | None = None
best_validation_epoch: int | None = None
best_validation_metrics: EpochMetrics | None = None
best_state_dict: dict[str, torch.Tensor] | None = None
checks_without_improvement = 0
stopped_early = False
interrupted = False
restored_checkpoint_path: str | None = None
if progress_summary:
print(
"Starting torch-native optimization: "
f"{len(train_dataset)} structures, "
f"{epochs} epochs, batch_size={batch_size}, "
f"device={resolved_device}, dtype={dtype}"
)
print(
"Batching: "
f"cache_batches={cache_batches}, pin_memory={resolved_pin_memory}, "
f"cache_batches_on_device={cache_batches_on_device}, "
f"feature_cache_storage={feature_cache_storage}, "
f"feature_cache_mode={feature_cache_mode}, "
f"num_workers={num_workers}, compile_model={resolved_compile_model}"
)
if initial_onebody_energies:
formatted_onebody = ", ".join(
f"Z={atomic_type}: {value:.8f} eV"
for atomic_type, value in sorted(initial_onebody_energies.items())
)
print("Initialized one-body terms:", formatted_onebody)
if initial_train_metrics is not None:
print(
"Initial train metrics: "
f"loss={initial_train_metrics.loss:.6e}, "
"energy_mae="
f"{_format_optional_metric(initial_train_metrics.energy_mae, ' eV')}, "
"forces_mae="
f"{_format_optional_metric(initial_train_metrics.forces_mae, ' eV/A')}"
)
if initial_validation_metrics is not None:
initial_validation_energy_mae = _format_optional_metric(
initial_validation_metrics.energy_mae,
" eV",
)
initial_validation_forces_mae = _format_optional_metric(
initial_validation_metrics.forces_mae,
" eV/A",
)
print(
"Initial validation metrics: "
f"loss={initial_validation_metrics.loss:.6e}, "
f"energy_mae={initial_validation_energy_mae}, "
f"forces_mae={initial_validation_forces_mae}"
)
if early_stopping_patience is not None:
print(
"Early stopping: "
f"patience={early_stopping_patience} validation checks, "
f"min_delta={early_stopping_min_delta:.3e}, "
f"restore_best={restore_best}"
)
if resolved_best_checkpoint_path is not None:
print("Best-validation checkpoint:", resolved_best_checkpoint_path)
print(
"Weight decay: "
f"default={float(weight_decay):.3e}, "
f"one_body={resolved_onebody_weight_decay:.3e}, "
f"two_body={resolved_pair_weight_decay:.3e}, "
f"three_body={resolved_threebody_weight_decay:.3e}, "
f"embedding={resolved_embedding_weight_decay:.3e}, "
f"state={resolved_state_weight_decay:.3e}, "
f"charge_spin={resolved_charge_spin_weight_decay:.3e}"
)
try:
for epoch in range(1, epochs + 1):
metrics = train_one_epoch(
training_model,
train_loader,
optimizer=optimizer,
dtype=dtype,
device=resolved_device,
loss_weights=loss_weights,
twobody_shape_penalty=twobody_shape_penalty,
twobody_wall_penalty=twobody_wall_penalty,
max_grad_norm=max_grad_norm,
progress=progress_bars,
progress_description=f"train {epoch}/{epochs}",
)
train_history.append(metrics)
if progress_bars and (
epoch == 1 or epoch == epochs or epoch % progress_frequency == 0
):
print(
f"Epoch {epoch:4d}/{epochs}: "
f"loss={metrics.loss:.6e}, "
"energy_mae="
f"{_format_optional_metric(metrics.energy_mae, ' eV')}, "
"forces_mae="
f"{_format_optional_metric(metrics.forces_mae, ' eV/A')}"
)
should_validate = (
validation_loader is not None
and validation_frequency is not None
and (epoch == epochs or epoch % validation_frequency == 0)
)
if should_validate:
validation_metrics = evaluate_model(
training_model,
validation_loader, # type: ignore[arg-type]
split="validation",
dtype=dtype,
device=resolved_device,
loss_weights=loss_weights,
progress=progress_bars,
progress_description=f"validation {epoch}/{epochs}",
)
validation_history.append(validation_metrics)
improved = (
best_validation_loss is None
or validation_metrics.loss
< best_validation_loss - early_stopping_min_delta
)
if improved:
best_validation_loss = validation_metrics.loss
best_validation_epoch = epoch
best_validation_metrics = validation_metrics
checks_without_improvement = 0
if restore_best:
best_state_dict = _copy_state_dict_to_cpu(model)
if resolved_best_checkpoint_path is not None:
checkpoint_metadata = dict(resolved_best_checkpoint_metadata)
checkpoint_metadata.update(
{
"checkpoint": "best_validation",
"best_validation_epoch": int(epoch),
"best_validation_loss": float(validation_metrics.loss),
"best_validation_metrics": _epoch_metrics_dict(
validation_metrics
),
"completed_epochs": int(epoch),
"dtype": str(dtype).replace("torch.", ""),
"device": str(resolved_device),
"force_loss_weight": float(force_loss_weight),
}
)
_save_training_state_checkpoint(
resolved_best_checkpoint_path,
model=model,
metadata=checkpoint_metadata,
)
if progress_summary:
print(
"Saved best-validation checkpoint: "
f"{resolved_best_checkpoint_path}"
)
elif early_stopping_patience is not None:
checks_without_improvement += 1
if progress_summary:
print(
_format_train_validation_summary(
epoch=epoch,
epochs=epochs,
train_metrics=metrics,
validation_metrics=validation_metrics,
)
)
if early_stopping_patience is not None:
print(
"Early stopping status: "
f"best_epoch={best_validation_epoch}, "
f"checks_without_improvement="
f"{checks_without_improvement}/"
f"{early_stopping_patience}"
)
if (
early_stopping_patience is not None
and checks_without_improvement >= early_stopping_patience
):
stopped_early = True
if progress_summary:
print(
"Stopping early after validation epoch "
f"{epoch}; best validation epoch was "
f"{best_validation_epoch}."
)
break
except KeyboardInterrupt:
interrupted = True
if progress_summary:
print(
"Interrupted during training; restoring the best current-run "
"checkpoint when available."
)
if interrupted and restore_best and best_state_dict is not None:
model.load_state_dict(best_state_dict)
if resolved_best_checkpoint_path is not None:
restored_checkpoint_path = str(resolved_best_checkpoint_path)
if progress_summary:
print(f"Restored best validation state from epoch {best_validation_epoch}.")
elif (
interrupted
and restore_best
and resolved_best_checkpoint_path is not None
and resolved_best_checkpoint_path.is_file()
):
try:
checkpoint = torch.load(
resolved_best_checkpoint_path,
map_location="cpu",
weights_only=False,
)
except TypeError:
checkpoint = torch.load(resolved_best_checkpoint_path, map_location="cpu")
state_dict = (
checkpoint.get("state_dict") if isinstance(checkpoint, dict) else None
)
if isinstance(state_dict, dict):
model.load_state_dict(state_dict)
restored_checkpoint_path = str(resolved_best_checkpoint_path)
if progress_summary:
print(
"Restored best validation checkpoint: "
f"{resolved_best_checkpoint_path}"
)
if (
not interrupted
and early_stopping_patience is not None
and restore_best
and best_state_dict is not None
):
model.load_state_dict(best_state_dict)
need_final_train_evaluation = evaluate_final_train or not train_history
if need_final_train_evaluation:
final_train_metrics = evaluate_model(
training_model,
train_loader,
split="train",
dtype=dtype,
device=resolved_device,
loss_weights=loss_weights,
progress=progress_bars,
progress_description="train final",
)
else:
final_train_metrics = train_history[-1]
if progress_summary and need_final_train_evaluation:
print(
"Final train metrics: "
f"loss={final_train_metrics.loss:.6e}, "
"energy_mae="
f"{_format_optional_metric(final_train_metrics.energy_mae, ' eV')}, "
"forces_mae="
f"{_format_optional_metric(final_train_metrics.forces_mae, ' eV/A')}"
)
final_validation_metrics = (
None if not validation_history else validation_history[-1]
)
if (
restore_best
and best_validation_metrics is not None
and (early_stopping_patience is not None or interrupted)
):
final_validation_metrics = best_validation_metrics
n_parameters = sum(
parameter.numel() for parameter in model.parameters() if parameter.requires_grad
)
return TorchTrainingResult(
history=TrainingHistory(
train=tuple(train_history),
validation=tuple(validation_history),
),
initial_train_metrics=initial_train_metrics,
initial_validation_metrics=initial_validation_metrics,
final_train_metrics=final_train_metrics,
final_validation_metrics=final_validation_metrics,
loss_weights=loss_weights,
device=str(resolved_device),
dtype=str(dtype).replace("torch.", ""),
batch_size=int(batch_size),
validation_batch_size=(
None if validation_dataset is None else int(resolved_validation_batch_size)
),
epochs=int(epochs),
completed_epochs=int(len(train_history)),
learning_rate=float(learning_rate),
weight_decay=float(weight_decay),
onebody_weight_decay=float(resolved_onebody_weight_decay),
pair_weight_decay=float(resolved_pair_weight_decay),
threebody_weight_decay=float(resolved_threebody_weight_decay),
embedding_weight_decay=float(resolved_embedding_weight_decay),
state_weight_decay=float(resolved_state_weight_decay),
charge_spin_weight_decay=float(resolved_charge_spin_weight_decay),
term_weight_decays={
str(key): float(value)
for key, value in sorted(resolved_term_weight_decays.items())
},
max_grad_norm=max_grad_norm,
validation_frequency=validation_frequency,
early_stopping_patience=early_stopping_patience,
early_stopping_min_delta=float(early_stopping_min_delta),
restore_best=bool(restore_best),
stopped_early=bool(stopped_early),
interrupted=bool(interrupted),
restored_checkpoint_path=restored_checkpoint_path,
best_validation_epoch=best_validation_epoch,
best_validation_metrics=best_validation_metrics,
best_checkpoint_path=(
None
if resolved_best_checkpoint_path is None
else str(resolved_best_checkpoint_path)
),
num_workers=int(num_workers),
pin_memory=bool(resolved_pin_memory),
cache_batches=bool(cache_batches),
cache_batches_on_device=bool(cache_batches_on_device),
feature_cache_storage=str(feature_cache_storage),
feature_cache_mode=str(feature_cache_mode),
validation_feature_cache_storage=str(validation_feature_cache_storage),
validation_feature_cache_mode=str(resolved_validation_feature_cache_mode),
feature_cache_dir=None if feature_cache_dir is None else str(feature_cache_dir),
validation_feature_cache_dir=(
None
if validation_feature_cache_dir is None
else str(validation_feature_cache_dir)
),
feature_cache_per_atom_energy=bool(feature_cache_per_atom_energy),
compile_model=bool(resolved_compile_model),
evaluate_initial=bool(evaluate_initial),
evaluate_final_train=bool(evaluate_final_train),
initial_onebody_energies=initial_onebody_energies,
n_parameters=int(n_parameters),
)