"""Dynamic three-body inference microbenchmarks."""
from __future__ import annotations
import time
from dataclasses import dataclass
from pathlib import Path
import torch
from ufp.benchmarks._common import resolve_device, resolve_dtype
from ufp.neighbors import build_neighbor_list
from ufp.terms._threebody_kernels import (
native_threebody_backend_available,
preprocess_sources_native_or_torch,
)
from ufp.terms._threebody_ops import build_edge_category_table
from ufp.terms._threebody_runtime import (
ThreeBodyRuntimeConfig,
resolve_threebody_runtime_config,
)
from ufp.terms.threebody import (
_build_dense_feature_cache_from_feature_cache,
_build_feature_cache_from_buckets,
_evaluate_dense_feature_cache_energy_forces,
evaluate_bucketed_energy_forces,
)
[docs]
@dataclass(frozen=True)
class ThreeBodyDynamicBenchmarkResult:
"""Timing result for one dynamic three-body evaluator backend."""
scenario: str
backend: str
device: str
dtype: str
n_nodes: int
n_systems: int
n_categories: int
n_sources: int
mean_degree: int
n_pattern_plans: int
n_triplet_layout_entries: int
median_time_s: float
min_time_s: float
[docs]
@dataclass(frozen=True)
class ThreeBodyDynamicBreakdownBenchmarkResult:
"""End-to-end dynamic inference timing split into bucket and evaluator stages."""
scenario: str
backend: str
device: str
dtype: str
n_nodes: int
n_systems: int
n_categories: int
n_sources: int
mean_degree: int
n_pattern_plans: int
n_triplet_layout_entries: int
median_bucket_build_time_s: float
median_evaluator_time_s: float
median_total_time_s: float
[docs]
@dataclass(frozen=True)
class ThreeBodyCacheBenchmarkResult:
"""Timing split for fixed-geometry three-body feature-cache generation."""
scenario: str
backend: str
device: str
dtype: str
n_nodes: int
n_systems: int
n_categories: int
n_sources: int
mean_degree: int
n_pattern_plans: int
n_triplet_layout_entries: int
median_bucket_build_time_s: float
median_sparse_feature_time_s: float
median_dense_feature_time_s: float
median_cache_build_time_s: float
median_cached_evaluator_time_s: float
@dataclass(frozen=True)
class _DynamicScenario:
name: str
n_nodes: int
n_categories: int
n_sources: int
degree: int
coeff_size: int
xyz_filename: str | None = None
cutoff: float = 1.8
first_knot_xy: float = -0.75
first_knot_z: float = -0.75
knot_spacing_xy: float = 0.25
knot_spacing_z: float = 0.25
lower_support_xy: float = 0.0
lower_support_z: float = 0.0
@dataclass(frozen=True)
class _RawWorkload:
name: str
first: torch.Tensor
second: torch.Tensor
node_cat: torch.Tensor
pair_vectors: torch.Tensor
pair_distances: torch.Tensor
system_index: torch.Tensor
coeffs: torch.Tensor
edge_cat_table: torch.Tensor
n_nodes: int
n_systems: int
n_categories: int
n_sources: int
mean_degree: int
coeff_size: int
first_knot_xy: float
first_knot_z: float
knot_spacing_xy: float
knot_spacing_z: float
lower_support_xy: float
lower_support_z: float
_BENCHMARK_DIR = Path(__file__).resolve().parent
_SCENARIOS = {
"triangle_pair_threebody": _DynamicScenario(
name="triangle_pair_threebody",
n_nodes=96,
n_categories=1,
n_sources=96,
degree=8,
coeff_size=8,
),
"ternary_alloy": _DynamicScenario(
name="ternary_alloy",
n_nodes=192,
n_categories=3,
n_sources=192,
degree=14,
coeff_size=9,
),
"high_degree_cluster": _DynamicScenario(
name="high_degree_cluster",
n_nodes=256,
n_categories=4,
n_sources=192,
degree=28,
coeff_size=10,
),
"molecules_xyz": _DynamicScenario(
name="molecules_xyz",
n_nodes=0,
n_categories=0,
n_sources=0,
degree=0,
coeff_size=12,
xyz_filename="molecules.xyz",
cutoff=2.2,
first_knot_xy=-0.75,
first_knot_z=-0.75,
knot_spacing_xy=0.25,
knot_spacing_z=0.25,
),
"unary_xyz": _DynamicScenario(
name="unary_xyz",
n_nodes=0,
n_categories=0,
n_sources=0,
degree=0,
coeff_size=14,
xyz_filename="unary.xyz",
cutoff=5.0,
first_knot_xy=-1.0,
first_knot_z=-1.0,
knot_spacing_xy=0.5,
knot_spacing_z=0.5,
),
"alchemical_xyz": _DynamicScenario(
name="alchemical_xyz",
n_nodes=0,
n_categories=0,
n_sources=0,
degree=0,
coeff_size=14,
xyz_filename="alchemical.xyz",
cutoff=5.0,
first_knot_xy=-1.0,
first_knot_z=-1.0,
knot_spacing_xy=0.5,
knot_spacing_z=0.5,
),
}
[docs]
def available_threebody_dynamic_scenarios() -> tuple[str, ...]:
"""Return available dynamic three-body microbenchmark scenario names."""
return tuple(sorted(_SCENARIOS))
def _synchronize(device: torch.device) -> None:
"""Synchronize accelerator work before reading wall-clock time."""
if device.type == "cuda":
torch.cuda.synchronize(device)
def _build_workload(
scenario: _DynamicScenario,
*,
device: torch.device,
dtype: torch.dtype,
seed: int,
runtime_config: ThreeBodyRuntimeConfig,
):
"""Build one deterministic bucketed dynamic-inference workload."""
raw = _build_raw_workload(
scenario,
device=device,
dtype=dtype,
seed=seed,
)
buckets = _with_pattern_metadata(
_build_buckets(raw, runtime_config=runtime_config),
device,
)
return (
raw,
buckets,
raw.node_cat,
raw.coeffs,
raw.edge_cat_table,
)
def _build_raw_workload(
scenario: _DynamicScenario,
*,
device: torch.device,
dtype: torch.dtype,
seed: int,
) -> _RawWorkload:
"""Build deterministic pair tensors before three-body bucketing."""
if scenario.xyz_filename is not None:
return _build_xyz_workload(
scenario,
device=device,
dtype=dtype,
seed=seed,
)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
src_ids = torch.arange(scenario.n_sources, device=device, dtype=torch.int64)
first = torch.repeat_interleave(src_ids, scenario.degree)
second = torch.randint(
0,
scenario.n_nodes,
(first.numel(),),
device=device,
generator=generator,
dtype=torch.int64,
)
duplicate = second == first
second = torch.where(duplicate, (second + 1) % scenario.n_nodes, second)
node_cat = torch.randint(
0,
scenario.n_categories,
(scenario.n_nodes,),
device=device,
generator=generator,
dtype=torch.int64,
)
directions = torch.randn((first.numel(), 3), device=device, dtype=dtype)
directions = directions / torch.linalg.vector_norm(
directions,
dim=1,
keepdim=True,
).clamp_min(torch.finfo(dtype).eps)
distances = 0.25 + 0.45 * torch.rand(
(first.numel(),),
device=device,
dtype=dtype,
generator=generator,
)
pair_vectors = directions * distances[:, None]
n_triplet_categories = scenario.n_categories * (
scenario.n_categories * (scenario.n_categories + 1) // 2
)
coeffs = torch.randn(
(
n_triplet_categories,
scenario.coeff_size,
scenario.coeff_size,
scenario.coeff_size,
),
device=device,
dtype=dtype,
generator=generator,
)
return _RawWorkload(
name=scenario.name,
first=first,
second=second,
node_cat=node_cat,
pair_vectors=pair_vectors,
pair_distances=distances,
system_index=torch.zeros((scenario.n_nodes,), device=device, dtype=torch.int64),
coeffs=coeffs,
edge_cat_table=build_edge_category_table(
scenario.n_categories,
device=device,
),
n_nodes=scenario.n_nodes,
n_systems=1,
n_categories=scenario.n_categories,
n_sources=scenario.n_sources,
mean_degree=scenario.degree,
coeff_size=scenario.coeff_size,
first_knot_xy=scenario.first_knot_xy,
first_knot_z=scenario.first_knot_z,
knot_spacing_xy=scenario.knot_spacing_xy,
knot_spacing_z=scenario.knot_spacing_z,
lower_support_xy=scenario.lower_support_xy,
lower_support_z=scenario.lower_support_z,
)
def _read_extxyz_frames(path: Path):
"""Read an extended XYZ file with ASE."""
try:
from ase.io import read
except ImportError as exc:
raise RuntimeError(
"ASE is required for file-backed three-body benchmarks"
) from exc
return read(path, ":")
def _atomic_categories(
atomic_numbers: torch.Tensor,
) -> tuple[torch.Tensor, int]:
"""Map atomic numbers to compact sorted categories."""
atomic_types = torch.unique(atomic_numbers.cpu(), sorted=True)
categories = torch.full_like(atomic_numbers, -1, dtype=torch.int64)
for category, atomic_number in enumerate(atomic_types.tolist()):
categories[atomic_numbers == int(atomic_number)] = category
return categories, int(atomic_types.numel())
def _build_xyz_workload(
scenario: _DynamicScenario,
*,
device: torch.device,
dtype: torch.dtype,
seed: int,
) -> _RawWorkload:
"""Build one deterministic workload from an extended XYZ benchmark file."""
path = _BENCHMARK_DIR / str(scenario.xyz_filename)
frames = _read_extxyz_frames(path)
if not frames:
raise RuntimeError(f"benchmark XYZ file is empty: {path}")
first_parts: list[torch.Tensor] = []
second_parts: list[torch.Tensor] = []
vector_parts: list[torch.Tensor] = []
distance_parts: list[torch.Tensor] = []
atomic_number_parts: list[torch.Tensor] = []
system_index_parts: list[torch.Tensor] = []
atom_offset = 0
for system_id, atoms in enumerate(frames):
neighbor_list = build_neighbor_list(
atoms,
cutoff=scenario.cutoff,
arrays="torch",
full_list=True,
sorted=True,
)
pairs = neighbor_list.pairs.to(dtype=torch.int64)
first_parts.append(pairs[0] + atom_offset)
second_parts.append(pairs[1] + atom_offset)
if neighbor_list.vectors is None or neighbor_list.distances is None:
raise RuntimeError("benchmark neighbor list did not return pair geometry")
vector_parts.append(neighbor_list.vectors.to(dtype=dtype))
distance_parts.append(neighbor_list.distances.to(dtype=dtype))
atomic_numbers = torch.as_tensor(atoms.numbers, dtype=torch.int64)
atomic_number_parts.append(atomic_numbers)
system_index_parts.append(
torch.full((len(atoms),), system_id, dtype=torch.int64)
)
atom_offset += len(atoms)
first = torch.cat(first_parts).to(device=device)
second = torch.cat(second_parts).to(device=device)
pair_vectors = torch.cat(vector_parts).to(device=device, dtype=dtype)
pair_distances = torch.cat(distance_parts).to(device=device, dtype=dtype)
atomic_numbers = torch.cat(atomic_number_parts).to(device=device)
node_cat, n_categories = _atomic_categories(atomic_numbers)
node_cat = node_cat.to(device=device)
system_index = torch.cat(system_index_parts).to(device=device)
n_triplet_categories = n_categories * (n_categories * (n_categories + 1) // 2)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
coeffs = torch.randn(
(
n_triplet_categories,
scenario.coeff_size,
scenario.coeff_size,
scenario.coeff_size,
),
device=device,
dtype=dtype,
generator=generator,
)
n_nodes = int(atomic_numbers.numel())
mean_degree = int(round(float(first.numel()) / max(n_nodes, 1)))
return _RawWorkload(
name=scenario.name,
first=first,
second=second,
node_cat=node_cat,
pair_vectors=pair_vectors,
pair_distances=pair_distances,
system_index=system_index,
coeffs=coeffs,
edge_cat_table=build_edge_category_table(n_categories, device=device),
n_nodes=n_nodes,
n_systems=len(frames),
n_categories=n_categories,
n_sources=n_nodes,
mean_degree=mean_degree,
coeff_size=scenario.coeff_size,
first_knot_xy=scenario.first_knot_xy,
first_knot_z=scenario.first_knot_z,
knot_spacing_xy=scenario.knot_spacing_xy,
knot_spacing_z=scenario.knot_spacing_z,
lower_support_xy=scenario.lower_support_xy,
lower_support_z=scenario.lower_support_z,
)
def _build_buckets(
raw: _RawWorkload,
*,
runtime_config: ThreeBodyRuntimeConfig,
):
"""Build reusable source buckets from a raw synthetic workload."""
return preprocess_sources_native_or_torch(
raw.first,
raw.second,
raw.node_cat,
raw.n_categories,
raw.pair_vectors,
raw.pair_distances,
runtime_config=runtime_config,
)
def _with_pattern_metadata(buckets, device: torch.device):
"""Attach reusable pattern metadata when native preprocessing did not."""
if buckets.tensor_pattern_plans is not None:
return buckets
return buckets.with_pattern_plans(device)
def _timing_stats(timings: list[float]) -> tuple[float, float]:
"""Return median and minimum timing values."""
timing_tensor = torch.tensor(timings, dtype=torch.float64)
return float(torch.median(timing_tensor).item()), float(
torch.min(timing_tensor).item()
)
[docs]
def run_threebody_dynamic_inference_benchmark(
*,
scenario: str = "ternary_alloy",
backend: str = "torch",
bucket_backend: str | None = None,
runtime_config: ThreeBodyRuntimeConfig | None = None,
device: str | torch.device | None = None,
dtype: str | torch.dtype | None = "auto",
seed: int = 0,
repeats: int = 20,
warmup: int = 5,
) -> ThreeBodyDynamicBenchmarkResult:
"""Run one dynamic three-body evaluator microbenchmark."""
if scenario not in _SCENARIOS:
choices = ", ".join(available_threebody_dynamic_scenarios())
raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}")
if repeats <= 0 or warmup < 0:
raise ValueError("`repeats` must be positive and `warmup` non-negative")
config = resolve_threebody_runtime_config(
runtime_config,
dynamic_backend=None if runtime_config is not None else backend,
bucket_backend=bucket_backend,
)
backend = config.dynamic_backend
resolved_device = resolve_device(device)
resolved_dtype = resolve_dtype(resolved_device, dtype)
scenario_data = _SCENARIOS[scenario]
raw, buckets, node_cat, coeffs, edge_cat_table = _build_workload(
scenario_data,
device=resolved_device,
dtype=resolved_dtype,
seed=seed,
runtime_config=config,
)
def evaluate() -> tuple[torch.Tensor, torch.Tensor]:
return evaluate_bucketed_energy_forces(
buckets,
node_cat,
coeffs,
edge_cat_table,
spline="cubic",
n_nodes=raw.n_nodes,
n_cat=raw.n_categories,
first_knot_xy=raw.first_knot_xy,
first_knot_z=raw.first_knot_z,
knot_spacing_xy=raw.knot_spacing_xy,
knot_spacing_z=raw.knot_spacing_z,
lower_support_xy=raw.lower_support_xy,
lower_support_z=raw.lower_support_z,
runtime_config=config,
)
timings = []
for _ in range(warmup):
evaluate()
_synchronize(resolved_device)
for _ in range(repeats):
start = time.perf_counter()
evaluate()
_synchronize(resolved_device)
timings.append(time.perf_counter() - start)
median_time_s, min_time_s = _timing_stats(timings)
tensor_plans = buckets.tensor_pattern_plans
n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel())
return ThreeBodyDynamicBenchmarkResult(
scenario=raw.name,
backend=backend,
device=str(resolved_device),
dtype=str(resolved_dtype).replace("torch.", ""),
n_nodes=raw.n_nodes,
n_systems=raw.n_systems,
n_categories=raw.n_categories,
n_sources=raw.n_sources,
mean_degree=raw.mean_degree,
n_pattern_plans=int(buckets.patterns.shape[0]),
n_triplet_layout_entries=n_layout,
median_time_s=median_time_s,
min_time_s=min_time_s,
)
[docs]
def run_threebody_dynamic_breakdown_benchmark(
*,
scenario: str = "ternary_alloy",
backend: str = "torch",
bucket_backend: str | None = None,
runtime_config: ThreeBodyRuntimeConfig | None = None,
device: str | torch.device | None = None,
dtype: str | torch.dtype | None = "auto",
seed: int = 0,
repeats: int = 20,
warmup: int = 5,
) -> ThreeBodyDynamicBreakdownBenchmarkResult:
"""Run dynamic inference with separate bucket-build and evaluator timings."""
if scenario not in _SCENARIOS:
choices = ", ".join(available_threebody_dynamic_scenarios())
raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}")
if repeats <= 0 or warmup < 0:
raise ValueError("`repeats` must be positive and `warmup` non-negative")
config = resolve_threebody_runtime_config(
runtime_config,
dynamic_backend=None if runtime_config is not None else backend,
bucket_backend=bucket_backend,
)
backend = config.dynamic_backend
resolved_device = resolve_device(device)
resolved_dtype = resolve_dtype(resolved_device, dtype)
scenario_data = _SCENARIOS[scenario]
raw = _build_raw_workload(
scenario_data,
device=resolved_device,
dtype=resolved_dtype,
seed=seed,
)
bucket_timings: list[float] = []
evaluator_timings: list[float] = []
for _ in range(warmup):
buckets = _with_pattern_metadata(
_build_buckets(raw, runtime_config=config),
resolved_device,
)
evaluate_bucketed_energy_forces(
buckets,
raw.node_cat,
raw.coeffs,
raw.edge_cat_table,
spline="cubic",
n_nodes=raw.n_nodes,
n_cat=raw.n_categories,
first_knot_xy=raw.first_knot_xy,
first_knot_z=raw.first_knot_z,
knot_spacing_xy=raw.knot_spacing_xy,
knot_spacing_z=raw.knot_spacing_z,
lower_support_xy=raw.lower_support_xy,
lower_support_z=raw.lower_support_z,
runtime_config=config,
)
_synchronize(resolved_device)
measured_buckets = None
for _ in range(repeats):
start = time.perf_counter()
buckets = _with_pattern_metadata(
_build_buckets(raw, runtime_config=config),
resolved_device,
)
_synchronize(resolved_device)
bucket_timings.append(time.perf_counter() - start)
measured_buckets = buckets
start = time.perf_counter()
evaluate_bucketed_energy_forces(
buckets,
raw.node_cat,
raw.coeffs,
raw.edge_cat_table,
spline="cubic",
n_nodes=raw.n_nodes,
n_cat=raw.n_categories,
first_knot_xy=raw.first_knot_xy,
first_knot_z=raw.first_knot_z,
knot_spacing_xy=raw.knot_spacing_xy,
knot_spacing_z=raw.knot_spacing_z,
lower_support_xy=raw.lower_support_xy,
lower_support_z=raw.lower_support_z,
runtime_config=config,
)
_synchronize(resolved_device)
evaluator_timings.append(time.perf_counter() - start)
assert measured_buckets is not None
bucket_median, _ = _timing_stats(bucket_timings)
evaluator_median, _ = _timing_stats(evaluator_timings)
tensor_plans = measured_buckets.tensor_pattern_plans
n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel())
return ThreeBodyDynamicBreakdownBenchmarkResult(
scenario=raw.name,
backend=backend,
device=str(resolved_device),
dtype=str(resolved_dtype).replace("torch.", ""),
n_nodes=raw.n_nodes,
n_systems=raw.n_systems,
n_categories=raw.n_categories,
n_sources=raw.n_sources,
mean_degree=raw.mean_degree,
n_pattern_plans=int(measured_buckets.patterns.shape[0]),
n_triplet_layout_entries=n_layout,
median_bucket_build_time_s=bucket_median,
median_evaluator_time_s=evaluator_median,
median_total_time_s=bucket_median + evaluator_median,
)
[docs]
def run_threebody_cache_benchmark(
*,
scenario: str = "ternary_alloy",
backend: str = "torch",
bucket_backend: str | None = None,
runtime_config: ThreeBodyRuntimeConfig | None = None,
device: str | torch.device | None = None,
dtype: str | torch.dtype | None = "auto",
seed: int = 0,
repeats: int = 10,
warmup: int = 2,
) -> ThreeBodyCacheBenchmarkResult:
"""Run fixed-geometry cache generation with stage-level timing."""
if scenario not in _SCENARIOS:
choices = ", ".join(available_threebody_dynamic_scenarios())
raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}")
if repeats <= 0 or warmup < 0:
raise ValueError("`repeats` must be positive and `warmup` non-negative")
config = resolve_threebody_runtime_config(
runtime_config,
dynamic_backend=None if runtime_config is not None else backend,
bucket_backend=bucket_backend,
)
backend = config.dynamic_backend
resolved_device = resolve_device(device)
resolved_dtype = resolve_dtype(resolved_device, dtype)
scenario_data = _SCENARIOS[scenario]
raw = _build_raw_workload(
scenario_data,
device=resolved_device,
dtype=resolved_dtype,
seed=seed,
)
coeff_shape = (
raw.coeff_size,
raw.coeff_size,
raw.coeff_size,
)
bucket_timings: list[float] = []
sparse_timings: list[float] = []
dense_timings: list[float] = []
cached_eval_timings: list[float] = []
for _ in range(warmup):
buckets = _with_pattern_metadata(
_build_buckets(raw, runtime_config=config),
resolved_device,
)
feature_cache = _build_feature_cache_from_buckets(
buckets,
coeff_shape,
spline="cubic",
n_cat=raw.n_categories,
first_knot_xy=raw.first_knot_xy,
first_knot_z=raw.first_knot_z,
knot_spacing_xy=raw.knot_spacing_xy,
knot_spacing_z=raw.knot_spacing_z,
runtime_config=config,
)
dense_cache = _build_dense_feature_cache_from_feature_cache(
feature_cache,
raw.system_index,
coeff_shape=coeff_shape,
runtime_config=config,
)
_evaluate_dense_feature_cache_energy_forces(
dense_cache,
raw.coeffs,
n_nodes=raw.n_nodes,
n_systems=raw.n_systems,
)
_synchronize(resolved_device)
measured_buckets = None
for _ in range(repeats):
start = time.perf_counter()
buckets = _with_pattern_metadata(
_build_buckets(raw, runtime_config=config),
resolved_device,
)
_synchronize(resolved_device)
bucket_timings.append(time.perf_counter() - start)
measured_buckets = buckets
start = time.perf_counter()
feature_cache = _build_feature_cache_from_buckets(
buckets,
coeff_shape,
spline="cubic",
n_cat=raw.n_categories,
first_knot_xy=raw.first_knot_xy,
first_knot_z=raw.first_knot_z,
knot_spacing_xy=raw.knot_spacing_xy,
knot_spacing_z=raw.knot_spacing_z,
runtime_config=config,
)
_synchronize(resolved_device)
sparse_timings.append(time.perf_counter() - start)
start = time.perf_counter()
dense_cache = _build_dense_feature_cache_from_feature_cache(
feature_cache,
raw.system_index,
coeff_shape=coeff_shape,
runtime_config=config,
)
_synchronize(resolved_device)
dense_timings.append(time.perf_counter() - start)
start = time.perf_counter()
_evaluate_dense_feature_cache_energy_forces(
dense_cache,
raw.coeffs,
n_nodes=raw.n_nodes,
n_systems=raw.n_systems,
)
_synchronize(resolved_device)
cached_eval_timings.append(time.perf_counter() - start)
assert measured_buckets is not None
bucket_median, _ = _timing_stats(bucket_timings)
sparse_median, _ = _timing_stats(sparse_timings)
dense_median, _ = _timing_stats(dense_timings)
cached_eval_median, _ = _timing_stats(cached_eval_timings)
tensor_plans = measured_buckets.tensor_pattern_plans
n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel())
return ThreeBodyCacheBenchmarkResult(
scenario=raw.name,
backend=backend,
device=str(resolved_device),
dtype=str(resolved_dtype).replace("torch.", ""),
n_nodes=raw.n_nodes,
n_systems=raw.n_systems,
n_categories=raw.n_categories,
n_sources=raw.n_sources,
mean_degree=raw.mean_degree,
n_pattern_plans=int(measured_buckets.patterns.shape[0]),
n_triplet_layout_entries=n_layout,
median_bucket_build_time_s=bucket_median,
median_sparse_feature_time_s=sparse_median,
median_dense_feature_time_s=dense_median,
median_cache_build_time_s=sparse_median + dense_median,
median_cached_evaluator_time_s=cached_eval_median,
)
__all__ = [
"ThreeBodyCacheBenchmarkResult",
"ThreeBodyDynamicBreakdownBenchmarkResult",
"ThreeBodyDynamicBenchmarkResult",
"available_threebody_dynamic_scenarios",
"native_threebody_backend_available",
"run_threebody_cache_benchmark",
"run_threebody_dynamic_breakdown_benchmark",
"run_threebody_dynamic_inference_benchmark",
]