Source code for ufp.benchmarks._leastsquares_vs_training

"""Least-squares-vs-training microbenchmarks built on toy UFP problems."""

from __future__ import annotations

import argparse
import json
import time
from dataclasses import asdict, dataclass
from itertools import combinations_with_replacement
from typing import Callable, Sequence

import ase
import numpy as np
import torch

from ufp.benchmarks._common import (
    BenchmarkCheckpoint,
    BenchmarkPoint,
    BenchmarkResult,
    BenchmarkWorkloadDefaults,
    format_number,
    parse_positive_int_sequence,
    resolve_device,
    resolve_dtype,
    scenario_choices,
)
from ufp.leastsquares import FitSample, LinearFitter
from ufp.leastsquares._block import (
    _block_matrix_diagonal,
    _block_solve_batch_storage_elements,
    _block_solve_batch_storage_nbytes,
)
from ufp.leastsquares._layout import ParameterLayout
from ufp.leastsquares.linear import BlockLinearProblem
from ufp.neighbors._neighbors import build_neighbor_list
from ufp.terms._threebody_kernels import (
    native_threebody_dense_feature_cache_available,
    native_threebody_feature_cache_available,
    native_threebody_lstsq_assemble_available,
)
from ufp.terms.model import UFPModel
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.twobody import SplinePairTerm
from ufp.training import (
    ASEAtomsDataset,
    LossWeights,
    build_ase_dataloader,
    evaluate_model,
    train_one_epoch,
)


@dataclass(frozen=True)
class _ScenarioPreset:
    """Reusable definition of one synthetic benchmark scenario."""

    name: str
    description: str
    workload: BenchmarkWorkloadDefaults
    loss_weights: LossWeights
    build_atoms: Callable[[int, int], list[ase.Atoms]]
    build_model: Callable[[object, object | None], UFPModel]
    build_teacher_coeffs: Callable[[], tuple[object, object | None]]


@dataclass(frozen=True)
class _ScenarioData:
    """Data container for Scenario."""

    preset: _ScenarioPreset
    device: torch.device
    dtype: torch.dtype
    checkpoint: str
    precomputed_neighbor_lists: bool
    loss_weights: LossWeights
    train_atoms: tuple[ase.Atoms, ...]
    validation_atoms: tuple[ase.Atoms, ...]
    test_atoms: tuple[ase.Atoms, ...]
    train_loader: torch.utils.data.DataLoader
    validation_loader: torch.utils.data.DataLoader
    test_loader: torch.utils.data.DataLoader
    fit_samples: tuple[FitSample, ...]
    teacher_theta: torch.Tensor
    make_student_model: Callable[[], UFPModel]


@dataclass(frozen=True)
class _CGCheckpoint:
    """Checkpoint metadata for CG."""

    theta: torch.Tensor
    solve_time_s: float


_CHECKPOINTS: dict[str, BenchmarkCheckpoint] = {
    "baseline": BenchmarkCheckpoint(
        name="baseline",
        description="Float64 everywhere with dynamic neighbor lists.",
        dtype=torch.float64,
        precompute_neighbor_lists=False,
    ),
    "device_default_dtype": BenchmarkCheckpoint(
        name="device_default_dtype",
        description="Device-aware dtype with dynamic neighbor lists.",
        dtype="auto",
        precompute_neighbor_lists=False,
    ),
    "cached_neighbor_lists": BenchmarkCheckpoint(
        name="cached_neighbor_lists",
        description="Device-aware dtype plus reused neighbor lists.",
        dtype="auto",
        precompute_neighbor_lists=True,
    ),
}


def _hh_pair_model(
    pair_coeffs: torch.Tensor,
    threebody_coeffs: torch.Tensor | None = None,
) -> UFPModel:
    """Build the pair-only hydrogen dimer benchmark model."""
    if threebody_coeffs is not None:
        raise ValueError("pair-only scenarios must not define three-body coefficients")

    return UFPModel(
        pair_terms=[
            SplinePairTerm(
                cutoff=2.5,
                pair=(1, 1),
                coeffs=pair_coeffs,
                spline="cubic",
                full_support_start=0.0,
            )
        ],
        atomic_types=[1],
    )


def _triangle_model(
    pair_coeffs: object,
    threebody_coeffs: object | None,
) -> UFPModel:
    """Build the single-element triangle benchmark model."""
    if not isinstance(pair_coeffs, torch.Tensor):
        raise TypeError("triangle scenarios expect pair coefficients as a tensor")
    if threebody_coeffs is None:
        raise ValueError("triangle scenarios require three-body coefficients")
    if not isinstance(threebody_coeffs, torch.Tensor):
        raise TypeError("triangle scenarios expect three-body coefficients as a tensor")

    return UFPModel(
        pair_terms=[
            SplinePairTerm(
                cutoff=2.5,
                pair=(1, 1),
                coeffs=pair_coeffs,
                spline="cubic",
                full_support_start=0.0,
            )
        ],
        threebody_terms=[
            SplineThreeBodyTerm(
                cutoff=2.5,
                atomic_types=[1],
                coeffs_by_triplet=threebody_coeffs,
                spline="cubic",
                full_support_start_xy=0.0,
                full_support_start_z=0.0,
            )
        ],
        atomic_types=[1],
    )


def _quaternary_cluster_model(
    pair_coeffs: object,
    threebody_coeffs: object | None,
) -> UFPModel:
    """Build the multi-element molecular cluster benchmark model."""
    if not isinstance(pair_coeffs, dict):
        raise TypeError(
            "quaternary cluster scenarios expect pair coefficients as a dict"
        )
    if threebody_coeffs is None:
        raise ValueError("quaternary cluster scenarios require three-body coefficients")
    if not isinstance(threebody_coeffs, dict):
        raise TypeError(
            "quaternary cluster scenarios expect three-body coefficients as a dict"
        )

    atomic_types = [1, 6, 8, 14]
    pair_terms = [
        SplinePairTerm(
            cutoff=3.2,
            pair=pair,
            coeffs=pair_coeffs[pair],
            spline="cubic",
            full_support_start=0.0,
        )
        for pair in combinations_with_replacement(atomic_types, 2)
    ]
    threebody_terms = [
        SplineThreeBodyTerm(
            cutoff=3.2,
            atomic_types=list(atomic_types),
            coeffs_by_triplet=threebody_coeffs["all"],
            spline="cubic",
            full_support_start_xy=0.0,
            full_support_start_z=2.0,
        )
    ]
    return UFPModel(
        pair_terms=pair_terms,
        threebody_terms=threebody_terms,
        atomic_types=atomic_types,
    )


