"""Workflow helpers for choosing efficient three-body runtime defaults."""
from __future__ import annotations
from dataclasses import dataclass
import torch
from ufp.terms._threebody_kernels import (
native_threebody_backend_available,
native_threebody_dense_feature_cache_available,
native_threebody_feature_cache_available,
native_threebody_lstsq_assemble_available,
native_threebody_preprocess_sources_available,
)
from ufp.terms._threebody_runtime import (
resolve_threebody_runtime_config,
set_default_threebody_runtime_env,
set_threebody_lstsq_runtime_env,
set_threebody_runtime_env,
)
[docs]
@dataclass(frozen=True)
class ThreeBodyRuntimeStatus:
"""Resolved three-body runtime configuration for a training workflow."""
requested_backend: str
requested_bucket_backend: str
selected_device: torch.device
native_cxx_evaluator_available: bool
native_cuda_evaluator_available: bool
native_cxx_bucketing_available: bool
native_cxx_dense_cache_available: bool
native_cxx_bucketing_used: bool
native_cxx_dense_cache_used: bool
native_cxx_dynamic_used: bool
native_cuda_dynamic_used: bool
dynamic_note: str
[docs]
def as_metadata(self) -> dict[str, object]:
"""Return JSON/checkpoint-friendly runtime metadata."""
return {
"threebody_backend": self.requested_backend,
"threebody_bucket_backend": self.requested_bucket_backend,
"train_device": str(self.selected_device),
"native_cxx_evaluator_available": self.native_cxx_evaluator_available,
"native_cuda_evaluator_available": self.native_cuda_evaluator_available,
"native_cxx_bucketing_available": self.native_cxx_bucketing_available,
"native_cxx_dense_cache_available": self.native_cxx_dense_cache_available,
"native_cxx_bucketing_used": self.native_cxx_bucketing_used,
"native_cxx_dense_cache_used": self.native_cxx_dense_cache_used,
"native_cxx_dynamic_used": self.native_cxx_dynamic_used,
"native_cuda_dynamic_used": self.native_cuda_dynamic_used,
"threebody_dynamic_note": self.dynamic_note,
}
[docs]
@dataclass(frozen=True)
class ThreeBodyLeastSquaresRuntimeStatus:
"""Resolved three-body runtime configuration for least-squares workflows."""
requested_backend: str
selected_backend: str
requested_bucket_backend: str
fit_device: torch.device
prediction_device: torch.device
assembled_cache_mode: str
normal_equation_cache_mode: str
normal_equation_build_device: torch.device | None
native_cxx_assembly_available: bool
native_cuda_assembly_available: bool
native_cxx_bucketing_available: bool
native_cxx_assembly_used: bool
native_cuda_assembly_used: bool
native_cxx_bucketing_used: bool
[docs]
def as_metadata(self) -> dict[str, object]:
"""Return JSON/checkpoint-friendly runtime metadata."""
return {
"threebody_lstsq_backend": self.requested_backend,
"threebody_lstsq_selected_backend": self.selected_backend,
"threebody_bucket_backend": self.requested_bucket_backend,
"fit_device": str(self.fit_device),
"prediction_device": str(self.prediction_device),
"lstsq_cache_mode": self.assembled_cache_mode,
"normal_equation_cache_mode": self.normal_equation_cache_mode,
"normal_equation_build_device": (
None
if self.normal_equation_build_device is None
else str(self.normal_equation_build_device)
),
"native_cxx_lstsq_available": self.native_cxx_assembly_available,
"native_cuda_lstsq_available": self.native_cuda_assembly_available,
"native_cxx_bucketing_available": self.native_cxx_bucketing_available,
"native_cxx_lstsq_used": self.native_cxx_assembly_used,
"native_cuda_lstsq_used": self.native_cuda_assembly_used,
"native_cxx_bucketing_used": self.native_cxx_bucketing_used,
}
[docs]
def configure_threebody_runtime_defaults() -> None:
"""
Set conservative three-body runtime defaults unless already configured.
``auto`` prefers native dynamic kernels when available. Bucket ``auto`` uses
the Python/Torch path for CUDA inputs to avoid CPU-GPU transfer overhead, and
may use native CPU source preprocessing for CPU inputs. Explicit
``UFP_THREEBODY_BACKEND`` and ``UFP_THREEBODY_BUCKET_BACKEND`` environment
settings are preserved.
"""
set_default_threebody_runtime_env()
[docs]
def preferred_threebody_inference_device() -> torch.device:
"""
Return the preferred dynamic three-body inference device.
The order follows current dynamic-inference benchmarks:
native CUDA, native CPU C++, GPU PyTorch, then CPU PyTorch.
Returns:
Preferred torch device for dynamic three-body inference.
"""
configure_threebody_runtime_defaults()
if torch.cuda.is_available() and native_threebody_backend_available(device="cuda"):
return torch.device("cuda")
if native_threebody_backend_available(device="cpu"):
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def _resolve_device_option(
*,
selection: str | torch.device,
explicit: str | torch.device | None,
) -> torch.device:
"""Resolve an explicit device or an auto/cpu/cuda selection."""
if explicit is not None:
return torch.device(explicit)
normalized = str(selection).strip().lower()
if normalized == "auto":
return preferred_threebody_inference_device()
if normalized == "gpu":
normalized = "cuda"
return torch.device(normalized)
[docs]
def configure_threebody_leastsquares_runtime(
*,
backend: str = "auto",
bucket_backend: str = "python",
fit_device_selection: str | torch.device = "auto",
fit_device: str | torch.device | None = None,
prediction_device_selection: str | torch.device = "auto",
prediction_device: str | torch.device | None = None,
assembled_cache_mode: str = "auto",
normal_equation_cache_mode: str = "auto",
normal_equation_build_device: str | torch.device | None = None,
) -> tuple[torch.device, torch.device, ThreeBodyLeastSquaresRuntimeStatus]:
"""
Apply three-body least-squares runtime options and report dispatch status.
Least-squares ``auto`` follows the current assembly benchmarks: native CUDA,
native CPU C++, CUDA Torch, then CPU Torch. Native source bucketing is not
part of the default least-squares path because the benchmarked workloads do
not justify its overhead.
Args:
backend: Requested least-squares assembly backend.
bucket_backend: Requested source-bucketing backend.
fit_device_selection: Automatic or explicit fit-device selection.
fit_device: Explicit fit device override.
prediction_device_selection: Automatic or explicit prediction-device
selection.
prediction_device: Explicit prediction device override.
assembled_cache_mode: Assembled-batch cache mode label.
normal_equation_cache_mode: Normal-equation cache mode label.
normal_equation_build_device: Optional device used to build normal equations.
Returns:
Fit device, prediction device, and selected runtime status.
"""
runtime_config = resolve_threebody_runtime_config(
lstsq_backend=backend,
bucket_backend=bucket_backend,
)
backend_normalized = runtime_config.lstsq_backend
bucket_backend_normalized = runtime_config.bucket_backend
set_threebody_lstsq_runtime_env(
backend=backend_normalized,
bucket_backend=bucket_backend_normalized,
)
configure_threebody_runtime_defaults()
native_cxx_assembly_available = native_threebody_feature_cache_available(
device="cpu"
) and native_threebody_dense_feature_cache_available(device="cpu")
native_cuda_assembly_available = (
torch.cuda.is_available()
and native_threebody_lstsq_assemble_available(device="cuda")
)
native_cxx_bucketing_available = native_threebody_preprocess_sources_available(
device="cpu"
)
if fit_device is None and str(fit_device_selection).strip().lower() == "auto":
if native_cuda_assembly_available:
resolved_fit_device = torch.device("cuda")
elif native_cxx_assembly_available:
resolved_fit_device = torch.device("cpu")
elif torch.cuda.is_available():
resolved_fit_device = torch.device("cuda")
else:
resolved_fit_device = torch.device("cpu")
else:
resolved_fit_device = _resolve_device_option(
selection=fit_device_selection,
explicit=fit_device,
)
resolved_prediction_device = _resolve_device_option(
selection=prediction_device_selection,
explicit=prediction_device,
)
resolved_normal_device = (
None
if normal_equation_build_device is None
else torch.device(normal_equation_build_device)
)
build_device = (
resolved_fit_device
if resolved_normal_device is None
else resolved_normal_device
)
if backend_normalized == "torch":
selected_backend = "torch"
elif backend_normalized == "cuda":
selected_backend = "cuda" if native_cuda_assembly_available else "unavailable"
elif build_device.type == "cuda" and native_cuda_assembly_available:
selected_backend = "cuda"
elif native_cxx_assembly_available:
selected_backend = "native"
else:
selected_backend = "torch" if backend_normalized == "auto" else "unavailable"
return (
resolved_fit_device,
resolved_prediction_device,
(
ThreeBodyLeastSquaresRuntimeStatus(
requested_backend=backend_normalized,
selected_backend=selected_backend,
requested_bucket_backend=bucket_backend_normalized,
fit_device=resolved_fit_device,
prediction_device=resolved_prediction_device,
assembled_cache_mode=str(assembled_cache_mode),
normal_equation_cache_mode=str(normal_equation_cache_mode),
normal_equation_build_device=resolved_normal_device,
native_cxx_assembly_available=native_cxx_assembly_available,
native_cuda_assembly_available=native_cuda_assembly_available,
native_cxx_bucketing_available=native_cxx_bucketing_available,
native_cxx_assembly_used=selected_backend == "native",
native_cuda_assembly_used=selected_backend == "cuda",
native_cxx_bucketing_used=(
(
bucket_backend_normalized == "native"
or (
bucket_backend_normalized == "auto"
and build_device.type == "cpu"
)
)
and native_cxx_bucketing_available
),
)
),
)
[docs]
def configure_threebody_training_runtime(
*,
backend: str = "auto",
bucket_backend: str = "auto",
device_selection: str | torch.device = "auto",
device: str | torch.device | None = None,
cache_batches: bool = False,
cache_batches_on_device: bool = False,
feature_cache_storage: str = "none",
) -> tuple[torch.device, ThreeBodyRuntimeStatus]:
"""
Apply three-body runtime options and return the selected training status.
Native dynamic three-body kernels are inference-only today, so gradient
training always uses the PyTorch evaluator. The native CPU bucket path can
still be used when its inputs remain on CPU.
Args:
backend: Requested dynamic three-body evaluator backend.
bucket_backend: Requested source-bucketing backend.
device_selection: Automatic or explicit training-device selection.
device: Explicit training device override.
cache_batches: Whether training will reuse cached batches.
cache_batches_on_device: Whether cached batches stay on the training device.
feature_cache_storage: Feature-cache storage policy.
Returns:
Training device and selected runtime status.
"""
runtime_config = resolve_threebody_runtime_config(
dynamic_backend=backend,
bucket_backend=bucket_backend,
)
backend_normalized = runtime_config.dynamic_backend
bucket_backend_normalized = runtime_config.bucket_backend
set_threebody_runtime_env(
backend="torch",
bucket_backend=bucket_backend_normalized,
)
configure_threebody_runtime_defaults()
if device is not None:
selected_device = torch.device(device)
else:
device_selection_normalized = str(device_selection).strip().lower()
if device_selection_normalized == "auto":
selected_device = preferred_threebody_inference_device()
else:
if device_selection_normalized == "gpu":
device_selection_normalized = "cuda"
selected_device = torch.device(device_selection_normalized)
native_cxx_evaluator_available = native_threebody_backend_available(device="cpu")
native_cuda_evaluator_available = (
torch.cuda.is_available() and native_threebody_backend_available(device="cuda")
)
native_cxx_bucketing_available = native_threebody_preprocess_sources_available(
device="cpu"
)
native_cxx_dense_cache_available = native_threebody_feature_cache_available(
device="cpu"
) and native_threebody_dense_feature_cache_available(device="cpu")
native_cxx_bucketing_used = (
bucket_backend_normalized == "native"
or (bucket_backend_normalized == "auto" and selected_device.type == "cpu")
) and native_cxx_bucketing_available
native_cxx_dense_cache_used = False
dynamic_note = "UFP_THREEBODY_BACKEND=torch"
if backend_normalized != "torch":
dynamic_note = (
"training requires autograd; native dynamic kernels are inference-only"
)
return selected_device, ThreeBodyRuntimeStatus(
requested_backend=backend_normalized,
requested_bucket_backend=bucket_backend_normalized,
selected_device=selected_device,
native_cxx_evaluator_available=native_cxx_evaluator_available,
native_cuda_evaluator_available=native_cuda_evaluator_available,
native_cxx_bucketing_available=native_cxx_bucketing_available,
native_cxx_dense_cache_available=native_cxx_dense_cache_available,
native_cxx_bucketing_used=native_cxx_bucketing_used,
native_cxx_dense_cache_used=native_cxx_dense_cache_used,
native_cxx_dynamic_used=False,
native_cuda_dynamic_used=False,
dynamic_note=dynamic_note,
)
[docs]
def print_threebody_runtime_status(status: ThreeBodyRuntimeStatus) -> None:
"""Print native three-body availability and expected training dispatch."""
print("Three-body runtime:")
print(" requested evaluator backend:", status.requested_backend)
print(" requested bucket backend:", status.requested_bucket_backend)
print(" selected training device:", status.selected_device)
print(
" native C++ evaluator available:",
status.native_cxx_evaluator_available,
)
print(
" native CUDA evaluator available:",
status.native_cuda_evaluator_available,
)
print(
" native C++ bucketing available:",
status.native_cxx_bucketing_available,
)
print(
" native C++ dense-cache available:",
status.native_cxx_dense_cache_available,
)
print(" native C++ bucketing used:", status.native_cxx_bucketing_used)
print(" native C++ dense-cache used:", status.native_cxx_dense_cache_used)
print(
" native C++ dynamic evaluator used during training:",
status.native_cxx_dynamic_used,
)
print(
" native CUDA dynamic evaluator used during training:",
status.native_cuda_dynamic_used,
)
print(" dynamic evaluator note:", status.dynamic_note)
[docs]
def print_threebody_leastsquares_runtime_status(
status: ThreeBodyLeastSquaresRuntimeStatus,
) -> None:
"""Print native three-body least-squares availability and dispatch."""
print("Three-body least-squares runtime:")
print(" requested assembly backend:", status.requested_backend)
print(" selected assembly backend:", status.selected_backend)
print(" requested bucket backend:", status.requested_bucket_backend)
print(" fit device:", status.fit_device)
print(" prediction device:", status.prediction_device)
print(" assembled-cache mode:", status.assembled_cache_mode)
print(" normal-equation cache mode:", status.normal_equation_cache_mode)
print(" normal-equation build device:", status.normal_equation_build_device)
print(" native C++ assembly available:", status.native_cxx_assembly_available)
print(" native CUDA assembly available:", status.native_cuda_assembly_available)
print(" native C++ bucketing available:", status.native_cxx_bucketing_available)
print(" native C++ assembly used:", status.native_cxx_assembly_used)
print(" native CUDA assembly used:", status.native_cuda_assembly_used)
print(" native C++ bucketing used:", status.native_cxx_bucketing_used)
__all__ = [
"ThreeBodyLeastSquaresRuntimeStatus",
"ThreeBodyRuntimeStatus",
"configure_threebody_leastsquares_runtime",
"configure_threebody_training_runtime",
"configure_threebody_runtime_defaults",
"preferred_threebody_inference_device",
"print_threebody_leastsquares_runtime_status",
"print_threebody_runtime_status",
]