"""Conversion and save helpers for UFP UF2+3 metatomic exports."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any
import torch
from ufp.adapters._metatomic_uf23 import (
UF23MetatomicModule,
UF23ModelState,
UF23PairSplineState,
UF23ThreeBodySplineState,
UF23TwoBodySplineState,
)
from ufp.adapters.metatomic import _dtype_name, _require_metatomic_torch
from ufp.terms.model import UFPModel
from ufp.terms.onebody import ElementOneBodyTerm
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm
def _normalized_devices(supported_devices: Sequence[str]) -> tuple[str, ...]:
"""Normalize and validate metatomic device names."""
devices = tuple(dict.fromkeys(str(device).lower() for device in supported_devices))
if not devices:
raise ValueError("`supported_devices` must contain at least one device")
unsupported = sorted(set(devices) - {"cpu", "cuda"})
if unsupported:
raise ValueError(
"unsupported metatomic device names: " + ", ".join(unsupported)
)
return devices
def _floating_dtype(model: UFPModel) -> torch.dtype:
"""Return the export dtype and reject unsupported precision."""
dtype = model.preferred_dtype()
if dtype not in (torch.float32, torch.float64):
raise ValueError(
"UF2+3 metatomic export supports only float32 and float64 parameters, "
f"got {dtype}"
)
return dtype
def _clone_export_tensor(
tensor: torch.Tensor,
*,
dtype: torch.dtype,
name: str,
) -> torch.Tensor:
"""Clone one floating coefficient tensor into export storage."""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"`{name}` must be exposed as a torch.Tensor")
if not tensor.is_floating_point():
raise ValueError(f"`{name}` must be floating point")
if tensor.dtype not in (torch.float32, torch.float64):
raise ValueError(
f"`{name}` must use float32 or float64 for UF2+3 export, got {tensor.dtype}"
)
return tensor.detach().to(device="cpu", dtype=dtype).contiguous().clone()
def _term_label(term: torch.nn.Module) -> str:
"""Return a concise term label for diagnostics."""
return f"{type(term).__module__}.{type(term).__name__}"
def _ensure_supported_model(model: UFPModel) -> None:
"""Validate that the term graph is in the UF2+3 export family."""
if not isinstance(model, UFPModel):
raise TypeError("UF2+3 metatomic export requires a UFPModel instance")
if model.atomic_types is None or len(model.atomic_types) == 0:
raise ValueError("UF2+3 metatomic export requires `model.atomic_types`")
if len(model.other_terms) != 0:
labels = ", ".join(_term_label(term) for term in model.other_terms)
raise ValueError("unsupported non-UF2+3 `other_terms`: " + labels)
for term in model.onebody_terms:
if not isinstance(term, ElementOneBodyTerm):
raise ValueError(
"unsupported one-body term for UF2+3 metatomic export: "
+ _term_label(term)
)
for term in model.pair_terms:
if not isinstance(term, (SplinePairTerm, SplineTwoBodyTerm)):
raise ValueError(
"unsupported pair term for UF2+3 metatomic export: " + _term_label(term)
)
if isinstance(term, SplinePairTerm) and not term.enabled:
raise ValueError(
"disabled SplinePairTerm entries can not be exported to the "
"production UF2+3 metatomic runtime"
)
for term in model.threebody_terms:
if not isinstance(term, SplineThreeBodyTerm):
raise ValueError(
"unsupported three-body term for UF2+3 metatomic export: "
+ _term_label(term)
)
def _validate_term_atomic_types(model: UFPModel) -> None:
"""Ensure every term is covered by the model-level atomic type list."""
assert model.atomic_types is not None
model_types = set(int(value) for value in model.atomic_types)
for term in model.terms:
term_types = getattr(term, "atomic_types", None)
if term_types is None:
continue
outside = sorted(set(int(value) for value in term_types) - model_types)
if outside:
raise ValueError(
f"{_term_label(term)} uses atomic types outside model.atomic_types: "
+ ", ".join(str(value) for value in outside)
)
def _validate_cuda_threebody_support(
model: UFPModel,
supported_devices: Sequence[str],
) -> None:
"""Fail early for CUDA exports that need the native three-body extension."""
if "cuda" not in supported_devices or len(model.threebody_terms) == 0:
return
for term in model.threebody_terms:
if str(term.spline) != "cubic":
raise ValueError(
"CUDA UF2+3 metatomic export supports only cubic three-body "
f"splines, got {term.spline!r}"
)
from ufp.terms._threebody_kernels import native_threebody_backend_available
if not native_threebody_backend_available(device="cuda", spline="cubic"):
raise RuntimeError(
"CUDA UF2+3 metatomic export requires the optional native UFP "
"three-body extension with a CUDA kernel. Reinstall with "
"`UFP_BUILD_NATIVE=1 UFP_BUILD_CUDA=1`, or export with "
"`supported_devices=('cpu',)` for CPU validation."
)
def _onebody_values(model: UFPModel, dtype: torch.dtype) -> torch.Tensor:
"""Collect one-body reference energies into model atomic-type order."""
assert model.atomic_types is not None
atomic_types = tuple(int(value) for value in model.atomic_types)
index_by_type = {
atomic_type: index for index, atomic_type in enumerate(atomic_types)
}
values = torch.zeros(len(atomic_types), dtype=dtype)
for term in model.onebody_terms:
term_values = _clone_export_tensor(
term.values,
dtype=dtype,
name="ElementOneBodyTerm.values",
)
assert term.atomic_types is not None
if tuple(term_values.shape) != (len(term.atomic_types),):
raise ValueError(
"ElementOneBodyTerm.values has shape "
f"{tuple(term_values.shape)}, expected ({len(term.atomic_types)},)"
)
for source_index, atomic_type in enumerate(term.atomic_types):
values[index_by_type[int(atomic_type)]] += term_values[source_index]
return values
def _pair_state(term: SplinePairTerm, dtype: torch.dtype) -> UF23PairSplineState:
"""Convert one SplinePairTerm through its public coefficient API."""
coeffs = _clone_export_tensor(
term.true_coeffs,
dtype=dtype,
name="SplinePairTerm.true_coeffs",
)
return UF23PairSplineState(
pair=tuple(int(value) for value in term.pair),
coeffs=coeffs,
symmetric=bool(term.symmetric),
spline=str(term.spline),
first_knot=float(term.first_knot),
knot_spacing=float(term.knot_spacing),
eps=float(term.eps),
)
def _twobody_state(
term: SplineTwoBodyTerm,
dtype: torch.dtype,
) -> UF23TwoBodySplineState:
"""Convert one SplineTwoBodyTerm through its public coefficient API."""
coeffs = _clone_export_tensor(
term.true_coeffs_by_pair,
dtype=dtype,
name="SplineTwoBodyTerm.true_coeffs_by_pair",
)
expected_shape = (len(term.pair_categories), int(coeffs.shape[1]))
if tuple(coeffs.shape) != expected_shape:
raise ValueError(
"SplineTwoBodyTerm.true_coeffs_by_pair has shape "
f"{tuple(coeffs.shape)}, expected {expected_shape}"
)
return UF23TwoBodySplineState(
atomic_types=tuple(int(value) for value in term.atomic_types or ()),
coeffs_by_pair=coeffs,
active_pair_mask=term.active_pair_mask.detach().cpu().clone(),
symmetric=bool(term.symmetric),
spline=str(term.spline),
first_knot=float(term.first_knot),
knot_spacing=float(term.knot_spacing),
eps=float(term.eps),
)
def _threebody_state(
term: SplineThreeBodyTerm,
dtype: torch.dtype,
) -> UF23ThreeBodySplineState:
"""Convert one SplineThreeBodyTerm through its public coefficient API."""
coeffs = _clone_export_tensor(
term.true_coeffs_by_triplet,
dtype=dtype,
name="SplineThreeBodyTerm.true_coeffs_by_triplet",
)
expected_shape = (len(term.triplet_categories), *tuple(term.coeff_shape))
if tuple(coeffs.shape) != expected_shape:
raise ValueError(
"SplineThreeBodyTerm.true_coeffs_by_triplet has shape "
f"{tuple(coeffs.shape)}, expected {expected_shape}"
)
return UF23ThreeBodySplineState(
atomic_types=tuple(int(value) for value in term.atomic_types or ()),
coeffs_by_triplet=coeffs,
active_triplet_mask=term.active_triplet_mask.detach().cpu().clone(),
edge_cat_table=term.edge_cat_table.detach().cpu().clone(),
spline=str(term.spline),
first_knot_xy=float(term.first_knot_xy),
first_knot_z=float(term.first_knot_z),
knot_spacing_xy=float(term.knot_spacing_xy),
knot_spacing_z=float(term.knot_spacing_z),
lower_support_xy=float(term.lower_support_xy),
upper_support_xy=float(term.upper_support_xy),
lower_support_z=float(term.lower_support_z),
eps=float(term.eps),
)
def build_uf23_model_state(
model: UFPModel,
*,
supported_devices: Sequence[str],
) -> UF23ModelState:
"""Validate and convert a UFPModel into UF2+3 export state."""
devices = _normalized_devices(supported_devices)
_ensure_supported_model(model)
_validate_term_atomic_types(model)
_validate_cuda_threebody_support(model, devices)
dtype = _floating_dtype(model)
pair_terms = []
twobody_terms = []
for term in model.pair_terms:
if isinstance(term, SplinePairTerm):
pair_terms.append(_pair_state(term, dtype))
elif isinstance(term, SplineTwoBodyTerm):
twobody_terms.append(_twobody_state(term, dtype))
else:
raise AssertionError("unsupported pair term escaped validation")
return UF23ModelState(
atomic_types=tuple(int(value) for value in model.atomic_types or ()),
onebody_values=_onebody_values(model, dtype),
pair_terms=tuple(pair_terms),
twobody_terms=tuple(twobody_terms),
threebody_terms=tuple(
_threebody_state(term, dtype) for term in model.threebody_terms
),
interaction_range=0.0 if model.cutoff is None else float(model.cutoff),
dtype=dtype,
)
def _make_atomistic_model(
state: UF23ModelState,
*,
length_unit: str,
energy_unit: str,
supported_devices: Sequence[str],
):
"""Create a metatomic AtomisticModel from converted UF2+3 state."""
mta = _require_metatomic_torch()
capabilities = mta.ModelCapabilities(
length_unit=length_unit,
atomic_types=list(state.atomic_types),
interaction_range=float(state.interaction_range),
outputs={
"energy": mta.ModelOutput(
quantity="energy",
unit=energy_unit,
per_atom=False,
),
"non_conservative_forces": mta.ModelOutput(
quantity="force",
unit=f"{energy_unit}/{length_unit}",
per_atom=True,
),
},
supported_devices=list(supported_devices),
dtype=_dtype_name(state.dtype),
)
metadata = mta.ModelMetadata()
runtime = UF23MetatomicModule(state).eval()
return mta.AtomisticModel(runtime, metadata, capabilities)
def _save_atomistic_model(
atomistic_model: Any,
path: Path,
*,
collect_extensions: str | Path | None,
) -> None:
"""Save a metatomic AtomisticModel with optional native-extension collection."""
save = getattr(atomistic_model, "save", None)
if not callable(save):
raise TypeError("metatomic AtomisticModel does not provide a callable save()")
path.parent.mkdir(parents=True, exist_ok=True)
if collect_extensions is None:
save(str(path))
else:
save(str(path), collect_extensions=str(collect_extensions))
[docs]
def export_uf23_atomistic_model(
model: UFPModel,
path: str | Path,
*,
length_unit: str = "Angstrom",
energy_unit: str = "eV",
supported_devices: Sequence[str] = ("cuda", "cpu"),
collect_extensions: str | Path | None = None,
check: bool = True,
) -> None:
"""Export a fitted UFP UF2+3 model as a metatomic AtomisticModel file."""
del check
devices = _normalized_devices(supported_devices)
state = build_uf23_model_state(model, supported_devices=devices)
atomistic_model = _make_atomistic_model(
state,
length_unit=length_unit,
energy_unit=energy_unit,
supported_devices=devices,
)
_save_atomistic_model(
atomistic_model,
Path(path),
collect_extensions=collect_extensions,
)
def _load_checkpoint_payload(checkpoint_path: str | Path) -> dict[str, Any]:
"""Load a torch checkpoint payload with compatibility for PyTorch versions."""
try:
payload = torch.load(
checkpoint_path,
map_location="cpu",
weights_only=False,
)
except TypeError:
payload = torch.load(checkpoint_path, map_location="cpu")
if not isinstance(payload, dict):
raise ValueError("checkpoint must be a dictionary containing `state_dict`")
return payload
def _checkpoint_state_dict(payload: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Extract the model state dict from supported checkpoint payloads."""
state_dict = payload.get("state_dict")
if state_dict is None:
state_dict = payload.get("model_state_dict")
if state_dict is None and all(
isinstance(key, str) and isinstance(value, torch.Tensor)
for key, value in payload.items()
):
state_dict = payload
if not isinstance(state_dict, dict):
raise ValueError("checkpoint must contain `state_dict` or `model_state_dict`")
return state_dict
def _load_checkpoint_state_dict(
model: UFPModel,
state_dict: dict[str, torch.Tensor],
payload: dict[str, Any],
) -> None:
"""Load checkpoint weights, allowing legacy one-body metadata when present."""
if "onebody_energy" not in payload:
model.load_state_dict(state_dict)
return
result = model.load_state_dict(state_dict, strict=False)
missing = [
key for key in result.missing_keys if not key.startswith("onebody_terms.")
]
if missing or result.unexpected_keys:
details = []
if missing:
details.append("missing keys: " + ", ".join(missing))
if result.unexpected_keys:
details.append("unexpected keys: " + ", ".join(result.unexpected_keys))
raise RuntimeError(
"checkpoint state_dict does not match the model factory output ("
+ "; ".join(details)
+ ")"
)
def _apply_onebody_energy_if_present(
model: UFPModel,
payload: dict[str, Any],
) -> UFPModel:
"""Apply legacy single-element one-body checkpoint metadata when possible."""
if "onebody_energy" not in payload:
return model
value = float(payload["onebody_energy"])
onebody_terms = list(model.onebody_terms)
if len(onebody_terms) == 1 and onebody_terms[0].values.numel() == 1:
onebody_terms[0].values.data.fill_(value)
return model
if len(onebody_terms) != 0:
raise ValueError(
"checkpoint contains scalar `onebody_energy`, but the model factory "
"returned a model with incompatible one-body terms"
)
if model.atomic_types is None or len(model.atomic_types) != 1:
raise ValueError(
"checkpoint contains scalar `onebody_energy`; the model factory must "
"return either one compatible ElementOneBodyTerm or a single-element "
"model so the reference energy can be inserted"
)
return UFPModel(
onebody_terms=[
ElementOneBodyTerm(
atomic_types=model.atomic_types,
values=torch.tensor([value], dtype=model.preferred_dtype()),
trainable=False,
)
],
pair_terms=tuple(model.pair_terms),
threebody_terms=tuple(model.threebody_terms),
atomic_types=model.atomic_types,
neighbor_backend=model.neighbor_backend,
)
[docs]
def export_uf23_checkpoint(
checkpoint_path: str | Path,
model_factory: Callable[[], UFPModel],
output_path: str | Path,
*,
length_unit: str = "Angstrom",
energy_unit: str = "eV",
supported_devices: Sequence[str] = ("cuda", "cpu"),
collect_extensions: str | Path | None = None,
check: bool = True,
) -> None:
"""Rebuild a UF2+3 model, load a checkpoint state_dict, and export it."""
model = model_factory()
if not isinstance(model, UFPModel):
raise TypeError("`model_factory` must return a UFPModel")
payload = _load_checkpoint_payload(checkpoint_path)
state_dict = _checkpoint_state_dict(payload)
_load_checkpoint_state_dict(model, state_dict, payload)
model = _apply_onebody_energy_if_present(model, payload)
export_uf23_atomistic_model(
model,
output_path,
length_unit=length_unit,
energy_unit=energy_unit,
supported_devices=supported_devices,
collect_extensions=collect_extensions,
check=check,
)
__all__ = [
"build_uf23_model_state",
"export_uf23_atomistic_model",
"export_uf23_checkpoint",
]