def _ternary_alloy_model(
    pair_coeffs: object,
    threebody_coeffs: object | None,
) -> UFPModel:
    """Build the ternary alloy cluster benchmark model."""
    if not isinstance(pair_coeffs, dict):
        raise TypeError("ternary alloy scenarios expect pair coefficients as a dict")
    if threebody_coeffs is None:
        raise ValueError("ternary alloy scenarios require three-body coefficients")
    if not isinstance(threebody_coeffs, dict):
        raise TypeError(
            "ternary alloy scenarios expect three-body coefficients as a dict"
        )

    atomic_types = [13, 28, 29]
    pair_terms = [
        SplinePairTerm(
            cutoff=3.4,
            pair=pair,
            coeffs=pair_coeffs[pair],
            spline="cubic",
            full_support_start=0.0,
        )
        for pair in combinations_with_replacement(atomic_types, 2)
    ]
    threebody_terms = [
        SplineThreeBodyTerm(
            cutoff=3.4,
            atomic_types=list(atomic_types),
            coeffs_by_triplet=threebody_coeffs["all"],
            spline="cubic",
            full_support_start_xy=0.0,
            full_support_start_z=2.0,
        )
    ]
    return UFPModel(
        pair_terms=pair_terms,
        threebody_terms=threebody_terms,
        atomic_types=atomic_types,
    )


def _pair_teacher_coeffs() -> tuple[torch.Tensor, torch.Tensor | None]:
    """Return teacher coefficients for the pair-only scenario."""
    pair_coeffs = torch.zeros(10, dtype=torch.float64)
    pair_coeffs[2] = 0.22
    pair_coeffs[3] = -0.31
    pair_coeffs[4] = 0.27
    pair_coeffs[5] = -0.12
    pair_coeffs[6] = 0.05
    return pair_coeffs, None


def _triangle_teacher_coeffs() -> tuple[torch.Tensor, torch.Tensor | None]:
    """Return teacher coefficients for the triangle scenario."""
    pair_coeffs = torch.zeros(8, dtype=torch.float64)
    pair_coeffs[2] = 0.18
    pair_coeffs[3] = -0.27
    pair_coeffs[4] = 0.31
    pair_coeffs[5] = -0.10

    threebody_coeffs = torch.zeros((1, 8, 8, 8), dtype=torch.float64)
    threebody_coeffs[0, 4, 4, 4] = 0.06
    threebody_coeffs[0, 4, 5, 4] = -0.04
    threebody_coeffs[0, 5, 4, 4] = 0.03
    threebody_coeffs[0, 4, 4, 5] = -0.02
    threebody_coeffs[0, 5, 5, 4] = 0.015
    return pair_coeffs, threebody_coeffs


