"""Least-squares-vs-training microbenchmarks built on toy UFP problems."""
from __future__ import annotations
import argparse
import json
import time
from dataclasses import asdict, dataclass
from itertools import combinations_with_replacement
from typing import Callable, Sequence
import ase
import numpy as np
import torch
from ufp.benchmarks._common import (
BenchmarkCheckpoint,
BenchmarkPoint,
BenchmarkResult,
BenchmarkWorkloadDefaults,
format_number,
parse_positive_int_sequence,
resolve_device,
resolve_dtype,
scenario_choices,
)
from ufp.leastsquares import FitSample, LinearFitter
from ufp.leastsquares._block import (
_block_matrix_diagonal,
_block_solve_batch_storage_elements,
_block_solve_batch_storage_nbytes,
)
from ufp.leastsquares._layout import ParameterLayout
from ufp.leastsquares.linear import BlockLinearProblem
from ufp.neighbors._neighbors import build_neighbor_list
from ufp.terms._threebody_kernels import (
native_threebody_dense_feature_cache_available,
native_threebody_feature_cache_available,
native_threebody_lstsq_assemble_available,
)
from ufp.terms.model import UFPModel
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.twobody import SplinePairTerm
from ufp.training import (
ASEAtomsDataset,
LossWeights,
build_ase_dataloader,
evaluate_model,
train_one_epoch,
)
@dataclass(frozen=True)
class _ScenarioPreset:
"""Reusable definition of one synthetic benchmark scenario."""
name: str
description: str
workload: BenchmarkWorkloadDefaults
loss_weights: LossWeights
build_atoms: Callable[[int, int], list[ase.Atoms]]
build_model: Callable[[object, object | None], UFPModel]
build_teacher_coeffs: Callable[[], tuple[object, object | None]]
@dataclass(frozen=True)
class _ScenarioData:
"""Data container for Scenario."""
preset: _ScenarioPreset
device: torch.device
dtype: torch.dtype
checkpoint: str
precomputed_neighbor_lists: bool
loss_weights: LossWeights
train_atoms: tuple[ase.Atoms, ...]
validation_atoms: tuple[ase.Atoms, ...]
test_atoms: tuple[ase.Atoms, ...]
train_loader: torch.utils.data.DataLoader
validation_loader: torch.utils.data.DataLoader
test_loader: torch.utils.data.DataLoader
fit_samples: tuple[FitSample, ...]
teacher_theta: torch.Tensor
make_student_model: Callable[[], UFPModel]
@dataclass(frozen=True)
class _CGCheckpoint:
"""Checkpoint metadata for CG."""
theta: torch.Tensor
solve_time_s: float
_CHECKPOINTS: dict[str, BenchmarkCheckpoint] = {
"baseline": BenchmarkCheckpoint(
name="baseline",
description="Float64 everywhere with dynamic neighbor lists.",
dtype=torch.float64,
precompute_neighbor_lists=False,
),
"device_default_dtype": BenchmarkCheckpoint(
name="device_default_dtype",
description="Device-aware dtype with dynamic neighbor lists.",
dtype="auto",
precompute_neighbor_lists=False,
),
"cached_neighbor_lists": BenchmarkCheckpoint(
name="cached_neighbor_lists",
description="Device-aware dtype plus reused neighbor lists.",
dtype="auto",
precompute_neighbor_lists=True,
),
}
def _hh_pair_model(
pair_coeffs: torch.Tensor,
threebody_coeffs: torch.Tensor | None = None,
) -> UFPModel:
"""Build the pair-only hydrogen dimer benchmark model."""
if threebody_coeffs is not None:
raise ValueError("pair-only scenarios must not define three-body coefficients")
return UFPModel(
pair_terms=[
SplinePairTerm(
cutoff=2.5,
pair=(1, 1),
coeffs=pair_coeffs,
spline="cubic",
full_support_start=0.0,
)
],
atomic_types=[1],
)
def _triangle_model(
pair_coeffs: object,
threebody_coeffs: object | None,
) -> UFPModel:
"""Build the single-element triangle benchmark model."""
if not isinstance(pair_coeffs, torch.Tensor):
raise TypeError("triangle scenarios expect pair coefficients as a tensor")
if threebody_coeffs is None:
raise ValueError("triangle scenarios require three-body coefficients")
if not isinstance(threebody_coeffs, torch.Tensor):
raise TypeError("triangle scenarios expect three-body coefficients as a tensor")
return UFPModel(
pair_terms=[
SplinePairTerm(
cutoff=2.5,
pair=(1, 1),
coeffs=pair_coeffs,
spline="cubic",
full_support_start=0.0,
)
],
threebody_terms=[
SplineThreeBodyTerm(
cutoff=2.5,
atomic_types=[1],
coeffs_by_triplet=threebody_coeffs,
spline="cubic",
full_support_start_xy=0.0,
full_support_start_z=0.0,
)
],
atomic_types=[1],
)
def _quaternary_cluster_model(
pair_coeffs: object,
threebody_coeffs: object | None,
) -> UFPModel:
"""Build the multi-element molecular cluster benchmark model."""
if not isinstance(pair_coeffs, dict):
raise TypeError(
"quaternary cluster scenarios expect pair coefficients as a dict"
)
if threebody_coeffs is None:
raise ValueError("quaternary cluster scenarios require three-body coefficients")
if not isinstance(threebody_coeffs, dict):
raise TypeError(
"quaternary cluster scenarios expect three-body coefficients as a dict"
)
atomic_types = [1, 6, 8, 14]
pair_terms = [
SplinePairTerm(
cutoff=3.2,
pair=pair,
coeffs=pair_coeffs[pair],
spline="cubic",
full_support_start=0.0,
)
for pair in combinations_with_replacement(atomic_types, 2)
]
threebody_terms = [
SplineThreeBodyTerm(
cutoff=3.2,
atomic_types=list(atomic_types),
coeffs_by_triplet=threebody_coeffs["all"],
spline="cubic",
full_support_start_xy=0.0,
full_support_start_z=2.0,
)
]
return UFPModel(
pair_terms=pair_terms,
threebody_terms=threebody_terms,
atomic_types=atomic_types,
)
def _ternary_alloy_model(
pair_coeffs: object,
threebody_coeffs: object | None,
) -> UFPModel:
"""Build the ternary alloy cluster benchmark model."""
if not isinstance(pair_coeffs, dict):
raise TypeError("ternary alloy scenarios expect pair coefficients as a dict")
if threebody_coeffs is None:
raise ValueError("ternary alloy scenarios require three-body coefficients")
if not isinstance(threebody_coeffs, dict):
raise TypeError(
"ternary alloy scenarios expect three-body coefficients as a dict"
)
atomic_types = [13, 28, 29]
pair_terms = [
SplinePairTerm(
cutoff=3.4,
pair=pair,
coeffs=pair_coeffs[pair],
spline="cubic",
full_support_start=0.0,
)
for pair in combinations_with_replacement(atomic_types, 2)
]
threebody_terms = [
SplineThreeBodyTerm(
cutoff=3.4,
atomic_types=list(atomic_types),
coeffs_by_triplet=threebody_coeffs["all"],
spline="cubic",
full_support_start_xy=0.0,
full_support_start_z=2.0,
)
]
return UFPModel(
pair_terms=pair_terms,
threebody_terms=threebody_terms,
atomic_types=atomic_types,
)
def _pair_teacher_coeffs() -> tuple[torch.Tensor, torch.Tensor | None]:
"""Return teacher coefficients for the pair-only scenario."""
pair_coeffs = torch.zeros(10, dtype=torch.float64)
pair_coeffs[2] = 0.22
pair_coeffs[3] = -0.31
pair_coeffs[4] = 0.27
pair_coeffs[5] = -0.12
pair_coeffs[6] = 0.05
return pair_coeffs, None
def _triangle_teacher_coeffs() -> tuple[torch.Tensor, torch.Tensor | None]:
"""Return teacher coefficients for the triangle scenario."""
pair_coeffs = torch.zeros(8, dtype=torch.float64)
pair_coeffs[2] = 0.18
pair_coeffs[3] = -0.27
pair_coeffs[4] = 0.31
pair_coeffs[5] = -0.10
threebody_coeffs = torch.zeros((1, 8, 8, 8), dtype=torch.float64)
threebody_coeffs[0, 4, 4, 4] = 0.06
threebody_coeffs[0, 4, 5, 4] = -0.04
threebody_coeffs[0, 5, 4, 4] = 0.03
threebody_coeffs[0, 4, 4, 5] = -0.02
threebody_coeffs[0, 5, 5, 4] = 0.015
return pair_coeffs, threebody_coeffs
def _quaternary_cluster_teacher_coeffs() -> tuple[object, object | None]:
"""Return teacher coefficients for the quaternary cluster scenario."""
atomic_types = [1, 6, 8, 14]
pair_coeffs: dict[tuple[int, int], torch.Tensor] = {}
for pair_index, pair in enumerate(combinations_with_replacement(atomic_types, 2)):
coeffs = torch.zeros(10, dtype=torch.float64)
coeffs[2] = 0.03 * (pair_index + 1)
coeffs[3] = -0.018 * ((pair_index % 4) + 1)
coeffs[4] = 0.012 * (((pair_index + 2) % 5) + 1)
coeffs[5] = -0.008 * (((2 * pair_index) % 3) + 1)
coeffs[6] = 0.004 * (((pair_index + 1) % 6) + 1)
pair_coeffs[pair] = coeffs
n_triplet_categories = len(atomic_types) * (
len(atomic_types) * (len(atomic_types) + 1) // 2
)
threebody = torch.zeros((n_triplet_categories, 10, 10, 10), dtype=torch.float64)
for triplet_index in range(n_triplet_categories):
base = 0.0025 * ((triplet_index % 7) + 1)
sign = 1.0 if triplet_index % 2 == 0 else -1.0
x = 2 + (triplet_index % 2)
y = 2 + ((triplet_index // 2) % 2)
z = 2 + ((triplet_index // 4) % 2)
threebody[triplet_index, x, y, z] = sign * base
threebody[triplet_index, x + 1, y, z] = -0.5 * sign * base
threebody[triplet_index, x, y + 1, z] = 0.35 * sign * base
return pair_coeffs, {"all": threebody}
def _ternary_alloy_teacher_coeffs() -> tuple[object, object | None]:
"""Return teacher coefficients for the ternary alloy scenario."""
atomic_types = [13, 28, 29]
pair_coeffs: dict[tuple[int, int], torch.Tensor] = {}
for pair_index, pair in enumerate(combinations_with_replacement(atomic_types, 2)):
coeffs = torch.zeros(10, dtype=torch.float64)
coeffs[2] = 0.045 * (pair_index + 1)
coeffs[3] = -0.025 * (((pair_index + 1) % 3) + 1)
coeffs[4] = 0.018 * (((2 * pair_index) % 4) + 1)
coeffs[5] = -0.010 * (((pair_index + 2) % 5) + 1)
coeffs[6] = 0.006 * (((pair_index + 3) % 4) + 1)
pair_coeffs[pair] = coeffs
n_triplet_categories = len(atomic_types) * (
len(atomic_types) * (len(atomic_types) + 1) // 2
)
threebody = torch.zeros((n_triplet_categories, 9, 9, 9), dtype=torch.float64)
for triplet_index in range(n_triplet_categories):
base = 0.004 * ((triplet_index % 5) + 1)
sign = 1.0 if triplet_index % 2 == 0 else -1.0
x = 2 + (triplet_index % 3)
y = 2 + ((triplet_index // 2) % 2)
z = 2 + ((triplet_index // 3) % 2)
threebody[triplet_index, x, y, z] = sign * base
threebody[triplet_index, x + 1, y, z] = -0.45 * sign * base
threebody[triplet_index, x, y + 1, z] = 0.30 * sign * base
threebody[triplet_index, x, y, z + 1] = -0.20 * sign * base
return pair_coeffs, {"all": threebody}
def _pair_atoms(count: int, seed: int) -> list[ase.Atoms]:
"""Generate random hydrogen dimer structures."""
rng = np.random.default_rng(seed)
frames: list[ase.Atoms] = []
for _ in range(count):
distance = float(rng.uniform(0.70, 1.85))
frames.append(
ase.Atoms(
symbols=["H", "H"],
positions=[[0.0, 0.0, 0.0], [distance, 0.0, 0.0]],
cell=np.eye(3) * 8.0,
pbc=False,
)
)
return frames
def _triangle_atoms(count: int, seed: int) -> list[ase.Atoms]:
"""Generate random single-element triangle structures."""
rng = np.random.default_rng(seed)
frames: list[ase.Atoms] = []
for _ in range(count):
edge_x = float(rng.uniform(0.85, 1.20))
edge_y = float(rng.uniform(0.85, 1.20))
offset_x = float(rng.uniform(-0.15, 0.15))
frames.append(
ase.Atoms(
symbols=["H", "H", "H"],
positions=[
[0.0, 0.0, 0.0],
[edge_x, 0.0, 0.0],
[offset_x, edge_y, 0.0],
],
cell=np.eye(3) * 10.0,
pbc=False,
)
)
return frames
def _quaternary_cluster_atoms(count: int, seed: int) -> list[ase.Atoms]:
"""Generate perturbed quaternary molecular cluster structures."""
rng = np.random.default_rng(seed)
base_symbols = ["Si", "O", "O", "C", "H", "H", "H", "O"]
base_positions = np.array(
[
[0.00, 0.00, 0.00],
[1.62, 0.10, 0.05],
[-1.56, -0.14, 0.08],
[0.22, 1.72, -0.18],
[0.34, 2.73, 0.32],
[2.22, -0.54, -0.22],
[-2.16, 0.48, -0.30],
[0.08, -1.84, 0.26],
],
dtype=float,
)
frames: list[ase.Atoms] = []
for _ in range(count):
positions = base_positions.copy()
positions += rng.normal(scale=0.10, size=positions.shape)
positions *= rng.uniform(0.94, 1.08)
angles = rng.normal(scale=0.18, size=3)
cx, cy, cz = np.cos(angles)
sx, sy, sz = np.sin(angles)
rot_x = np.array([[1.0, 0.0, 0.0], [0.0, cx, -sx], [0.0, sx, cx]])
rot_y = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]])
rot_z = np.array([[cz, -sz, 0.0], [sz, cz, 0.0], [0.0, 0.0, 1.0]])
positions = positions @ (rot_z @ rot_y @ rot_x).T
positions += rng.normal(scale=0.04, size=positions.shape)
frames.append(
ase.Atoms(
symbols=base_symbols,
positions=positions,
cell=np.eye(3) * 14.0,
pbc=False,
)
)
return frames
def _ternary_alloy_atoms(count: int, seed: int) -> list[ase.Atoms]:
"""Generate perturbed ternary alloy cluster structures."""
rng = np.random.default_rng(seed)
base_symbols = ["Al", "Al", "Ni", "Ni", "Cu", "Cu"]
base_positions = np.array(
[
[0.00, 0.00, 0.00],
[2.45, 0.08, -0.06],
[1.18, 2.06, 0.18],
[-1.26, 2.02, -0.22],
[-2.18, -0.18, 0.12],
[0.12, -2.36, -0.10],
],
dtype=float,
)
frames: list[ase.Atoms] = []
for _ in range(count):
positions = base_positions.copy()
positions += rng.normal(scale=0.12, size=positions.shape)
positions *= rng.uniform(0.95, 1.06)
angles = rng.normal(scale=0.16, size=3)
cx, cy, cz = np.cos(angles)
sx, sy, sz = np.sin(angles)
rot_x = np.array([[1.0, 0.0, 0.0], [0.0, cx, -sx], [0.0, sx, cx]])
rot_y = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]])
rot_z = np.array([[cz, -sz, 0.0], [sz, cz, 0.0], [0.0, 0.0, 1.0]])
positions = positions @ (rot_z @ rot_y @ rot_x).T
positions += rng.normal(scale=0.05, size=positions.shape)
frames.append(
ase.Atoms(
symbols=base_symbols,
positions=positions,
cell=np.eye(3) * 14.0,
pbc=False,
)
)
return frames
_SCENARIOS: dict[str, _ScenarioPreset] = {
"pair_only": _ScenarioPreset(
name="pair_only",
description="Diatomic H-H toy problem with one cubic two-body spline term.",
workload=BenchmarkWorkloadDefaults(
train_size=800,
validation_size=200,
test_size=200,
training_batch_size=128,
training_epochs=4,
learning_rate=0.05,
cg_checkpoints=(1, 2, 3, 4),
),
loss_weights=LossWeights(energy=1.0, forces=2.0),
build_atoms=_pair_atoms,
build_model=_hh_pair_model,
build_teacher_coeffs=_pair_teacher_coeffs,
),
"triangle_pair_threebody": _ScenarioPreset(
name="triangle_pair_threebody",
description=(
"Three-atom H triangle toy problem with one two-body spline term and "
"one three-body spline block."
),
workload=BenchmarkWorkloadDefaults(
train_size=800,
validation_size=200,
test_size=200,
training_batch_size=128,
training_epochs=4,
learning_rate=0.05,
cg_checkpoints=(1, 2, 3, 4),
),
loss_weights=LossWeights(energy=1.0, forces=5.0),
build_atoms=_triangle_atoms,
build_model=_triangle_model,
build_teacher_coeffs=_triangle_teacher_coeffs,
),
"ternary_alloy": _ScenarioPreset(
name="ternary_alloy",
description=(
"Six-atom Al/Ni/Cu alloy cluster with all 6 pair channels active and "
"one three-category three-body spline block."
),
workload=BenchmarkWorkloadDefaults(
train_size=800,
validation_size=200,
test_size=200,
training_batch_size=64,
training_epochs=4,
learning_rate=0.035,
cg_checkpoints=(1, 2, 3, 4),
),
loss_weights=LossWeights(energy=1.0, forces=6.0),
build_atoms=_ternary_alloy_atoms,
build_model=_ternary_alloy_model,
build_teacher_coeffs=_ternary_alloy_teacher_coeffs,
),
"quaternary_cluster": _ScenarioPreset(
name="quaternary_cluster",
description=(
"Eight-atom H/C/O/Si cluster with all 10 pair channels active and one "
"four-category three-body spline block."
),
workload=BenchmarkWorkloadDefaults(
train_size=800,
validation_size=200,
test_size=200,
training_batch_size=32,
training_epochs=4,
learning_rate=0.03,
cg_checkpoints=(1, 2, 3, 4),
),
loss_weights=LossWeights(energy=1.0, forces=8.0),
build_atoms=_quaternary_cluster_atoms,
build_model=_quaternary_cluster_model,
build_teacher_coeffs=_quaternary_cluster_teacher_coeffs,
),
}
def _label_frames(
teacher_model: UFPModel,
frames: Sequence[ase.Atoms],
*,
device: torch.device,
dtype: torch.dtype,
) -> tuple[ase.Atoms, ...]:
"""Attach teacher-model energy and force labels to frames."""
labeled: list[ase.Atoms] = []
for atoms in frames:
output = teacher_model.compute(
atoms,
device=device,
dtype=dtype,
derive_forces=True,
)
if output.energy is None or output.forces is None:
raise RuntimeError("teacher model must produce both energies and forces")
tagged = atoms.copy()
tagged.info["energy"] = float(output.energy.reshape(-1)[0].item())
tagged.arrays["forces"] = output.forces.detach().cpu().numpy()
labeled.append(tagged)
return tuple(labeled)
def _aligned_fit_samples(
atoms_list: Sequence[ase.Atoms],
*,
loss_weights: LossWeights,
neighbor_lists: Sequence[object | None] | None = None,
) -> tuple[FitSample, ...]:
"""Convert labeled frames into least-squares samples with metric-aligned weights."""
n_energy_rows = len(atoms_list)
n_force_rows = sum(len(atoms) * 3 for atoms in atoms_list)
if n_energy_rows <= 0 or n_force_rows <= 0:
raise ValueError(
"benchmark datasets must contain at least one energy and force row"
)
# Match the weighted global MSE used by the training metrics with uniform
# per-row weights in the least-squares objective.
energy_weight = loss_weights.energy / n_energy_rows
force_weight = loss_weights.forces / n_force_rows
if neighbor_lists is None:
neighbor_lists = [None] * len(atoms_list)
return tuple(
FitSample(
atoms=atoms,
neighbor_list=neighbor_list,
energy=float(atoms.info["energy"]),
forces=atoms.arrays["forces"],
energy_weight=energy_weight,
force_weight=force_weight,
)
for atoms, neighbor_list in zip(atoms_list, neighbor_lists, strict=True)
)
def _zero_like_coeffs(coeffs: object) -> object:
"""Create zero-filled coefficient containers matching the teacher layout."""
if isinstance(coeffs, torch.Tensor):
return torch.zeros_like(coeffs)
if isinstance(coeffs, dict):
return {key: _zero_like_coeffs(value) for key, value in coeffs.items()}
raise TypeError(f"unsupported coefficient container type: {type(coeffs)}")
def _build_scenario(
scenario: str,
*,
seed: int,
device: torch.device,
dtype: torch.dtype,
checkpoint: str,
precompute_neighbor_lists: bool,
train_size: int | None,
validation_size: int | None,
test_size: int | None,
training_batch_size: int | None,
) -> _ScenarioData:
"""Materialize datasets, loaders, and model factories for one benchmark scenario."""
if scenario not in _SCENARIOS:
choices = ", ".join(sorted(_SCENARIOS))
raise ValueError(f"unknown scenario '{scenario}'. Expected one of: {choices}")
preset = _SCENARIOS[scenario]
resolved_train_size = (
preset.workload.train_size if train_size is None else int(train_size)
)
resolved_validation_size = (
preset.workload.validation_size
if validation_size is None
else int(validation_size)
)
resolved_test_size = (
preset.workload.test_size if test_size is None else int(test_size)
)
resolved_batch_size = (
preset.workload.training_batch_size
if training_batch_size is None
else int(training_batch_size)
)
if (
min(
resolved_train_size,
resolved_validation_size,
resolved_test_size,
resolved_batch_size,
)
<= 0
):
raise ValueError("benchmark sizes and batch size must be positive")
pair_coeffs, threebody_coeffs = preset.build_teacher_coeffs()
teacher_model = preset.build_model(pair_coeffs, threebody_coeffs)
teacher_model.to(device=device, dtype=dtype)
teacher_theta = (
ParameterLayout.from_model(teacher_model).current_true_vector().detach().clone()
)
total_frames = resolved_train_size + resolved_validation_size + resolved_test_size
unlabeled = preset.build_atoms(total_frames, seed)
labeled = _label_frames(
teacher_model,
unlabeled,
device=device,
dtype=dtype,
)
neighbor_lists = None
if precompute_neighbor_lists:
if teacher_model.cutoff is None:
raise ValueError("precomputed neighbor lists require a model cutoff")
neighbor_lists = [
build_neighbor_list(
atoms=item,
cutoff=teacher_model.cutoff,
backend=teacher_model.neighbor_backend,
arrays="torch",
)
for item in labeled
]
train_atoms = labeled[:resolved_train_size]
validation_atoms = labeled[
resolved_train_size : resolved_train_size + resolved_validation_size
]
test_atoms = labeled[resolved_train_size + resolved_validation_size :]
train_loader = build_ase_dataloader(
ASEAtomsDataset.from_atoms(
train_atoms,
forces_key="forces",
neighbor_lists=(
None if neighbor_lists is None else neighbor_lists[:resolved_train_size]
),
),
batch_size=resolved_batch_size,
shuffle=False,
pin_memory=device.type == "cuda",
)
validation_loader = build_ase_dataloader(
ASEAtomsDataset.from_atoms(
validation_atoms,
forces_key="forces",
neighbor_lists=(
None
if neighbor_lists is None
else neighbor_lists[
resolved_train_size : resolved_train_size + resolved_validation_size
]
),
),
batch_size=resolved_validation_size,
shuffle=False,
pin_memory=device.type == "cuda",
)
test_loader = build_ase_dataloader(
ASEAtomsDataset.from_atoms(
test_atoms,
forces_key="forces",
neighbor_lists=(
None
if neighbor_lists is None
else neighbor_lists[resolved_train_size + resolved_validation_size :]
),
),
batch_size=resolved_test_size,
shuffle=False,
pin_memory=device.type == "cuda",
)
fit_samples = _aligned_fit_samples(
train_atoms,
loss_weights=preset.loss_weights,
neighbor_lists=(
None if neighbor_lists is None else neighbor_lists[:resolved_train_size]
),
)
def make_student_model() -> UFPModel:
"""Create an unfitted student model with zeroed coefficients."""
zero_pair = _zero_like_coeffs(pair_coeffs)
zero_threebody = (
None if threebody_coeffs is None else _zero_like_coeffs(threebody_coeffs)
)
return preset.build_model(zero_pair, zero_threebody)
return _ScenarioData(
preset=preset,
device=device,
dtype=dtype,
checkpoint=checkpoint,
precomputed_neighbor_lists=precompute_neighbor_lists,
loss_weights=preset.loss_weights,
train_atoms=train_atoms,
validation_atoms=validation_atoms,
test_atoms=test_atoms,
train_loader=train_loader,
validation_loader=validation_loader,
test_loader=test_loader,
fit_samples=fit_samples,
teacher_theta=teacher_theta,
make_student_model=make_student_model,
)
def _problem_normal_diagonal(problem: BlockLinearProblem) -> torch.Tensor:
"""Return the diagonal of the unregularized normal matrix."""
diagonal = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
for batch in problem.batches:
for key, block_matrix in batch.matrices.items():
diagonal[problem.layout.theta_slice(key)] += _block_matrix_diagonal(
block_matrix
)
return diagonal
def _problem_matrix_storage_elements(problem: BlockLinearProblem) -> int:
"""Return stored matrix value/index elements for a least-squares problem."""
return int(
sum(_block_solve_batch_storage_elements(batch) for batch in problem.batches)
)
def _problem_matrix_storage_nbytes(problem: BlockLinearProblem) -> int:
"""Return approximate matrix storage bytes for a least-squares problem."""
return int(
sum(_block_solve_batch_storage_nbytes(batch) for batch in problem.batches)
)
def _native_extension_availability() -> dict[str, bool]:
"""Return optional native extension availability relevant to least-squares."""
return {
"threebody_feature_cache_cpu": native_threebody_feature_cache_available(
device="cpu",
spline="cubic",
),
"threebody_dense_feature_cache_cpu": (
native_threebody_dense_feature_cache_available(device="cpu")
),
"threebody_lstsq_assemble_cpu": native_threebody_lstsq_assemble_available(
device="cpu",
spline="cubic",
),
"threebody_lstsq_assemble_cuda": (
torch.cuda.is_available()
and native_threebody_lstsq_assemble_available(
device="cuda",
spline="cubic",
)
),
}
def _cg_solve_with_checkpoints(
problem: BlockLinearProblem,
checkpoints: Sequence[int],
*,
tolerance: float,
) -> dict[int, _CGCheckpoint]:
"""Run preconditioned CG and capture selected iterate checkpoints."""
requested = tuple(sorted({int(step) for step in checkpoints if int(step) > 0}))
if not requested:
return {}
rhs = problem.rhs()
x = torch.zeros_like(rhs)
residual = rhs.clone()
diagonal = problem.regularization_diagonal() + _problem_normal_diagonal(problem)
safe_diagonal = torch.where(
torch.abs(diagonal) > 1.0e-14,
diagonal,
torch.ones_like(diagonal),
)
z = residual / safe_diagonal
direction = z.clone()
rz_old = torch.dot(residual, z)
results: dict[int, _CGCheckpoint] = {}
solve_start = time.perf_counter()
final_iteration = 0
max_iter = requested[-1]
if torch.sqrt(torch.clamp(rz_old, min=0.0)) <= tolerance:
elapsed = time.perf_counter() - solve_start
for step in requested:
results[step] = _CGCheckpoint(theta=x.clone(), solve_time_s=elapsed)
return results
for iteration in range(1, max_iter + 1):
mat_direction = problem.normal_matvec(direction)
denom = torch.dot(direction, mat_direction)
if torch.abs(denom) <= 1.0e-30:
final_iteration = iteration - 1
break
alpha = rz_old / denom
x = x + alpha * direction
residual = residual - alpha * mat_direction
final_iteration = iteration
if iteration in requested:
results[iteration] = _CGCheckpoint(
theta=x.clone(),
solve_time_s=time.perf_counter() - solve_start,
)
if torch.linalg.norm(residual) <= tolerance:
break
z = residual / safe_diagonal
rz_new = torch.dot(residual, z)
if torch.abs(rz_old) <= 1.0e-30:
break
beta = rz_new / rz_old
direction = z + beta * direction
rz_old = rz_new
if final_iteration < max_iter:
elapsed = time.perf_counter() - solve_start
for step in requested:
if step > final_iteration:
results[step] = _CGCheckpoint(theta=x.clone(), solve_time_s=elapsed)
return results
def _make_point(
*,
method: str,
budget_kind: str,
budget: int | None,
label: str,
optimize_time_s: float,
train_loss: float,
validation_metrics,
test_metrics,
) -> BenchmarkPoint:
"""Package raw metrics into one benchmark record."""
return BenchmarkPoint(
method=method,
budget_kind=budget_kind,
budget=budget,
label=label,
optimize_time_s=float(optimize_time_s),
train_loss=float(train_loss),
validation_loss=float(validation_metrics.loss),
validation_energy_mae=validation_metrics.energy_mae,
validation_forces_mae=validation_metrics.forces_mae,
test_loss=float(test_metrics.loss),
test_energy_mae=test_metrics.energy_mae,
test_forces_mae=test_metrics.forces_mae,
)
def _evaluate_split_metrics(
model: UFPModel,
scenario: _ScenarioData,
*,
device: torch.device | None = None,
):
"""Evaluate split metrics."""
eval_device = scenario.device if device is None else device
validation_metrics = evaluate_model(
model,
scenario.validation_loader,
split="validation",
dtype=scenario.dtype,
device=eval_device,
loss_weights=scenario.loss_weights,
)
test_metrics = evaluate_model(
model,
scenario.test_loader,
split="test",
dtype=scenario.dtype,
device=eval_device,
loss_weights=scenario.loss_weights,
)
return validation_metrics, test_metrics
def _run_training_benchmark(
scenario: _ScenarioData,
*,
epochs: int,
learning_rate: float,
) -> tuple[BenchmarkPoint, ...]:
"""Run training benchmark."""
model = scenario.make_student_model()
model.to(device=scenario.device, dtype=scenario.dtype)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
checkpoints = {0, *range(1, epochs + 1)}
records: list[BenchmarkPoint] = []
train_metrics = evaluate_model(
model,
scenario.train_loader,
split="train",
dtype=scenario.dtype,
device=scenario.device,
loss_weights=scenario.loss_weights,
)
validation_metrics, test_metrics = _evaluate_split_metrics(model, scenario)
records.append(
_make_point(
method="training_adam",
budget_kind="epoch",
budget=0,
label="epoch=0",
optimize_time_s=0.0,
train_loss=train_metrics.loss,
validation_metrics=validation_metrics,
test_metrics=test_metrics,
)
)
optimize_time_s = 0.0
for epoch in range(1, epochs + 1):
start = time.perf_counter()
train_metrics = train_one_epoch(
model,
scenario.train_loader,
optimizer=optimizer,
dtype=scenario.dtype,
device=scenario.device,
loss_weights=scenario.loss_weights,
)
optimize_time_s += time.perf_counter() - start
if epoch not in checkpoints:
continue
validation_metrics, test_metrics = _evaluate_split_metrics(model, scenario)
records.append(
_make_point(
method="training_adam",
budget_kind="epoch",
budget=epoch,
label=f"epoch={epoch}",
optimize_time_s=optimize_time_s,
train_loss=train_metrics.loss,
validation_metrics=validation_metrics,
test_metrics=test_metrics,
)
)
return tuple(records)
def _run_leastsquares_benchmark(
scenario: _ScenarioData,
*,
cg_checkpoints: Sequence[int],
leastsquares_batch_size: int,
direct_solver: str,
assembly_contract: str,
matrix_storage: str,
) -> tuple[tuple[BenchmarkPoint, ...], float, float, int, int, int]:
"""Run direct and checkpointed least-squares benchmark solves."""
leastsquares_device = torch.device("cpu")
model = scenario.make_student_model()
model.to(device=leastsquares_device, dtype=scenario.dtype)
fitter = LinearFitter(
model,
fit_energy=True,
fit_forces=True,
solver="cg",
dtype=scenario.dtype,
device=leastsquares_device,
assembly_contract=assembly_contract,
matrix_storage=matrix_storage,
)
initial_train = evaluate_model(
model,
scenario.train_loader,
split="train",
dtype=scenario.dtype,
device=leastsquares_device,
loss_weights=scenario.loss_weights,
)
initial_validation, initial_test = _evaluate_split_metrics(
model,
scenario,
device=leastsquares_device,
)
records = [
_make_point(
method="leastsquares_cg",
budget_kind="iteration",
budget=0,
label="iter=0",
optimize_time_s=0.0,
train_loss=initial_train.loss,
validation_metrics=initial_validation,
test_metrics=initial_test,
)
]
build_start = time.perf_counter()
problem = fitter.build_problem(
scenario.fit_samples,
batch_size=leastsquares_batch_size,
)
build_time_s = time.perf_counter() - build_start
n_rows = problem.n_rows
matrix_storage_elements = _problem_matrix_storage_elements(problem)
matrix_storage_bytes = _problem_matrix_storage_nbytes(problem)
checkpoint_states = _cg_solve_with_checkpoints(
problem,
cg_checkpoints,
tolerance=1.0e-10,
)
solve_time_s = max(
(state.solve_time_s for state in checkpoint_states.values()),
default=0.0,
)
for iteration in sorted(checkpoint_states):
state = checkpoint_states[iteration]
fitter.write_back(state.theta)
validation_metrics, test_metrics = _evaluate_split_metrics(
model,
scenario,
device=leastsquares_device,
)
records.append(
_make_point(
method="leastsquares_cg",
budget_kind="iteration",
budget=iteration,
label=f"iter={iteration}",
optimize_time_s=build_time_s + state.solve_time_s,
train_loss=float(problem.objective(state.theta).item()),
validation_metrics=validation_metrics,
test_metrics=test_metrics,
)
)
if direct_solver != "none":
direct_start = time.perf_counter()
direct_theta = problem.solve(
solver=direct_solver,
cg_tolerance=1.0e-10,
cg_max_iter=None,
)
direct_solve_time_s = time.perf_counter() - direct_start
solve_time_s = direct_solve_time_s
fitter.write_back(direct_theta)
direct_validation, direct_test = _evaluate_split_metrics(
model,
scenario,
device=leastsquares_device,
)
records.append(
_make_point(
method="leastsquares_direct",
budget_kind="solver",
budget=None,
label=direct_solver,
optimize_time_s=build_time_s + direct_solve_time_s,
train_loss=float(problem.objective(direct_theta).item()),
validation_metrics=direct_validation,
test_metrics=direct_test,
)
)
return (
tuple(records),
build_time_s,
solve_time_s,
n_rows,
matrix_storage_elements,
matrix_storage_bytes,
)
[docs]
def run_leastsquares_vs_training_benchmark(
*,
scenario: str = "triangle_pair_threebody",
seed: int = 0,
device: str | torch.device | None = None,
dtype: str | torch.dtype | None = "auto",
checkpoint: str = "custom",
precompute_neighbor_lists: bool = False,
train_size: int | None = None,
validation_size: int | None = None,
test_size: int | None = None,
training_batch_size: int | None = None,
leastsquares_batch_size: int | None = None,
training_epochs: int | None = None,
learning_rate: float | None = None,
cg_checkpoints: Sequence[int] | None = None,
direct_solver: str = "none",
assembly_contract: str = "term",
matrix_storage: str = "auto",
) -> BenchmarkResult:
"""Run one deterministic benchmark scenario and return checkpointed results."""
resolved_device = resolve_device(device)
resolved_dtype = resolve_dtype(resolved_device, dtype)
scenario_data = _build_scenario(
scenario,
seed=seed,
device=resolved_device,
dtype=resolved_dtype,
checkpoint=checkpoint,
precompute_neighbor_lists=precompute_neighbor_lists,
train_size=train_size,
validation_size=validation_size,
test_size=test_size,
training_batch_size=training_batch_size,
)
preset = scenario_data.preset
resolved_training_epochs = (
preset.workload.training_epochs
if training_epochs is None
else int(training_epochs)
)
resolved_learning_rate = (
preset.workload.learning_rate if learning_rate is None else float(learning_rate)
)
resolved_cg_checkpoints = (
preset.workload.cg_checkpoints
if cg_checkpoints is None
else tuple(sorted({int(step) for step in cg_checkpoints if int(step) > 0}))
)
resolved_lstsq_batch_size = (
scenario_data.train_loader.batch_size
if leastsquares_batch_size is None
else int(leastsquares_batch_size)
)
if resolved_training_epochs <= 0 or resolved_lstsq_batch_size <= 0:
raise ValueError(
"training epochs and least-squares batch size must be positive"
)
training_records = _run_training_benchmark(
scenario_data,
epochs=resolved_training_epochs,
learning_rate=resolved_learning_rate,
)
(
leastsquares_records,
build_time_s,
solve_time_s,
n_rows,
matrix_storage_elements,
matrix_storage_bytes,
) = _run_leastsquares_benchmark(
scenario_data,
cg_checkpoints=resolved_cg_checkpoints,
leastsquares_batch_size=resolved_lstsq_batch_size,
direct_solver=direct_solver,
assembly_contract=assembly_contract,
matrix_storage=matrix_storage,
)
n_parameters = int(scenario_data.teacher_theta.numel())
return BenchmarkResult(
scenario=preset.name,
description=preset.description,
seed=seed,
checkpoint=checkpoint,
device=str(resolved_device),
leastsquares_device="cpu",
dtype=str(resolved_dtype).replace("torch.", ""),
precomputed_neighbor_lists=precompute_neighbor_lists,
n_train=len(scenario_data.train_atoms),
n_validation=len(scenario_data.validation_atoms),
n_test=len(scenario_data.test_atoms),
n_parameters=n_parameters,
n_rows=n_rows,
training_batch_size=int(scenario_data.train_loader.batch_size),
leastsquares_batch_size=resolved_lstsq_batch_size,
training_epochs=resolved_training_epochs,
cg_checkpoints=tuple(resolved_cg_checkpoints),
loss_weights={
"energy": float(scenario_data.loss_weights.energy),
"forces": float(scenario_data.loss_weights.forces),
"stress": float(scenario_data.loss_weights.stress),
},
leastsquares_build_time_s=build_time_s,
leastsquares_solve_time_s=solve_time_s,
leastsquares_total_time_s=build_time_s + solve_time_s,
leastsquares_matrix_storage_elements=matrix_storage_elements,
leastsquares_matrix_storage_bytes=matrix_storage_bytes,
assembly_contract=str(assembly_contract),
matrix_storage=str(matrix_storage),
native_extensions=_native_extension_availability(),
direct_solver=direct_solver,
records=training_records + leastsquares_records,
)
[docs]
def run_benchmark_checkpoints(
*,
checkpoints: Sequence[str],
scenario: str = "triangle_pair_threebody",
seed: int = 0,
device: str | torch.device | None = None,
train_size: int | None = None,
validation_size: int | None = None,
test_size: int | None = None,
training_batch_size: int | None = None,
leastsquares_batch_size: int | None = None,
training_epochs: int | None = None,
learning_rate: float | None = None,
cg_checkpoints: Sequence[int] | None = None,
direct_solver: str = "none",
assembly_contract: str = "term",
matrix_storage: str = "auto",
) -> tuple[BenchmarkResult, ...]:
"""Run benchmark checkpoints."""
results: list[BenchmarkResult] = []
for checkpoint_name in checkpoints:
if checkpoint_name not in _CHECKPOINTS:
choices = ", ".join(sorted(_CHECKPOINTS))
raise ValueError(
f"unknown checkpoint '{checkpoint_name}'. Expected one of: {choices}"
)
spec = _CHECKPOINTS[checkpoint_name]
results.append(
run_leastsquares_vs_training_benchmark(
scenario=scenario,
seed=seed,
device=device,
dtype=spec.dtype,
checkpoint=spec.name,
precompute_neighbor_lists=spec.precompute_neighbor_lists,
train_size=train_size,
validation_size=validation_size,
test_size=test_size,
training_batch_size=training_batch_size,
leastsquares_batch_size=leastsquares_batch_size,
training_epochs=training_epochs,
learning_rate=learning_rate,
cg_checkpoints=cg_checkpoints,
direct_solver=direct_solver,
assembly_contract=assembly_contract,
matrix_storage=matrix_storage,
)
)
return tuple(results)
[docs]
def available_benchmark_scenarios() -> tuple[str, ...]:
"""Return the names of registered toy microbenchmark scenarios."""
return tuple(sorted(_SCENARIOS))
[docs]
def available_benchmark_checkpoints() -> tuple[str, ...]:
"""Return the names of registered benchmark checkpoint presets."""
return tuple(sorted(_CHECKPOINTS))
def _build_arg_parser() -> argparse.ArgumentParser:
"""Build the command-line parser for the benchmark script."""
parser = argparse.ArgumentParser(
description=(
"Benchmark convergence of ufp.leastsquares against ufp.training on "
"deterministic toy problems."
)
)
parser.add_argument(
"--scenario",
choices=scenario_choices(available_benchmark_scenarios()),
default="triangle_pair_threebody",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--device",
choices=["auto", "cpu", "cuda", "gpu"],
default="auto",
)
parser.add_argument(
"--dtype",
choices=["auto", "float32", "float64"],
default="auto",
)
parser.add_argument(
"--checkpoint-name",
default="custom",
help="Label used for a single benchmark run.",
)
parser.add_argument(
"--precompute-neighbor-lists",
action="store_true",
help="Reuse one precomputed neighbor list per structure.",
)
parser.add_argument(
"--checkpoint",
action="append",
choices=available_benchmark_checkpoints(),
help="Run one of the named A/B benchmark checkpoints. Can be repeated.",
)
parser.add_argument("--train-size", type=int)
parser.add_argument("--validation-size", type=int)
parser.add_argument("--test-size", type=int)
parser.add_argument("--training-batch-size", type=int)
parser.add_argument("--leastsquares-batch-size", type=int)
parser.add_argument("--training-epochs", type=int)
parser.add_argument("--learning-rate", type=float)
parser.add_argument(
"--cg-checkpoints",
type=lambda value: parse_positive_int_sequence(
value,
label="checkpoint lists",
),
)
parser.add_argument(
"--direct-solver",
choices=["none", "normal_equation_direct", "dense_lstsq"],
default="none",
)
parser.add_argument(
"--assembly-contract",
choices=["block", "term"],
default="term",
help="Least-squares assembly dispatch contract to benchmark.",
)
parser.add_argument(
"--matrix-storage",
choices=["dense", "row_indexed", "column_chunked", "auto"],
default="auto",
help="In-memory block matrix storage representation to benchmark.",
)
parser.add_argument(
"--json",
action="store_true",
help="Emit JSON instead of the formatted text report.",
)
return parser
[docs]
def main(argv: Sequence[str] | None = None) -> int:
"""Run the module's command-line entry point."""
parser = _build_arg_parser()
args = parser.parse_args(argv)
scenario_names = (
sorted(_SCENARIOS) if args.scenario == "all" else [str(args.scenario)]
)
if args.checkpoint:
results = [
result
for scenario_name in scenario_names
for result in run_benchmark_checkpoints(
checkpoints=args.checkpoint,
scenario=scenario_name,
seed=args.seed,
device=args.device,
train_size=args.train_size,
validation_size=args.validation_size,
test_size=args.test_size,
training_batch_size=args.training_batch_size,
leastsquares_batch_size=args.leastsquares_batch_size,
training_epochs=args.training_epochs,
learning_rate=args.learning_rate,
cg_checkpoints=args.cg_checkpoints,
direct_solver=args.direct_solver,
assembly_contract=args.assembly_contract,
matrix_storage=args.matrix_storage,
)
]
else:
results = [
run_leastsquares_vs_training_benchmark(
scenario=scenario_name,
seed=args.seed,
device=args.device,
dtype=args.dtype,
checkpoint=args.checkpoint_name,
precompute_neighbor_lists=args.precompute_neighbor_lists,
train_size=args.train_size,
validation_size=args.validation_size,
test_size=args.test_size,
training_batch_size=args.training_batch_size,
leastsquares_batch_size=args.leastsquares_batch_size,
training_epochs=args.training_epochs,
learning_rate=args.learning_rate,
cg_checkpoints=args.cg_checkpoints,
direct_solver=args.direct_solver,
assembly_contract=args.assembly_contract,
matrix_storage=args.matrix_storage,
)
for scenario_name in scenario_names
]
if args.json:
payload = [asdict(result) for result in results]
if len(payload) == 1:
print(json.dumps(payload[0], indent=2))
else:
print(json.dumps(payload, indent=2))
return 0
print("\n\n".join(format_benchmark_report(result) for result in results))
return 0
__all__ = [
"available_benchmark_checkpoints",
"available_benchmark_scenarios",
"BenchmarkPoint",
"BenchmarkResult",
"BenchmarkCheckpoint",
"format_benchmark_report",
"main",
"run_benchmark_checkpoints",
"run_leastsquares_vs_training_benchmark",
]