Source code for ufp.benchmarks._threebody_dynamic

"""Dynamic three-body inference microbenchmarks."""

from __future__ import annotations

import time
from dataclasses import dataclass
from pathlib import Path

import torch

from ufp.benchmarks._common import resolve_device, resolve_dtype
from ufp.neighbors import build_neighbor_list
from ufp.terms._threebody_kernels import (
    native_threebody_backend_available,
    preprocess_sources_native_or_torch,
)
from ufp.terms._threebody_ops import build_edge_category_table
from ufp.terms._threebody_runtime import (
    ThreeBodyRuntimeConfig,
    resolve_threebody_runtime_config,
)
from ufp.terms.threebody import (
    _build_dense_feature_cache_from_feature_cache,
    _build_feature_cache_from_buckets,
    _evaluate_dense_feature_cache_energy_forces,
    evaluate_bucketed_energy_forces,
)


[docs] @dataclass(frozen=True) class ThreeBodyDynamicBenchmarkResult: """Timing result for one dynamic three-body evaluator backend.""" scenario: str backend: str device: str dtype: str n_nodes: int n_systems: int n_categories: int n_sources: int mean_degree: int n_pattern_plans: int n_triplet_layout_entries: int median_time_s: float min_time_s: float
[docs] @dataclass(frozen=True) class ThreeBodyDynamicBreakdownBenchmarkResult: """End-to-end dynamic inference timing split into bucket and evaluator stages.""" scenario: str backend: str device: str dtype: str n_nodes: int n_systems: int n_categories: int n_sources: int mean_degree: int n_pattern_plans: int n_triplet_layout_entries: int median_bucket_build_time_s: float median_evaluator_time_s: float median_total_time_s: float
[docs] @dataclass(frozen=True) class ThreeBodyCacheBenchmarkResult: """Timing split for fixed-geometry three-body feature-cache generation.""" scenario: str backend: str device: str dtype: str n_nodes: int n_systems: int n_categories: int n_sources: int mean_degree: int n_pattern_plans: int n_triplet_layout_entries: int median_bucket_build_time_s: float median_sparse_feature_time_s: float median_dense_feature_time_s: float median_cache_build_time_s: float median_cached_evaluator_time_s: float
@dataclass(frozen=True) class _DynamicScenario: name: str n_nodes: int n_categories: int n_sources: int degree: int coeff_size: int xyz_filename: str | None = None cutoff: float = 1.8 first_knot_xy: float = -0.75 first_knot_z: float = -0.75 knot_spacing_xy: float = 0.25 knot_spacing_z: float = 0.25 lower_support_xy: float = 0.0 lower_support_z: float = 0.0 @dataclass(frozen=True) class _RawWorkload: name: str first: torch.Tensor second: torch.Tensor node_cat: torch.Tensor pair_vectors: torch.Tensor pair_distances: torch.Tensor system_index: torch.Tensor coeffs: torch.Tensor edge_cat_table: torch.Tensor n_nodes: int n_systems: int n_categories: int n_sources: int mean_degree: int coeff_size: int first_knot_xy: float first_knot_z: float knot_spacing_xy: float knot_spacing_z: float lower_support_xy: float lower_support_z: float _BENCHMARK_DIR = Path(__file__).resolve().parent _SCENARIOS = { "triangle_pair_threebody": _DynamicScenario( name="triangle_pair_threebody", n_nodes=96, n_categories=1, n_sources=96, degree=8, coeff_size=8, ), "ternary_alloy": _DynamicScenario( name="ternary_alloy", n_nodes=192, n_categories=3, n_sources=192, degree=14, coeff_size=9, ), "high_degree_cluster": _DynamicScenario( name="high_degree_cluster", n_nodes=256, n_categories=4, n_sources=192, degree=28, coeff_size=10, ), "molecules_xyz": _DynamicScenario( name="molecules_xyz", n_nodes=0, n_categories=0, n_sources=0, degree=0, coeff_size=12, xyz_filename="molecules.xyz", cutoff=2.2, first_knot_xy=-0.75, first_knot_z=-0.75, knot_spacing_xy=0.25, knot_spacing_z=0.25, ), "unary_xyz": _DynamicScenario( name="unary_xyz", n_nodes=0, n_categories=0, n_sources=0, degree=0, coeff_size=14, xyz_filename="unary.xyz", cutoff=5.0, first_knot_xy=-1.0, first_knot_z=-1.0, knot_spacing_xy=0.5, knot_spacing_z=0.5, ), "alchemical_xyz": _DynamicScenario( name="alchemical_xyz", n_nodes=0, n_categories=0, n_sources=0, degree=0, coeff_size=14, xyz_filename="alchemical.xyz", cutoff=5.0, first_knot_xy=-1.0, first_knot_z=-1.0, knot_spacing_xy=0.5, knot_spacing_z=0.5, ), }
[docs] def available_threebody_dynamic_scenarios() -> tuple[str, ...]: """Return available dynamic three-body microbenchmark scenario names.""" return tuple(sorted(_SCENARIOS))
def _synchronize(device: torch.device) -> None: """Synchronize accelerator work before reading wall-clock time.""" if device.type == "cuda": torch.cuda.synchronize(device) def _build_workload( scenario: _DynamicScenario, *, device: torch.device, dtype: torch.dtype, seed: int, runtime_config: ThreeBodyRuntimeConfig, ): """Build one deterministic bucketed dynamic-inference workload.""" raw = _build_raw_workload( scenario, device=device, dtype=dtype, seed=seed, ) buckets = _with_pattern_metadata( _build_buckets(raw, runtime_config=runtime_config), device, ) return ( raw, buckets, raw.node_cat, raw.coeffs, raw.edge_cat_table, ) def _build_raw_workload( scenario: _DynamicScenario, *, device: torch.device, dtype: torch.dtype, seed: int, ) -> _RawWorkload: """Build deterministic pair tensors before three-body bucketing.""" if scenario.xyz_filename is not None: return _build_xyz_workload( scenario, device=device, dtype=dtype, seed=seed, ) generator = torch.Generator(device=device) generator.manual_seed(seed) src_ids = torch.arange(scenario.n_sources, device=device, dtype=torch.int64) first = torch.repeat_interleave(src_ids, scenario.degree) second = torch.randint( 0, scenario.n_nodes, (first.numel(),), device=device, generator=generator, dtype=torch.int64, ) duplicate = second == first second = torch.where(duplicate, (second + 1) % scenario.n_nodes, second) node_cat = torch.randint( 0, scenario.n_categories, (scenario.n_nodes,), device=device, generator=generator, dtype=torch.int64, ) directions = torch.randn((first.numel(), 3), device=device, dtype=dtype) directions = directions / torch.linalg.vector_norm( directions, dim=1, keepdim=True, ).clamp_min(torch.finfo(dtype).eps) distances = 0.25 + 0.45 * torch.rand( (first.numel(),), device=device, dtype=dtype, generator=generator, ) pair_vectors = directions * distances[:, None] n_triplet_categories = scenario.n_categories * ( scenario.n_categories * (scenario.n_categories + 1) // 2 ) coeffs = torch.randn( ( n_triplet_categories, scenario.coeff_size, scenario.coeff_size, scenario.coeff_size, ), device=device, dtype=dtype, generator=generator, ) return _RawWorkload( name=scenario.name, first=first, second=second, node_cat=node_cat, pair_vectors=pair_vectors, pair_distances=distances, system_index=torch.zeros((scenario.n_nodes,), device=device, dtype=torch.int64), coeffs=coeffs, edge_cat_table=build_edge_category_table( scenario.n_categories, device=device, ), n_nodes=scenario.n_nodes, n_systems=1, n_categories=scenario.n_categories, n_sources=scenario.n_sources, mean_degree=scenario.degree, coeff_size=scenario.coeff_size, first_knot_xy=scenario.first_knot_xy, first_knot_z=scenario.first_knot_z, knot_spacing_xy=scenario.knot_spacing_xy, knot_spacing_z=scenario.knot_spacing_z, lower_support_xy=scenario.lower_support_xy, lower_support_z=scenario.lower_support_z, ) def _read_extxyz_frames(path: Path): """Read an extended XYZ file with ASE.""" try: from ase.io import read except ImportError as exc: raise RuntimeError( "ASE is required for file-backed three-body benchmarks" ) from exc return read(path, ":") def _atomic_categories( atomic_numbers: torch.Tensor, ) -> tuple[torch.Tensor, int]: """Map atomic numbers to compact sorted categories.""" atomic_types = torch.unique(atomic_numbers.cpu(), sorted=True) categories = torch.full_like(atomic_numbers, -1, dtype=torch.int64) for category, atomic_number in enumerate(atomic_types.tolist()): categories[atomic_numbers == int(atomic_number)] = category return categories, int(atomic_types.numel()) def _build_xyz_workload( scenario: _DynamicScenario, *, device: torch.device, dtype: torch.dtype, seed: int, ) -> _RawWorkload: """Build one deterministic workload from an extended XYZ benchmark file.""" path = _BENCHMARK_DIR / str(scenario.xyz_filename) frames = _read_extxyz_frames(path) if not frames: raise RuntimeError(f"benchmark XYZ file is empty: {path}") first_parts: list[torch.Tensor] = [] second_parts: list[torch.Tensor] = [] vector_parts: list[torch.Tensor] = [] distance_parts: list[torch.Tensor] = [] atomic_number_parts: list[torch.Tensor] = [] system_index_parts: list[torch.Tensor] = [] atom_offset = 0 for system_id, atoms in enumerate(frames): neighbor_list = build_neighbor_list( atoms, cutoff=scenario.cutoff, arrays="torch", full_list=True, sorted=True, ) pairs = neighbor_list.pairs.to(dtype=torch.int64) first_parts.append(pairs[0] + atom_offset) second_parts.append(pairs[1] + atom_offset) if neighbor_list.vectors is None or neighbor_list.distances is None: raise RuntimeError("benchmark neighbor list did not return pair geometry") vector_parts.append(neighbor_list.vectors.to(dtype=dtype)) distance_parts.append(neighbor_list.distances.to(dtype=dtype)) atomic_numbers = torch.as_tensor(atoms.numbers, dtype=torch.int64) atomic_number_parts.append(atomic_numbers) system_index_parts.append( torch.full((len(atoms),), system_id, dtype=torch.int64) ) atom_offset += len(atoms) first = torch.cat(first_parts).to(device=device) second = torch.cat(second_parts).to(device=device) pair_vectors = torch.cat(vector_parts).to(device=device, dtype=dtype) pair_distances = torch.cat(distance_parts).to(device=device, dtype=dtype) atomic_numbers = torch.cat(atomic_number_parts).to(device=device) node_cat, n_categories = _atomic_categories(atomic_numbers) node_cat = node_cat.to(device=device) system_index = torch.cat(system_index_parts).to(device=device) n_triplet_categories = n_categories * (n_categories * (n_categories + 1) // 2) generator = torch.Generator(device=device) generator.manual_seed(seed) coeffs = torch.randn( ( n_triplet_categories, scenario.coeff_size, scenario.coeff_size, scenario.coeff_size, ), device=device, dtype=dtype, generator=generator, ) n_nodes = int(atomic_numbers.numel()) mean_degree = int(round(float(first.numel()) / max(n_nodes, 1))) return _RawWorkload( name=scenario.name, first=first, second=second, node_cat=node_cat, pair_vectors=pair_vectors, pair_distances=pair_distances, system_index=system_index, coeffs=coeffs, edge_cat_table=build_edge_category_table(n_categories, device=device), n_nodes=n_nodes, n_systems=len(frames), n_categories=n_categories, n_sources=n_nodes, mean_degree=mean_degree, coeff_size=scenario.coeff_size, first_knot_xy=scenario.first_knot_xy, first_knot_z=scenario.first_knot_z, knot_spacing_xy=scenario.knot_spacing_xy, knot_spacing_z=scenario.knot_spacing_z, lower_support_xy=scenario.lower_support_xy, lower_support_z=scenario.lower_support_z, ) def _build_buckets( raw: _RawWorkload, *, runtime_config: ThreeBodyRuntimeConfig, ): """Build reusable source buckets from a raw synthetic workload.""" return preprocess_sources_native_or_torch( raw.first, raw.second, raw.node_cat, raw.n_categories, raw.pair_vectors, raw.pair_distances, runtime_config=runtime_config, ) def _with_pattern_metadata(buckets, device: torch.device): """Attach reusable pattern metadata when native preprocessing did not.""" if buckets.tensor_pattern_plans is not None: return buckets return buckets.with_pattern_plans(device) def _timing_stats(timings: list[float]) -> tuple[float, float]: """Return median and minimum timing values.""" timing_tensor = torch.tensor(timings, dtype=torch.float64) return float(torch.median(timing_tensor).item()), float( torch.min(timing_tensor).item() )
[docs] def run_threebody_dynamic_inference_benchmark( *, scenario: str = "ternary_alloy", backend: str = "torch", bucket_backend: str | None = None, runtime_config: ThreeBodyRuntimeConfig | None = None, device: str | torch.device | None = None, dtype: str | torch.dtype | None = "auto", seed: int = 0, repeats: int = 20, warmup: int = 5, ) -> ThreeBodyDynamicBenchmarkResult: """Run one dynamic three-body evaluator microbenchmark.""" if scenario not in _SCENARIOS: choices = ", ".join(available_threebody_dynamic_scenarios()) raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}") if repeats <= 0 or warmup < 0: raise ValueError("`repeats` must be positive and `warmup` non-negative") config = resolve_threebody_runtime_config( runtime_config, dynamic_backend=None if runtime_config is not None else backend, bucket_backend=bucket_backend, ) backend = config.dynamic_backend resolved_device = resolve_device(device) resolved_dtype = resolve_dtype(resolved_device, dtype) scenario_data = _SCENARIOS[scenario] raw, buckets, node_cat, coeffs, edge_cat_table = _build_workload( scenario_data, device=resolved_device, dtype=resolved_dtype, seed=seed, runtime_config=config, ) def evaluate() -> tuple[torch.Tensor, torch.Tensor]: return evaluate_bucketed_energy_forces( buckets, node_cat, coeffs, edge_cat_table, spline="cubic", n_nodes=raw.n_nodes, n_cat=raw.n_categories, first_knot_xy=raw.first_knot_xy, first_knot_z=raw.first_knot_z, knot_spacing_xy=raw.knot_spacing_xy, knot_spacing_z=raw.knot_spacing_z, lower_support_xy=raw.lower_support_xy, lower_support_z=raw.lower_support_z, runtime_config=config, ) timings = [] for _ in range(warmup): evaluate() _synchronize(resolved_device) for _ in range(repeats): start = time.perf_counter() evaluate() _synchronize(resolved_device) timings.append(time.perf_counter() - start) median_time_s, min_time_s = _timing_stats(timings) tensor_plans = buckets.tensor_pattern_plans n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel()) return ThreeBodyDynamicBenchmarkResult( scenario=raw.name, backend=backend, device=str(resolved_device), dtype=str(resolved_dtype).replace("torch.", ""), n_nodes=raw.n_nodes, n_systems=raw.n_systems, n_categories=raw.n_categories, n_sources=raw.n_sources, mean_degree=raw.mean_degree, n_pattern_plans=int(buckets.patterns.shape[0]), n_triplet_layout_entries=n_layout, median_time_s=median_time_s, min_time_s=min_time_s, )
[docs] def run_threebody_dynamic_breakdown_benchmark( *, scenario: str = "ternary_alloy", backend: str = "torch", bucket_backend: str | None = None, runtime_config: ThreeBodyRuntimeConfig | None = None, device: str | torch.device | None = None, dtype: str | torch.dtype | None = "auto", seed: int = 0, repeats: int = 20, warmup: int = 5, ) -> ThreeBodyDynamicBreakdownBenchmarkResult: """Run dynamic inference with separate bucket-build and evaluator timings.""" if scenario not in _SCENARIOS: choices = ", ".join(available_threebody_dynamic_scenarios()) raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}") if repeats <= 0 or warmup < 0: raise ValueError("`repeats` must be positive and `warmup` non-negative") config = resolve_threebody_runtime_config( runtime_config, dynamic_backend=None if runtime_config is not None else backend, bucket_backend=bucket_backend, ) backend = config.dynamic_backend resolved_device = resolve_device(device) resolved_dtype = resolve_dtype(resolved_device, dtype) scenario_data = _SCENARIOS[scenario] raw = _build_raw_workload( scenario_data, device=resolved_device, dtype=resolved_dtype, seed=seed, ) bucket_timings: list[float] = [] evaluator_timings: list[float] = [] for _ in range(warmup): buckets = _with_pattern_metadata( _build_buckets(raw, runtime_config=config), resolved_device, ) evaluate_bucketed_energy_forces( buckets, raw.node_cat, raw.coeffs, raw.edge_cat_table, spline="cubic", n_nodes=raw.n_nodes, n_cat=raw.n_categories, first_knot_xy=raw.first_knot_xy, first_knot_z=raw.first_knot_z, knot_spacing_xy=raw.knot_spacing_xy, knot_spacing_z=raw.knot_spacing_z, lower_support_xy=raw.lower_support_xy, lower_support_z=raw.lower_support_z, runtime_config=config, ) _synchronize(resolved_device) measured_buckets = None for _ in range(repeats): start = time.perf_counter() buckets = _with_pattern_metadata( _build_buckets(raw, runtime_config=config), resolved_device, ) _synchronize(resolved_device) bucket_timings.append(time.perf_counter() - start) measured_buckets = buckets start = time.perf_counter() evaluate_bucketed_energy_forces( buckets, raw.node_cat, raw.coeffs, raw.edge_cat_table, spline="cubic", n_nodes=raw.n_nodes, n_cat=raw.n_categories, first_knot_xy=raw.first_knot_xy, first_knot_z=raw.first_knot_z, knot_spacing_xy=raw.knot_spacing_xy, knot_spacing_z=raw.knot_spacing_z, lower_support_xy=raw.lower_support_xy, lower_support_z=raw.lower_support_z, runtime_config=config, ) _synchronize(resolved_device) evaluator_timings.append(time.perf_counter() - start) assert measured_buckets is not None bucket_median, _ = _timing_stats(bucket_timings) evaluator_median, _ = _timing_stats(evaluator_timings) tensor_plans = measured_buckets.tensor_pattern_plans n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel()) return ThreeBodyDynamicBreakdownBenchmarkResult( scenario=raw.name, backend=backend, device=str(resolved_device), dtype=str(resolved_dtype).replace("torch.", ""), n_nodes=raw.n_nodes, n_systems=raw.n_systems, n_categories=raw.n_categories, n_sources=raw.n_sources, mean_degree=raw.mean_degree, n_pattern_plans=int(measured_buckets.patterns.shape[0]), n_triplet_layout_entries=n_layout, median_bucket_build_time_s=bucket_median, median_evaluator_time_s=evaluator_median, median_total_time_s=bucket_median + evaluator_median, )
[docs] def run_threebody_cache_benchmark( *, scenario: str = "ternary_alloy", backend: str = "torch", bucket_backend: str | None = None, runtime_config: ThreeBodyRuntimeConfig | None = None, device: str | torch.device | None = None, dtype: str | torch.dtype | None = "auto", seed: int = 0, repeats: int = 10, warmup: int = 2, ) -> ThreeBodyCacheBenchmarkResult: """Run fixed-geometry cache generation with stage-level timing.""" if scenario not in _SCENARIOS: choices = ", ".join(available_threebody_dynamic_scenarios()) raise ValueError(f"unknown scenario '{scenario}'; expected one of: {choices}") if repeats <= 0 or warmup < 0: raise ValueError("`repeats` must be positive and `warmup` non-negative") config = resolve_threebody_runtime_config( runtime_config, dynamic_backend=None if runtime_config is not None else backend, bucket_backend=bucket_backend, ) backend = config.dynamic_backend resolved_device = resolve_device(device) resolved_dtype = resolve_dtype(resolved_device, dtype) scenario_data = _SCENARIOS[scenario] raw = _build_raw_workload( scenario_data, device=resolved_device, dtype=resolved_dtype, seed=seed, ) coeff_shape = ( raw.coeff_size, raw.coeff_size, raw.coeff_size, ) bucket_timings: list[float] = [] sparse_timings: list[float] = [] dense_timings: list[float] = [] cached_eval_timings: list[float] = [] for _ in range(warmup): buckets = _with_pattern_metadata( _build_buckets(raw, runtime_config=config), resolved_device, ) feature_cache = _build_feature_cache_from_buckets( buckets, coeff_shape, spline="cubic", n_cat=raw.n_categories, first_knot_xy=raw.first_knot_xy, first_knot_z=raw.first_knot_z, knot_spacing_xy=raw.knot_spacing_xy, knot_spacing_z=raw.knot_spacing_z, runtime_config=config, ) dense_cache = _build_dense_feature_cache_from_feature_cache( feature_cache, raw.system_index, coeff_shape=coeff_shape, runtime_config=config, ) _evaluate_dense_feature_cache_energy_forces( dense_cache, raw.coeffs, n_nodes=raw.n_nodes, n_systems=raw.n_systems, ) _synchronize(resolved_device) measured_buckets = None for _ in range(repeats): start = time.perf_counter() buckets = _with_pattern_metadata( _build_buckets(raw, runtime_config=config), resolved_device, ) _synchronize(resolved_device) bucket_timings.append(time.perf_counter() - start) measured_buckets = buckets start = time.perf_counter() feature_cache = _build_feature_cache_from_buckets( buckets, coeff_shape, spline="cubic", n_cat=raw.n_categories, first_knot_xy=raw.first_knot_xy, first_knot_z=raw.first_knot_z, knot_spacing_xy=raw.knot_spacing_xy, knot_spacing_z=raw.knot_spacing_z, runtime_config=config, ) _synchronize(resolved_device) sparse_timings.append(time.perf_counter() - start) start = time.perf_counter() dense_cache = _build_dense_feature_cache_from_feature_cache( feature_cache, raw.system_index, coeff_shape=coeff_shape, runtime_config=config, ) _synchronize(resolved_device) dense_timings.append(time.perf_counter() - start) start = time.perf_counter() _evaluate_dense_feature_cache_energy_forces( dense_cache, raw.coeffs, n_nodes=raw.n_nodes, n_systems=raw.n_systems, ) _synchronize(resolved_device) cached_eval_timings.append(time.perf_counter() - start) assert measured_buckets is not None bucket_median, _ = _timing_stats(bucket_timings) sparse_median, _ = _timing_stats(sparse_timings) dense_median, _ = _timing_stats(dense_timings) cached_eval_median, _ = _timing_stats(cached_eval_timings) tensor_plans = measured_buckets.tensor_pattern_plans n_layout = 0 if tensor_plans is None else int(tensor_plans.row.numel()) return ThreeBodyCacheBenchmarkResult( scenario=raw.name, backend=backend, device=str(resolved_device), dtype=str(resolved_dtype).replace("torch.", ""), n_nodes=raw.n_nodes, n_systems=raw.n_systems, n_categories=raw.n_categories, n_sources=raw.n_sources, mean_degree=raw.mean_degree, n_pattern_plans=int(measured_buckets.patterns.shape[0]), n_triplet_layout_entries=n_layout, median_bucket_build_time_s=bucket_median, median_sparse_feature_time_s=sparse_median, median_dense_feature_time_s=dense_median, median_cache_build_time_s=sparse_median + dense_median, median_cached_evaluator_time_s=cached_eval_median, )
__all__ = [ "ThreeBodyCacheBenchmarkResult", "ThreeBodyDynamicBreakdownBenchmarkResult", "ThreeBodyDynamicBenchmarkResult", "available_threebody_dynamic_scenarios", "native_threebody_backend_available", "run_threebody_cache_benchmark", "run_threebody_dynamic_breakdown_benchmark", "run_threebody_dynamic_inference_benchmark", ]