def _quaternary_cluster_teacher_coeffs() -> tuple[object, object | None]:
    """Return teacher coefficients for the quaternary cluster scenario."""
    atomic_types = [1, 6, 8, 14]
    pair_coeffs: dict[tuple[int, int], torch.Tensor] = {}
    for pair_index, pair in enumerate(combinations_with_replacement(atomic_types, 2)):
        coeffs = torch.zeros(10, dtype=torch.float64)
        coeffs[2] = 0.03 * (pair_index + 1)
        coeffs[3] = -0.018 * ((pair_index % 4) + 1)
        coeffs[4] = 0.012 * (((pair_index + 2) % 5) + 1)
        coeffs[5] = -0.008 * (((2 * pair_index) % 3) + 1)
        coeffs[6] = 0.004 * (((pair_index + 1) % 6) + 1)
        pair_coeffs[pair] = coeffs

    n_triplet_categories = len(atomic_types) * (
        len(atomic_types) * (len(atomic_types) + 1) // 2
    )
    threebody = torch.zeros((n_triplet_categories, 10, 10, 10), dtype=torch.float64)
    for triplet_index in range(n_triplet_categories):
        base = 0.0025 * ((triplet_index % 7) + 1)
        sign = 1.0 if triplet_index % 2 == 0 else -1.0
        x = 2 + (triplet_index % 2)
        y = 2 + ((triplet_index // 2) % 2)
        z = 2 + ((triplet_index // 4) % 2)
        threebody[triplet_index, x, y, z] = sign * base
        threebody[triplet_index, x + 1, y, z] = -0.5 * sign * base
        threebody[triplet_index, x, y + 1, z] = 0.35 * sign * base
    return pair_coeffs, {"all": threebody}


def _ternary_alloy_teacher_coeffs() -> tuple[object, object | None]:
    """Return teacher coefficients for the ternary alloy scenario."""
    atomic_types = [13, 28, 29]
    pair_coeffs: dict[tuple[int, int], torch.Tensor] = {}
    for pair_index, pair in enumerate(combinations_with_replacement(atomic_types, 2)):
        coeffs = torch.zeros(10, dtype=torch.float64)
        coeffs[2] = 0.045 * (pair_index + 1)
        coeffs[3] = -0.025 * (((pair_index + 1) % 3) + 1)
        coeffs[4] = 0.018 * (((2 * pair_index) % 4) + 1)
        coeffs[5] = -0.010 * (((pair_index + 2) % 5) + 1)
        coeffs[6] = 0.006 * (((pair_index + 3) % 4) + 1)
        pair_coeffs[pair] = coeffs

    n_triplet_categories = len(atomic_types) * (
        len(atomic_types) * (len(atomic_types) + 1) // 2
    )
    threebody = torch.zeros((n_triplet_categories, 9, 9, 9), dtype=torch.float64)
    for triplet_index in range(n_triplet_categories):
        base = 0.004 * ((triplet_index % 5) + 1)
        sign = 1.0 if triplet_index % 2 == 0 else -1.0
        x = 2 + (triplet_index % 3)
        y = 2 + ((triplet_index // 2) % 2)
        z = 2 + ((triplet_index // 3) % 2)
        threebody[triplet_index, x, y, z] = sign * base
        threebody[triplet_index, x + 1, y, z] = -0.45 * sign * base
        threebody[triplet_index, x, y + 1, z] = 0.30 * sign * base
        threebody[triplet_index, x, y, z + 1] = -0.20 * sign * base
    return pair_coeffs, {"all": threebody}


def _pair_atoms(count: int, seed: int) -> list[ase.Atoms]:
    """Generate random hydrogen dimer structures."""
    rng = np.random.default_rng(seed)
    frames: list[ase.Atoms] = []
    for _ in range(count):
        distance = float(rng.uniform(0.70, 1.85))
        frames.append(
            ase.Atoms(
                symbols=["H", "H"],
                positions=[[0.0, 0.0, 0.0], [distance, 0.0, 0.0]],
                cell=np.eye(3) * 8.0,
                pbc=False,
            )
        )
    return frames


def _triangle_atoms(count: int, seed: int) -> list[ase.Atoms]:
    """Generate random single-element triangle structures."""
    rng = np.random.default_rng(seed)
    frames: list[ase.Atoms] = []
    for _ in range(count):
        edge_x = float(rng.uniform(0.85, 1.20))
        edge_y = float(rng.uniform(0.85, 1.20))
        offset_x = float(rng.uniform(-0.15, 0.15))
        frames.append(
            ase.Atoms(
                symbols=["H", "H", "H"],
                positions=[
                    [0.0, 0.0, 0.0],
                    [edge_x, 0.0, 0.0],
                    [offset_x, edge_y, 0.0],
                ],
                cell=np.eye(3) * 10.0,
                pbc=False,
            )
        )
    return frames


def _quaternary_cluster_atoms(count: int, seed: int) -> list[ase.Atoms]:
    """Generate perturbed quaternary molecular cluster structures."""
    rng = np.random.default_rng(seed)
    base_symbols = ["Si", "O", "O", "C", "H", "H", "H", "O"]
    base_positions = np.array(
        [
            [0.00, 0.00, 0.00],
            [1.62, 0.10, 0.05],
            [-1.56, -0.14, 0.08],
            [0.22, 1.72, -0.18],
            [0.34, 2.73, 0.32],
            [2.22, -0.54, -0.22],
            [-2.16, 0.48, -0.30],
            [0.08, -1.84, 0.26],
        ],
        dtype=float,
    )
    frames: list[ase.Atoms] = []
    for _ in range(count):
        positions = base_positions.copy()
        positions += rng.normal(scale=0.10, size=positions.shape)
        positions *= rng.uniform(0.94, 1.08)
        angles = rng.normal(scale=0.18, size=3)
        cx, cy, cz = np.cos(angles)
        sx, sy, sz = np.sin(angles)
        rot_x = np.array([[1.0, 0.0, 0.0], [0.0, cx, -sx], [0.0, sx, cx]])
        rot_y = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]])
        rot_z = np.array([[cz, -sz, 0.0], [sz, cz, 0.0], [0.0, 0.0, 1.0]])
        positions = positions @ (rot_z @ rot_y @ rot_x).T
        positions += rng.normal(scale=0.04, size=positions.shape)
        frames.append(
            ase.Atoms(
                symbols=base_symbols,
                positions=positions,
                cell=np.eye(3) * 14.0,
                pbc=False,
            )
        )
    return frames


def _ternary_alloy_atoms(count: int, seed: int) -> list[ase.Atoms]:
    """Generate perturbed ternary alloy cluster structures."""
    rng = np.random.default_rng(seed)
    base_symbols = ["Al", "Al", "Ni", "Ni", "Cu", "Cu"]
    base_positions = np.array(
        [
            [0.00, 0.00, 0.00],
            [2.45, 0.08, -0.06],
            [1.18, 2.06, 0.18],
            [-1.26, 2.02, -0.22],
            [-2.18, -0.18, 0.12],
            [0.12, -2.36, -0.10],
        ],
        dtype=float,
    )
    frames: list[ase.Atoms] = []
    for _ in range(count):
        positions = base_positions.copy()
        positions += rng.normal(scale=0.12, size=positions.shape)
        positions *= rng.uniform(0.95, 1.06)
        angles = rng.normal(scale=0.16, size=3)
        cx, cy, cz = np.cos(angles)
        sx, sy, sz = np.sin(angles)
        rot_x = np.array([[1.0, 0.0, 0.0], [0.0, cx, -sx], [0.0, sx, cx]])
        rot_y = np.array([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]])
        rot_z = np.array([[cz, -sz, 0.0], [sz, cz, 0.0], [0.0, 0.0, 1.0]])
        positions = positions @ (rot_z @ rot_y @ rot_x).T
        positions += rng.normal(scale=0.05, size=positions.shape)
        frames.append(
            ase.Atoms(
                symbols=base_symbols,
                positions=positions,
                cell=np.eye(3) * 14.0,
                pbc=False,
            )
        )
    return frames


_SCENARIOS: dict[str, _ScenarioPreset] = {
    "pair_only": _ScenarioPreset(
        name="pair_only",
        description="Diatomic H-H toy problem with one cubic two-body spline term.",
        workload=BenchmarkWorkloadDefaults(
            train_size=800,
            validation_size=200,
            test_size=200,
            training_batch_size=128,
            training_epochs=4,
            learning_rate=0.05,
            cg_checkpoints=(1, 2, 3, 4),
        ),
        loss_weights=LossWeights(energy=1.0, forces=2.0),
        build_atoms=_pair_atoms,
        build_model=_hh_pair_model,
        build_teacher_coeffs=_pair_teacher_coeffs,
    ),
    "triangle_pair_threebody": _ScenarioPreset(
        name="triangle_pair_threebody",
        description=(
            "Three-atom H triangle toy problem with one two-body spline term and "
            "one three-body spline block."
        ),
        workload=BenchmarkWorkloadDefaults(
            train_size=800,
            validation_size=200,
            test_size=200,
            training_batch_size=128,
            training_epochs=4,
            learning_rate=0.05,
            cg_checkpoints=(1, 2, 3, 4),
        ),
        loss_weights=LossWeights(energy=1.0, forces=5.0),
        build_atoms=_triangle_atoms,
        build_model=_triangle_model,
        build_teacher_coeffs=_triangle_teacher_coeffs,
    ),
    "ternary_alloy": _ScenarioPreset(
        name="ternary_alloy",
        description=(
            "Six-atom Al/Ni/Cu alloy cluster with all 6 pair channels active and "
            "one three-category three-body spline block."
        ),
        workload=BenchmarkWorkloadDefaults(
            train_size=800,
            validation_size=200,
            test_size=200,
            training_batch_size=64,
            training_epochs=4,
            learning_rate=0.035,
            cg_checkpoints=(1, 2, 3, 4),
        ),
        loss_weights=LossWeights(energy=1.0, forces=6.0),
        build_atoms=_ternary_alloy_atoms,
        build_model=_ternary_alloy_model,
        build_teacher_coeffs=_ternary_alloy_teacher_coeffs,
    ),
    "quaternary_cluster": _ScenarioPreset(
        name="quaternary_cluster",
        description=(
            "Eight-atom H/C/O/Si cluster with all 10 pair channels active and one "
            "four-category three-body spline block."
        ),
        workload=BenchmarkWorkloadDefaults(
            train_size=800,
            validation_size=200,
            test_size=200,
            training_batch_size=32,
            training_epochs=4,
            learning_rate=0.03,
            cg_checkpoints=(1, 2, 3, 4),
        ),
        loss_weights=LossWeights(energy=1.0, forces=8.0),
        build_atoms=_quaternary_cluster_atoms,
        build_model=_quaternary_cluster_model,
        build_teacher_coeffs=_quaternary_cluster_teacher_coeffs,
    ),
}


def _label_frames(
    teacher_model: UFPModel,
    frames: Sequence[ase.Atoms],
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> tuple[ase.Atoms, ...]:
    """Attach teacher-model energy and force labels to frames."""
    labeled: list[ase.Atoms] = []
    for atoms in frames:
        output = teacher_model.compute(
            atoms,
            device=device,
            dtype=dtype,
            derive_forces=True,
        )
        if output.energy is None or output.forces is None:
            raise RuntimeError("teacher model must produce both energies and forces")

        tagged = atoms.copy()
        tagged.info["energy"] = float(output.energy.reshape(-1)[0].item())
        tagged.arrays["forces"] = output.forces.detach().cpu().numpy()
        labeled.append(tagged)
    return tuple(labeled)


def _aligned_fit_samples(
    atoms_list: Sequence[ase.Atoms],
    *,
    loss_weights: LossWeights,
    neighbor_lists: Sequence[object | None] | None = None,
) -> tuple[FitSample, ...]:
    """Convert labeled frames into least-squares samples with metric-aligned weights."""
    n_energy_rows = len(atoms_list)
    n_force_rows = sum(len(atoms) * 3 for atoms in atoms_list)
    if n_energy_rows <= 0 or n_force_rows <= 0:
        raise ValueError(
            "benchmark datasets must contain at least one energy and force row"
        )

    # Match the weighted global MSE used by the training metrics with uniform
    # per-row weights in the least-squares objective.
    energy_weight = loss_weights.energy / n_energy_rows
    force_weight = loss_weights.forces / n_force_rows
    if neighbor_lists is None:
        neighbor_lists = [None] * len(atoms_list)
    return tuple(
        FitSample(
            atoms=atoms,
            neighbor_list=neighbor_list,
            energy=float(atoms.info["energy"]),
            forces=atoms.arrays["forces"],
            energy_weight=energy_weight,
            force_weight=force_weight,
        )
        for atoms, neighbor_list in zip(atoms_list, neighbor_lists, strict=True)
    )


def _zero_like_coeffs(coeffs: object) -> object:
    """Create zero-filled coefficient containers matching the teacher layout."""
    if isinstance(coeffs, torch.Tensor):
        return torch.zeros_like(coeffs)
    if isinstance(coeffs, dict):
        return {key: _zero_like_coeffs(value) for key, value in coeffs.items()}
    raise TypeError(f"unsupported coefficient container type: {type(coeffs)}")


def _build_scenario(
    scenario: str,
    *,
    seed: int,
    device: torch.device,
    dtype: torch.dtype,
    checkpoint: str,
    precompute_neighbor_lists: bool,
    train_size: int | None,
    validation_size: int | None,
    test_size: int | None,
    training_batch_size: int | None,
) -> _ScenarioData:
    """Materialize datasets, loaders, and model factories for one benchmark scenario."""
    if scenario not in _SCENARIOS:
        choices = ", ".join(sorted(_SCENARIOS))
        raise ValueError(f"unknown scenario '{scenario}'. Expected one of: {choices}")

    preset = _SCENARIOS[scenario]
    resolved_train_size = (
        preset.workload.train_size if train_size is None else int(train_size)
    )
    resolved_validation_size = (
        preset.workload.validation_size
        if validation_size is None
        else int(validation_size)
    )
    resolved_test_size = (
        preset.workload.test_size if test_size is None else int(test_size)
    )
    resolved_batch_size = (
        preset.workload.training_batch_size
        if training_batch_size is None
        else int(training_batch_size)
    )

    if (
        min(
            resolved_train_size,
            resolved_validation_size,
            resolved_test_size,
            resolved_batch_size,
        )
        <= 0
    ):
        raise ValueError("benchmark sizes and batch size must be positive")

    pair_coeffs, threebody_coeffs = preset.build_teacher_coeffs()
    teacher_model = preset.build_model(pair_coeffs, threebody_coeffs)
    teacher_model.to(device=device, dtype=dtype)
    teacher_theta = (
        ParameterLayout.from_model(teacher_model).current_true_vector().detach().clone()
    )

    total_frames = resolved_train_size + resolved_validation_size + resolved_test_size
    unlabeled = preset.build_atoms(total_frames, seed)
    labeled = _label_frames(
        teacher_model,
        unlabeled,
        device=device,
        dtype=dtype,
    )
    neighbor_lists = None
    if precompute_neighbor_lists:
        if teacher_model.cutoff is None:
            raise ValueError("precomputed neighbor lists require a model cutoff")
        neighbor_lists = [
            build_neighbor_list(
                atoms=item,
                cutoff=teacher_model.cutoff,
                backend=teacher_model.neighbor_backend,
                arrays="torch",
            )
            for item in labeled
        ]
    train_atoms = labeled[:resolved_train_size]
    validation_atoms = labeled[
        resolved_train_size : resolved_train_size + resolved_validation_size
    ]
    test_atoms = labeled[resolved_train_size + resolved_validation_size :]

    train_loader = build_ase_dataloader(
        ASEAtomsDataset.from_atoms(
            train_atoms,
            forces_key="forces",
            neighbor_lists=(
                None if neighbor_lists is None else neighbor_lists[:resolved_train_size]
            ),
        ),
        batch_size=resolved_batch_size,
        shuffle=False,
        pin_memory=device.type == "cuda",
    )
    validation_loader = build_ase_dataloader(
        ASEAtomsDataset.from_atoms(
            validation_atoms,
            forces_key="forces",
            neighbor_lists=(
                None
                if neighbor_lists is None
                else neighbor_lists[
                    resolved_train_size : resolved_train_size + resolved_validation_size
                ]
            ),
        ),
        batch_size=resolved_validation_size,
        shuffle=False,
        pin_memory=device.type == "cuda",
    )
    test_loader = build_ase_dataloader(
        ASEAtomsDataset.from_atoms(
            test_atoms,
            forces_key="forces",
            neighbor_lists=(
                None
                if neighbor_lists is None
                else neighbor_lists[resolved_train_size + resolved_validation_size :]
            ),
        ),
        batch_size=resolved_test_size,
        shuffle=False,
        pin_memory=device.type == "cuda",
    )

    fit_samples = _aligned_fit_samples(
        train_atoms,
        loss_weights=preset.loss_weights,
        neighbor_lists=(
            None if neighbor_lists is None else neighbor_lists[:resolved_train_size]
        ),
    )

    def make_student_model() -> UFPModel:
        """Create an unfitted student model with zeroed coefficients."""
        zero_pair = _zero_like_coeffs(pair_coeffs)
        zero_threebody = (
            None if threebody_coeffs is None else _zero_like_coeffs(threebody_coeffs)
        )
        return preset.build_model(zero_pair, zero_threebody)

    return _ScenarioData(
        preset=preset,
        device=device,
        dtype=dtype,
        checkpoint=checkpoint,
        precomputed_neighbor_lists=precompute_neighbor_lists,
        loss_weights=preset.loss_weights,
        train_atoms=train_atoms,
        validation_atoms=validation_atoms,
        test_atoms=test_atoms,
        train_loader=train_loader,
        validation_loader=validation_loader,
        test_loader=test_loader,
        fit_samples=fit_samples,
        teacher_theta=teacher_theta,
        make_student_model=make_student_model,
    )


def _problem_normal_diagonal(problem: BlockLinearProblem) -> torch.Tensor:
    """Return the diagonal of the unregularized normal matrix."""
    diagonal = torch.zeros(
        (problem.layout.size,),
        dtype=problem.dtype,
        device=problem.device,
    )
    for batch in problem.batches:
        for key, block_matrix in batch.matrices.items():
            diagonal[problem.layout.theta_slice(key)] += _block_matrix_diagonal(
                block_matrix
            )
    return diagonal


def _problem_matrix_storage_elements(problem: BlockLinearProblem) -> int:
    """Return stored matrix value/index elements for a least-squares problem."""
    return int(
        sum(_block_solve_batch_storage_elements(batch) for batch in problem.batches)
    )


def _problem_matrix_storage_nbytes(problem: BlockLinearProblem) -> int:
    """Return approximate matrix storage bytes for a least-squares problem."""
    return int(
        sum(_block_solve_batch_storage_nbytes(batch) for batch in problem.batches)
    )


def _native_extension_availability() -> dict[str, bool]:
    """Return optional native extension availability relevant to least-squares."""
    return {
        "threebody_feature_cache_cpu": native_threebody_feature_cache_available(
            device="cpu",
            spline="cubic",
        ),
        "threebody_dense_feature_cache_cpu": (
            native_threebody_dense_feature_cache_available(device="cpu")
        ),
        "threebody_lstsq_assemble_cpu": native_threebody_lstsq_assemble_available(
            device="cpu",
            spline="cubic",
        ),
        "threebody_lstsq_assemble_cuda": (
            torch.cuda.is_available()
            and native_threebody_lstsq_assemble_available(
                device="cuda",
                spline="cubic",
            )
        ),
    }


def _cg_solve_with_checkpoints(
    problem: BlockLinearProblem,
    checkpoints: Sequence[int],
    *,
    tolerance: float,
) -> dict[int, _CGCheckpoint]:
    """Run preconditioned CG and capture selected iterate checkpoints."""
    requested = tuple(sorted({int(step) for step in checkpoints if int(step) > 0}))
    if not requested:
        return {}

    rhs = problem.rhs()
    x = torch.zeros_like(rhs)
    residual = rhs.clone()
    diagonal = problem.regularization_diagonal() + _problem_normal_diagonal(problem)
    safe_diagonal = torch.where(
        torch.abs(diagonal) > 1.0e-14,
        diagonal,
        torch.ones_like(diagonal),
    )
    z = residual / safe_diagonal
    direction = z.clone()
    rz_old = torch.dot(residual, z)

    results: dict[int, _CGCheckpoint] = {}
    solve_start = time.perf_counter()
    final_iteration = 0
    max_iter = requested[-1]

    if torch.sqrt(torch.clamp(rz_old, min=0.0)) <= tolerance:
        elapsed = time.perf_counter() - solve_start
        for step in requested:
            results[step] = _CGCheckpoint(theta=x.clone(), solve_time_s=elapsed)
        return results

    for iteration in range(1, max_iter + 1):
        mat_direction = problem.normal_matvec(direction)
        denom = torch.dot(direction, mat_direction)
        if torch.abs(denom) <= 1.0e-30:
            final_iteration = iteration - 1
            break

        alpha = rz_old / denom
        x = x + alpha * direction
        residual = residual - alpha * mat_direction
        final_iteration = iteration
        if iteration in requested:
            results[iteration] = _CGCheckpoint(
                theta=x.clone(),
                solve_time_s=time.perf_counter() - solve_start,
            )

        if torch.linalg.norm(residual) <= tolerance:
            break

        z = residual / safe_diagonal
        rz_new = torch.dot(residual, z)
        if torch.abs(rz_old) <= 1.0e-30:
            break
        beta = rz_new / rz_old
        direction = z + beta * direction
        rz_old = rz_new

    if final_iteration < max_iter:
        elapsed = time.perf_counter() - solve_start
        for step in requested:
            if step > final_iteration:
                results[step] = _CGCheckpoint(theta=x.clone(), solve_time_s=elapsed)

    return results


def _make_point(
    *,
    method: str,
    budget_kind: str,
    budget: int | None,
    label: str,
    optimize_time_s: float,
    train_loss: float,
    validation_metrics,
    test_metrics,
) -> BenchmarkPoint:
    """Package raw metrics into one benchmark record."""
    return BenchmarkPoint(
        method=method,
        budget_kind=budget_kind,
        budget=budget,
        label=label,
        optimize_time_s=float(optimize_time_s),
        train_loss=float(train_loss),
        validation_loss=float(validation_metrics.loss),
        validation_energy_mae=validation_metrics.energy_mae,
        validation_forces_mae=validation_metrics.forces_mae,
        test_loss=float(test_metrics.loss),
        test_energy_mae=test_metrics.energy_mae,
        test_forces_mae=test_metrics.forces_mae,
    )


def _evaluate_split_metrics(
    model: UFPModel,
    scenario: _ScenarioData,
    *,
    device: torch.device | None = None,
):
    """Evaluate split metrics."""
    eval_device = scenario.device if device is None else device
    validation_metrics = evaluate_model(
        model,
        scenario.validation_loader,
        split="validation",
        dtype=scenario.dtype,
        device=eval_device,
        loss_weights=scenario.loss_weights,
    )
    test_metrics = evaluate_model(
        model,
        scenario.test_loader,
        split="test",
        dtype=scenario.dtype,
        device=eval_device,
        loss_weights=scenario.loss_weights,
    )
    return validation_metrics, test_metrics


def _run_training_benchmark(
    scenario: _ScenarioData,
    *,
    epochs: int,
    learning_rate: float,
) -> tuple[BenchmarkPoint, ...]:
    """Run training benchmark."""
    model = scenario.make_student_model()
    model.to(device=scenario.device, dtype=scenario.dtype)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    checkpoints = {0, *range(1, epochs + 1)}
    records: list[BenchmarkPoint] = []

    train_metrics = evaluate_model(
        model,
        scenario.train_loader,
        split="train",
        dtype=scenario.dtype,
        device=scenario.device,
        loss_weights=scenario.loss_weights,
    )
    validation_metrics, test_metrics = _evaluate_split_metrics(model, scenario)
    records.append(
        _make_point(
            method="training_adam",
            budget_kind="epoch",
            budget=0,
            label="epoch=0",
            optimize_time_s=0.0,
            train_loss=train_metrics.loss,
            validation_metrics=validation_metrics,
            test_metrics=test_metrics,
        )
    )

    optimize_time_s = 0.0
    for epoch in range(1, epochs + 1):
        start = time.perf_counter()
        train_metrics = train_one_epoch(
            model,
            scenario.train_loader,
            optimizer=optimizer,
            dtype=scenario.dtype,
            device=scenario.device,
            loss_weights=scenario.loss_weights,
        )
        optimize_time_s += time.perf_counter() - start
        if epoch not in checkpoints:
            continue

        validation_metrics, test_metrics = _evaluate_split_metrics(model, scenario)
        records.append(
            _make_point(
                method="training_adam",
                budget_kind="epoch",
                budget=epoch,
                label=f"epoch={epoch}",
                optimize_time_s=optimize_time_s,
                train_loss=train_metrics.loss,
                validation_metrics=validation_metrics,
                test_metrics=test_metrics,
            )
        )

    return tuple(records)


def _run_leastsquares_benchmark(
    scenario: _ScenarioData,
    *,
    cg_checkpoints: Sequence[int],
    leastsquares_batch_size: int,
    direct_solver: str,
    assembly_contract: str,
    matrix_storage: str,
) -> tuple[tuple[BenchmarkPoint, ...], float, float, int, int, int]:
    """Run direct and checkpointed least-squares benchmark solves."""
    leastsquares_device = torch.device("cpu")
    model = scenario.make_student_model()
    model.to(device=leastsquares_device, dtype=scenario.dtype)
    fitter = LinearFitter(
        model,
        fit_energy=True,
        fit_forces=True,
        solver="cg",
        dtype=scenario.dtype,
        device=leastsquares_device,
        assembly_contract=assembly_contract,
        matrix_storage=matrix_storage,
    )

    initial_train = evaluate_model(
        model,
        scenario.train_loader,
        split="train",
        dtype=scenario.dtype,
        device=leastsquares_device,
        loss_weights=scenario.loss_weights,
    )
    initial_validation, initial_test = _evaluate_split_metrics(
        model,
        scenario,
        device=leastsquares_device,
    )
    records = [
        _make_point(
            method="leastsquares_cg",
            budget_kind="iteration",
            budget=0,
            label="iter=0",
            optimize_time_s=0.0,
            train_loss=initial_train.loss,
            validation_metrics=initial_validation,
            test_metrics=initial_test,
        )
    ]

    build_start = time.perf_counter()
    problem = fitter.build_problem(
        scenario.fit_samples,
        batch_size=leastsquares_batch_size,
    )
    build_time_s = time.perf_counter() - build_start
    n_rows = problem.n_rows
    matrix_storage_elements = _problem_matrix_storage_elements(problem)
    matrix_storage_bytes = _problem_matrix_storage_nbytes(problem)

    checkpoint_states = _cg_solve_with_checkpoints(
        problem,
        cg_checkpoints,
        tolerance=1.0e-10,
    )
    solve_time_s = max(
        (state.solve_time_s for state in checkpoint_states.values()),
        default=0.0,
    )
    for iteration in sorted(checkpoint_states):
        state = checkpoint_states[iteration]
        fitter.write_back(state.theta)
        validation_metrics, test_metrics = _evaluate_split_metrics(
            model,
            scenario,
            device=leastsquares_device,
        )
        records.append(
            _make_point(
                method="leastsquares_cg",
                budget_kind="iteration",
                budget=iteration,
                label=f"iter={iteration}",
                optimize_time_s=build_time_s + state.solve_time_s,
                train_loss=float(problem.objective(state.theta).item()),
                validation_metrics=validation_metrics,
                test_metrics=test_metrics,
            )
        )

    if direct_solver != "none":
        direct_start = time.perf_counter()
        direct_theta = problem.solve(
            solver=direct_solver,
            cg_tolerance=1.0e-10,
            cg_max_iter=None,
        )
        direct_solve_time_s = time.perf_counter() - direct_start
        solve_time_s = direct_solve_time_s
        fitter.write_back(direct_theta)
        direct_validation, direct_test = _evaluate_split_metrics(
            model,
            scenario,
            device=leastsquares_device,
        )
        records.append(
            _make_point(
                method="leastsquares_direct",
                budget_kind="solver",
                budget=None,
                label=direct_solver,
                optimize_time_s=build_time_s + direct_solve_time_s,
                train_loss=float(problem.objective(direct_theta).item()),
                validation_metrics=direct_validation,
                test_metrics=direct_test,
            )
        )

    return (
        tuple(records),
        build_time_s,
        solve_time_s,
        n_rows,
        matrix_storage_elements,
        matrix_storage_bytes,
    )


[docs] def run_leastsquares_vs_training_benchmark( *, scenario: str = "triangle_pair_threebody", seed: int = 0, device: str | torch.device | None = None, dtype: str | torch.dtype | None = "auto", checkpoint: str = "custom", precompute_neighbor_lists: bool = False, train_size: int | None = None, validation_size: int | None = None, test_size: int | None = None, training_batch_size: int | None = None, leastsquares_batch_size: int | None = None, training_epochs: int | None = None, learning_rate: float | None = None, cg_checkpoints: Sequence[int] | None = None, direct_solver: str = "none", assembly_contract: str = "term", matrix_storage: str = "auto", ) -> BenchmarkResult: """Run one deterministic benchmark scenario and return checkpointed results.""" resolved_device = resolve_device(device) resolved_dtype = resolve_dtype(resolved_device, dtype) scenario_data = _build_scenario( scenario, seed=seed, device=resolved_device, dtype=resolved_dtype, checkpoint=checkpoint, precompute_neighbor_lists=precompute_neighbor_lists, train_size=train_size, validation_size=validation_size, test_size=test_size, training_batch_size=training_batch_size, ) preset = scenario_data.preset resolved_training_epochs = ( preset.workload.training_epochs if training_epochs is None else int(training_epochs) ) resolved_learning_rate = ( preset.workload.learning_rate if learning_rate is None else float(learning_rate) ) resolved_cg_checkpoints = ( preset.workload.cg_checkpoints if cg_checkpoints is None else tuple(sorted({int(step) for step in cg_checkpoints if int(step) > 0})) ) resolved_lstsq_batch_size = ( scenario_data.train_loader.batch_size if leastsquares_batch_size is None else int(leastsquares_batch_size) ) if resolved_training_epochs <= 0 or resolved_lstsq_batch_size <= 0: raise ValueError( "training epochs and least-squares batch size must be positive" ) training_records = _run_training_benchmark( scenario_data, epochs=resolved_training_epochs, learning_rate=resolved_learning_rate, ) ( leastsquares_records, build_time_s, solve_time_s, n_rows, matrix_storage_elements, matrix_storage_bytes, ) = _run_leastsquares_benchmark( scenario_data, cg_checkpoints=resolved_cg_checkpoints, leastsquares_batch_size=resolved_lstsq_batch_size, direct_solver=direct_solver, assembly_contract=assembly_contract, matrix_storage=matrix_storage, ) n_parameters = int(scenario_data.teacher_theta.numel()) return BenchmarkResult( scenario=preset.name, description=preset.description, seed=seed, checkpoint=checkpoint, device=str(resolved_device), leastsquares_device="cpu", dtype=str(resolved_dtype).replace("torch.", ""), precomputed_neighbor_lists=precompute_neighbor_lists, n_train=len(scenario_data.train_atoms), n_validation=len(scenario_data.validation_atoms), n_test=len(scenario_data.test_atoms), n_parameters=n_parameters, n_rows=n_rows, training_batch_size=int(scenario_data.train_loader.batch_size), leastsquares_batch_size=resolved_lstsq_batch_size, training_epochs=resolved_training_epochs, cg_checkpoints=tuple(resolved_cg_checkpoints), loss_weights={ "energy": float(scenario_data.loss_weights.energy), "forces": float(scenario_data.loss_weights.forces), "stress": float(scenario_data.loss_weights.stress), }, leastsquares_build_time_s=build_time_s, leastsquares_solve_time_s=solve_time_s, leastsquares_total_time_s=build_time_s + solve_time_s, leastsquares_matrix_storage_elements=matrix_storage_elements, leastsquares_matrix_storage_bytes=matrix_storage_bytes, assembly_contract=str(assembly_contract), matrix_storage=str(matrix_storage), native_extensions=_native_extension_availability(), direct_solver=direct_solver, records=training_records + leastsquares_records, )
[docs] def run_benchmark_checkpoints( *, checkpoints: Sequence[str], scenario: str = "triangle_pair_threebody", seed: int = 0, device: str | torch.device | None = None, train_size: int | None = None, validation_size: int | None = None, test_size: int | None = None, training_batch_size: int | None = None, leastsquares_batch_size: int | None = None, training_epochs: int | None = None, learning_rate: float | None = None, cg_checkpoints: Sequence[int] | None = None, direct_solver: str = "none", assembly_contract: str = "term", matrix_storage: str = "auto", ) -> tuple[BenchmarkResult, ...]: """Run benchmark checkpoints.""" results: list[BenchmarkResult] = [] for checkpoint_name in checkpoints: if checkpoint_name not in _CHECKPOINTS: choices = ", ".join(sorted(_CHECKPOINTS)) raise ValueError( f"unknown checkpoint '{checkpoint_name}'. Expected one of: {choices}" ) spec = _CHECKPOINTS[checkpoint_name] results.append( run_leastsquares_vs_training_benchmark( scenario=scenario, seed=seed, device=device, dtype=spec.dtype, checkpoint=spec.name, precompute_neighbor_lists=spec.precompute_neighbor_lists, train_size=train_size, validation_size=validation_size, test_size=test_size, training_batch_size=training_batch_size, leastsquares_batch_size=leastsquares_batch_size, training_epochs=training_epochs, learning_rate=learning_rate, cg_checkpoints=cg_checkpoints, direct_solver=direct_solver, assembly_contract=assembly_contract, matrix_storage=matrix_storage, ) ) return tuple(results)
[docs] def available_benchmark_scenarios() -> tuple[str, ...]: """Return the names of registered toy microbenchmark scenarios.""" return tuple(sorted(_SCENARIOS))
[docs] def available_benchmark_checkpoints() -> tuple[str, ...]: """Return the names of registered benchmark checkpoint presets.""" return tuple(sorted(_CHECKPOINTS))
[docs] def format_benchmark_report(result: BenchmarkResult) -> str: """Format benchmark report.""" lines = [ f"Scenario: {result.scenario}", f"Description: {result.description}", ( "Train/validation/test: " f"{result.n_train}/{result.n_validation}/{result.n_test} systems" ), f"Checkpoint: {result.checkpoint}", f"Training device: {result.device}", f"Least-squares device: {result.leastsquares_device}", f"Dtype: {result.dtype}", ( "Neighbor lists: " f"{'precomputed' if result.precomputed_neighbor_lists else 'dynamic'}" ), ( "Rows/parameters: " f"{result.n_rows}/{result.n_parameters}; " f"loss weights energy={result.loss_weights['energy']}, " f"forces={result.loss_weights['forces']}" ), ( "Training batch size: " f"{result.training_batch_size}; least-squares batch size: " f"{result.leastsquares_batch_size}; " f"least-squares build time: {result.leastsquares_build_time_s:.6f}s" ), ( "Least-squares solve/total time: " f"{result.leastsquares_solve_time_s:.6f}s/" f"{result.leastsquares_total_time_s:.6f}s; " f"assembly={result.assembly_contract}; storage={result.matrix_storage}" ), ( "Matrix storage: " f"{result.leastsquares_matrix_storage_elements} value/index elements; " f"{result.leastsquares_matrix_storage_bytes} bytes" ), ( "Native extensions: " + ", ".join( f"{key}={value}" for key, value in sorted(result.native_extensions.items()) ) ), "", ] headers = ( "method", "label", "opt_s", "train_loss", "val_loss", "val_E_mae", "val_F_mae", "test_E_mae", "test_F_mae", ) rows = [ ( point.method, point.label, format_number(point.optimize_time_s), format_number(point.train_loss), format_number(point.validation_loss), format_number(point.validation_energy_mae), format_number(point.validation_forces_mae), format_number(point.test_energy_mae), format_number(point.test_forces_mae), ) for point in result.records ] widths = [ max(len(header), *(len(row[index]) for row in rows)) for index, header in enumerate(headers) ] header_line = " ".join( header.ljust(widths[index]) for index, header in enumerate(headers) ) separator_line = " ".join("-" * widths[index] for index in range(len(headers))) lines.extend([header_line, separator_line]) lines.extend( " ".join(cell.ljust(widths[index]) for index, cell in enumerate(row)) for row in rows ) return "\n".join(lines)
def _build_arg_parser() -> argparse.ArgumentParser: """Build the command-line parser for the benchmark script.""" parser = argparse.ArgumentParser( description=( "Benchmark convergence of ufp.leastsquares against ufp.training on " "deterministic toy problems." ) ) parser.add_argument( "--scenario", choices=scenario_choices(available_benchmark_scenarios()), default="triangle_pair_threebody", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--device", choices=["auto", "cpu", "cuda", "gpu"], default="auto", ) parser.add_argument( "--dtype", choices=["auto", "float32", "float64"], default="auto", ) parser.add_argument( "--checkpoint-name", default="custom", help="Label used for a single benchmark run.", ) parser.add_argument( "--precompute-neighbor-lists", action="store_true", help="Reuse one precomputed neighbor list per structure.", ) parser.add_argument( "--checkpoint", action="append", choices=available_benchmark_checkpoints(), help="Run one of the named A/B benchmark checkpoints. Can be repeated.", ) parser.add_argument("--train-size", type=int) parser.add_argument("--validation-size", type=int) parser.add_argument("--test-size", type=int) parser.add_argument("--training-batch-size", type=int) parser.add_argument("--leastsquares-batch-size", type=int) parser.add_argument("--training-epochs", type=int) parser.add_argument("--learning-rate", type=float) parser.add_argument( "--cg-checkpoints", type=lambda value: parse_positive_int_sequence( value, label="checkpoint lists", ), ) parser.add_argument( "--direct-solver", choices=["none", "normal_equation_direct", "dense_lstsq"], default="none", ) parser.add_argument( "--assembly-contract", choices=["block", "term"], default="term", help="Least-squares assembly dispatch contract to benchmark.", ) parser.add_argument( "--matrix-storage", choices=["dense", "row_indexed", "column_chunked", "auto"], default="auto", help="In-memory block matrix storage representation to benchmark.", ) parser.add_argument( "--json", action="store_true", help="Emit JSON instead of the formatted text report.", ) return parser
[docs] def main(argv: Sequence[str] | None = None) -> int: """Run the module's command-line entry point.""" parser = _build_arg_parser() args = parser.parse_args(argv) scenario_names = ( sorted(_SCENARIOS) if args.scenario == "all" else [str(args.scenario)] ) if args.checkpoint: results = [ result for scenario_name in scenario_names for result in run_benchmark_checkpoints( checkpoints=args.checkpoint, scenario=scenario_name, seed=args.seed, device=args.device, train_size=args.train_size, validation_size=args.validation_size, test_size=args.test_size, training_batch_size=args.training_batch_size, leastsquares_batch_size=args.leastsquares_batch_size, training_epochs=args.training_epochs, learning_rate=args.learning_rate, cg_checkpoints=args.cg_checkpoints, direct_solver=args.direct_solver, assembly_contract=args.assembly_contract, matrix_storage=args.matrix_storage, ) ] else: results = [ run_leastsquares_vs_training_benchmark( scenario=scenario_name, seed=args.seed, device=args.device, dtype=args.dtype, checkpoint=args.checkpoint_name, precompute_neighbor_lists=args.precompute_neighbor_lists, train_size=args.train_size, validation_size=args.validation_size, test_size=args.test_size, training_batch_size=args.training_batch_size, leastsquares_batch_size=args.leastsquares_batch_size, training_epochs=args.training_epochs, learning_rate=args.learning_rate, cg_checkpoints=args.cg_checkpoints, direct_solver=args.direct_solver, assembly_contract=args.assembly_contract, matrix_storage=args.matrix_storage, ) for scenario_name in scenario_names ] if args.json: payload = [asdict(result) for result in results] if len(payload) == 1: print(json.dumps(payload[0], indent=2)) else: print(json.dumps(payload, indent=2)) return 0 print("\n\n".join(format_benchmark_report(result) for result in results)) return 0
__all__ = [ "available_benchmark_checkpoints", "available_benchmark_scenarios", "BenchmarkPoint", "BenchmarkResult", "BenchmarkCheckpoint", "format_benchmark_report", "main", "run_benchmark_checkpoints", "run_leastsquares_vs_training_benchmark", ]