Source code for ufp.benchmarks._common

"""Shared helpers for benchmark and microbenchmark modules."""

from __future__ import annotations

import argparse
from dataclasses import dataclass
from typing import Sequence

import torch


[docs] @dataclass(frozen=True) class BenchmarkPoint: """Single measurement collected from one benchmark method.""" method: str budget_kind: str budget: int | None label: str optimize_time_s: float train_loss: float validation_loss: float validation_energy_mae: float | None validation_forces_mae: float | None test_loss: float test_energy_mae: float | None test_forces_mae: float | None
[docs] @dataclass(frozen=True) class BenchmarkResult: """Structured result for one deterministic microbenchmark run.""" scenario: str description: str seed: int checkpoint: str device: str leastsquares_device: str dtype: str precomputed_neighbor_lists: bool n_train: int n_validation: int n_test: int n_parameters: int n_rows: int training_batch_size: int leastsquares_batch_size: int training_epochs: int cg_checkpoints: tuple[int, ...] loss_weights: dict[str, float] leastsquares_build_time_s: float leastsquares_solve_time_s: float leastsquares_total_time_s: float leastsquares_matrix_storage_elements: int leastsquares_matrix_storage_bytes: int assembly_contract: str matrix_storage: str native_extensions: dict[str, bool] direct_solver: str records: tuple[BenchmarkPoint, ...]
[docs] @dataclass(frozen=True) class BenchmarkCheckpoint: """Named run configuration for a benchmark preset.""" name: str description: str dtype: str | torch.dtype | None precompute_neighbor_lists: bool = False
[docs] @dataclass(frozen=True) class BenchmarkWorkloadDefaults: """Default workload knobs shared by toy microbenchmark scenarios.""" train_size: int validation_size: int test_size: int training_batch_size: int training_epochs: int learning_rate: float cg_checkpoints: tuple[int, ...]
[docs] def resolve_device(device: str | torch.device | None) -> torch.device: """Resolve a benchmark device specifier.""" if device is None: return torch.device("cuda" if torch.cuda.is_available() else "cpu") if isinstance(device, torch.device): return device normalized = str(device).strip().lower() if normalized == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") if normalized == "gpu": normalized = "cuda" resolved = torch.device(normalized) if resolved.type == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA was requested, but no CUDA device is available") return resolved
[docs] def resolve_dtype( device: torch.device, dtype: str | torch.dtype | None, ) -> torch.dtype: """Resolve a benchmark dtype specifier.""" if isinstance(dtype, torch.dtype): return dtype if dtype is None: return torch.get_default_dtype() normalized = str(dtype).strip().lower() if normalized == "auto": return torch.float32 if device.type == "cuda" else torch.float64 resolved = getattr(torch, normalized, None) if isinstance(resolved, torch.dtype): return resolved raise ValueError(f"unsupported dtype '{dtype}'")
[docs] def format_number(value: float | None) -> str: """Format benchmark scalars for text reports.""" if value is None: return "-" if abs(value) >= 1.0e-3: return f"{value:.6f}" return f"{value:.3e}"
[docs] def parse_positive_int_sequence(value: str, *, label: str) -> tuple[int, ...]: """Parse a comma-separated sequence of positive integers.""" pieces = [piece.strip() for piece in value.split(",")] if not pieces or any(not piece for piece in pieces): raise argparse.ArgumentTypeError(f"{label} must be comma-separated integers") try: values = tuple(sorted({int(piece) for piece in pieces if int(piece) > 0})) except ValueError as exc: raise argparse.ArgumentTypeError(str(exc)) from exc if not values: raise argparse.ArgumentTypeError(f"{label} must contain positive integers") return values
[docs] def scenario_choices(names: Sequence[str]) -> list[str]: """Build argparse choices for a scenario registry.""" return ["all", *sorted(names)]
__all__ = [ "BenchmarkCheckpoint", "BenchmarkPoint", "BenchmarkResult", "BenchmarkWorkloadDefaults", "format_number", "parse_positive_int_sequence", "resolve_device", "resolve_dtype", "scenario_choices", ]