Source code for ufp.adapters._metatomic_export

"""Conversion and save helpers for UFP UF2+3 metatomic exports."""

from __future__ import annotations

from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

import torch

from ufp.adapters._metatomic_uf23 import (
    UF23MetatomicModule,
    UF23ModelState,
    UF23PairSplineState,
    UF23ThreeBodySplineState,
    UF23TwoBodySplineState,
)
from ufp.adapters.metatomic import _dtype_name, _require_metatomic_torch
from ufp.terms.model import UFPModel
from ufp.terms.onebody import ElementOneBodyTerm
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm


def _normalized_devices(supported_devices: Sequence[str]) -> tuple[str, ...]:
    """Normalize and validate metatomic device names."""
    devices = tuple(dict.fromkeys(str(device).lower() for device in supported_devices))
    if not devices:
        raise ValueError("`supported_devices` must contain at least one device")

    unsupported = sorted(set(devices) - {"cpu", "cuda"})
    if unsupported:
        raise ValueError(
            "unsupported metatomic device names: " + ", ".join(unsupported)
        )
    return devices


def _floating_dtype(model: UFPModel) -> torch.dtype:
    """Return the export dtype and reject unsupported precision."""
    dtype = model.preferred_dtype()
    if dtype not in (torch.float32, torch.float64):
        raise ValueError(
            "UF2+3 metatomic export supports only float32 and float64 parameters, "
            f"got {dtype}"
        )
    return dtype


def _clone_export_tensor(
    tensor: torch.Tensor,
    *,
    dtype: torch.dtype,
    name: str,
) -> torch.Tensor:
    """Clone one floating coefficient tensor into export storage."""
    if not isinstance(tensor, torch.Tensor):
        raise TypeError(f"`{name}` must be exposed as a torch.Tensor")
    if not tensor.is_floating_point():
        raise ValueError(f"`{name}` must be floating point")
    if tensor.dtype not in (torch.float32, torch.float64):
        raise ValueError(
            f"`{name}` must use float32 or float64 for UF2+3 export, got {tensor.dtype}"
        )
    return tensor.detach().to(device="cpu", dtype=dtype).contiguous().clone()


def _term_label(term: torch.nn.Module) -> str:
    """Return a concise term label for diagnostics."""
    return f"{type(term).__module__}.{type(term).__name__}"


def _ensure_supported_model(model: UFPModel) -> None:
    """Validate that the term graph is in the UF2+3 export family."""
    if not isinstance(model, UFPModel):
        raise TypeError("UF2+3 metatomic export requires a UFPModel instance")
    if model.atomic_types is None or len(model.atomic_types) == 0:
        raise ValueError("UF2+3 metatomic export requires `model.atomic_types`")
    if len(model.other_terms) != 0:
        labels = ", ".join(_term_label(term) for term in model.other_terms)
        raise ValueError("unsupported non-UF2+3 `other_terms`: " + labels)

    for term in model.onebody_terms:
        if not isinstance(term, ElementOneBodyTerm):
            raise ValueError(
                "unsupported one-body term for UF2+3 metatomic export: "
                + _term_label(term)
            )

    for term in model.pair_terms:
        if not isinstance(term, (SplinePairTerm, SplineTwoBodyTerm)):
            raise ValueError(
                "unsupported pair term for UF2+3 metatomic export: " + _term_label(term)
            )
        if isinstance(term, SplinePairTerm) and not term.enabled:
            raise ValueError(
                "disabled SplinePairTerm entries can not be exported to the "
                "production UF2+3 metatomic runtime"
            )

    for term in model.threebody_terms:
        if not isinstance(term, SplineThreeBodyTerm):
            raise ValueError(
                "unsupported three-body term for UF2+3 metatomic export: "
                + _term_label(term)
            )


def _validate_term_atomic_types(model: UFPModel) -> None:
    """Ensure every term is covered by the model-level atomic type list."""
    assert model.atomic_types is not None
    model_types = set(int(value) for value in model.atomic_types)
    for term in model.terms:
        term_types = getattr(term, "atomic_types", None)
        if term_types is None:
            continue
        outside = sorted(set(int(value) for value in term_types) - model_types)
        if outside:
            raise ValueError(
                f"{_term_label(term)} uses atomic types outside model.atomic_types: "
                + ", ".join(str(value) for value in outside)
            )


def _validate_cuda_threebody_support(
    model: UFPModel,
    supported_devices: Sequence[str],
) -> None:
    """Fail early for CUDA exports that need the native three-body extension."""
    if "cuda" not in supported_devices or len(model.threebody_terms) == 0:
        return

    for term in model.threebody_terms:
        if str(term.spline) != "cubic":
            raise ValueError(
                "CUDA UF2+3 metatomic export supports only cubic three-body "
                f"splines, got {term.spline!r}"
            )

    from ufp.terms._threebody_kernels import native_threebody_backend_available

    if not native_threebody_backend_available(device="cuda", spline="cubic"):
        raise RuntimeError(
            "CUDA UF2+3 metatomic export requires the optional native UFP "
            "three-body extension with a CUDA kernel. Reinstall with "
            "`UFP_BUILD_NATIVE=1 UFP_BUILD_CUDA=1`, or export with "
            "`supported_devices=('cpu',)` for CPU validation."
        )


def _onebody_values(model: UFPModel, dtype: torch.dtype) -> torch.Tensor:
    """Collect one-body reference energies into model atomic-type order."""
    assert model.atomic_types is not None
    atomic_types = tuple(int(value) for value in model.atomic_types)
    index_by_type = {
        atomic_type: index for index, atomic_type in enumerate(atomic_types)
    }
    values = torch.zeros(len(atomic_types), dtype=dtype)
    for term in model.onebody_terms:
        term_values = _clone_export_tensor(
            term.values,
            dtype=dtype,
            name="ElementOneBodyTerm.values",
        )
        assert term.atomic_types is not None
        if tuple(term_values.shape) != (len(term.atomic_types),):
            raise ValueError(
                "ElementOneBodyTerm.values has shape "
                f"{tuple(term_values.shape)}, expected ({len(term.atomic_types)},)"
            )
        for source_index, atomic_type in enumerate(term.atomic_types):
            values[index_by_type[int(atomic_type)]] += term_values[source_index]
    return values


def _pair_state(term: SplinePairTerm, dtype: torch.dtype) -> UF23PairSplineState:
    """Convert one SplinePairTerm through its public coefficient API."""
    coeffs = _clone_export_tensor(
        term.true_coeffs,
        dtype=dtype,
        name="SplinePairTerm.true_coeffs",
    )
    return UF23PairSplineState(
        pair=tuple(int(value) for value in term.pair),
        coeffs=coeffs,
        symmetric=bool(term.symmetric),
        spline=str(term.spline),
        first_knot=float(term.first_knot),
        knot_spacing=float(term.knot_spacing),
        eps=float(term.eps),
    )


def _twobody_state(
    term: SplineTwoBodyTerm,
    dtype: torch.dtype,
) -> UF23TwoBodySplineState:
    """Convert one SplineTwoBodyTerm through its public coefficient API."""
    coeffs = _clone_export_tensor(
        term.true_coeffs_by_pair,
        dtype=dtype,
        name="SplineTwoBodyTerm.true_coeffs_by_pair",
    )
    expected_shape = (len(term.pair_categories), int(coeffs.shape[1]))
    if tuple(coeffs.shape) != expected_shape:
        raise ValueError(
            "SplineTwoBodyTerm.true_coeffs_by_pair has shape "
            f"{tuple(coeffs.shape)}, expected {expected_shape}"
        )
    return UF23TwoBodySplineState(
        atomic_types=tuple(int(value) for value in term.atomic_types or ()),
        coeffs_by_pair=coeffs,
        active_pair_mask=term.active_pair_mask.detach().cpu().clone(),
        symmetric=bool(term.symmetric),
        spline=str(term.spline),
        first_knot=float(term.first_knot),
        knot_spacing=float(term.knot_spacing),
        eps=float(term.eps),
    )


def _threebody_state(
    term: SplineThreeBodyTerm,
    dtype: torch.dtype,
) -> UF23ThreeBodySplineState:
    """Convert one SplineThreeBodyTerm through its public coefficient API."""
    coeffs = _clone_export_tensor(
        term.true_coeffs_by_triplet,
        dtype=dtype,
        name="SplineThreeBodyTerm.true_coeffs_by_triplet",
    )
    expected_shape = (len(term.triplet_categories), *tuple(term.coeff_shape))
    if tuple(coeffs.shape) != expected_shape:
        raise ValueError(
            "SplineThreeBodyTerm.true_coeffs_by_triplet has shape "
            f"{tuple(coeffs.shape)}, expected {expected_shape}"
        )
    return UF23ThreeBodySplineState(
        atomic_types=tuple(int(value) for value in term.atomic_types or ()),
        coeffs_by_triplet=coeffs,
        active_triplet_mask=term.active_triplet_mask.detach().cpu().clone(),
        edge_cat_table=term.edge_cat_table.detach().cpu().clone(),
        spline=str(term.spline),
        first_knot_xy=float(term.first_knot_xy),
        first_knot_z=float(term.first_knot_z),
        knot_spacing_xy=float(term.knot_spacing_xy),
        knot_spacing_z=float(term.knot_spacing_z),
        lower_support_xy=float(term.lower_support_xy),
        upper_support_xy=float(term.upper_support_xy),
        lower_support_z=float(term.lower_support_z),
        eps=float(term.eps),
    )


def build_uf23_model_state(
    model: UFPModel,
    *,
    supported_devices: Sequence[str],
) -> UF23ModelState:
    """Validate and convert a UFPModel into UF2+3 export state."""
    devices = _normalized_devices(supported_devices)
    _ensure_supported_model(model)
    _validate_term_atomic_types(model)
    _validate_cuda_threebody_support(model, devices)
    dtype = _floating_dtype(model)

    pair_terms = []
    twobody_terms = []
    for term in model.pair_terms:
        if isinstance(term, SplinePairTerm):
            pair_terms.append(_pair_state(term, dtype))
        elif isinstance(term, SplineTwoBodyTerm):
            twobody_terms.append(_twobody_state(term, dtype))
        else:
            raise AssertionError("unsupported pair term escaped validation")

    return UF23ModelState(
        atomic_types=tuple(int(value) for value in model.atomic_types or ()),
        onebody_values=_onebody_values(model, dtype),
        pair_terms=tuple(pair_terms),
        twobody_terms=tuple(twobody_terms),
        threebody_terms=tuple(
            _threebody_state(term, dtype) for term in model.threebody_terms
        ),
        interaction_range=0.0 if model.cutoff is None else float(model.cutoff),
        dtype=dtype,
    )


def _make_atomistic_model(
    state: UF23ModelState,
    *,
    length_unit: str,
    energy_unit: str,
    supported_devices: Sequence[str],
):
    """Create a metatomic AtomisticModel from converted UF2+3 state."""
    mta = _require_metatomic_torch()
    capabilities = mta.ModelCapabilities(
        length_unit=length_unit,
        atomic_types=list(state.atomic_types),
        interaction_range=float(state.interaction_range),
        outputs={
            "energy": mta.ModelOutput(
                quantity="energy",
                unit=energy_unit,
                per_atom=False,
            ),
            "non_conservative_forces": mta.ModelOutput(
                quantity="force",
                unit=f"{energy_unit}/{length_unit}",
                per_atom=True,
            ),
        },
        supported_devices=list(supported_devices),
        dtype=_dtype_name(state.dtype),
    )
    metadata = mta.ModelMetadata()
    runtime = UF23MetatomicModule(state).eval()
    return mta.AtomisticModel(runtime, metadata, capabilities)


def _save_atomistic_model(
    atomistic_model: Any,
    path: Path,
    *,
    collect_extensions: str | Path | None,
) -> None:
    """Save a metatomic AtomisticModel with optional native-extension collection."""
    save = getattr(atomistic_model, "save", None)
    if not callable(save):
        raise TypeError("metatomic AtomisticModel does not provide a callable save()")

    path.parent.mkdir(parents=True, exist_ok=True)
    if collect_extensions is None:
        save(str(path))
    else:
        save(str(path), collect_extensions=str(collect_extensions))


[docs] def export_uf23_atomistic_model( model: UFPModel, path: str | Path, *, length_unit: str = "Angstrom", energy_unit: str = "eV", supported_devices: Sequence[str] = ("cuda", "cpu"), collect_extensions: str | Path | None = None, check: bool = True, ) -> None: """Export a fitted UFP UF2+3 model as a metatomic AtomisticModel file.""" del check devices = _normalized_devices(supported_devices) state = build_uf23_model_state(model, supported_devices=devices) atomistic_model = _make_atomistic_model( state, length_unit=length_unit, energy_unit=energy_unit, supported_devices=devices, ) _save_atomistic_model( atomistic_model, Path(path), collect_extensions=collect_extensions, )
def _load_checkpoint_payload(checkpoint_path: str | Path) -> dict[str, Any]: """Load a torch checkpoint payload with compatibility for PyTorch versions.""" try: payload = torch.load( checkpoint_path, map_location="cpu", weights_only=False, ) except TypeError: payload = torch.load(checkpoint_path, map_location="cpu") if not isinstance(payload, dict): raise ValueError("checkpoint must be a dictionary containing `state_dict`") return payload def _checkpoint_state_dict(payload: dict[str, Any]) -> dict[str, torch.Tensor]: """Extract the model state dict from supported checkpoint payloads.""" state_dict = payload.get("state_dict") if state_dict is None: state_dict = payload.get("model_state_dict") if state_dict is None and all( isinstance(key, str) and isinstance(value, torch.Tensor) for key, value in payload.items() ): state_dict = payload if not isinstance(state_dict, dict): raise ValueError("checkpoint must contain `state_dict` or `model_state_dict`") return state_dict def _load_checkpoint_state_dict( model: UFPModel, state_dict: dict[str, torch.Tensor], payload: dict[str, Any], ) -> None: """Load checkpoint weights, allowing legacy one-body metadata when present.""" if "onebody_energy" not in payload: model.load_state_dict(state_dict) return result = model.load_state_dict(state_dict, strict=False) missing = [ key for key in result.missing_keys if not key.startswith("onebody_terms.") ] if missing or result.unexpected_keys: details = [] if missing: details.append("missing keys: " + ", ".join(missing)) if result.unexpected_keys: details.append("unexpected keys: " + ", ".join(result.unexpected_keys)) raise RuntimeError( "checkpoint state_dict does not match the model factory output (" + "; ".join(details) + ")" ) def _apply_onebody_energy_if_present( model: UFPModel, payload: dict[str, Any], ) -> UFPModel: """Apply legacy single-element one-body checkpoint metadata when possible.""" if "onebody_energy" not in payload: return model value = float(payload["onebody_energy"]) onebody_terms = list(model.onebody_terms) if len(onebody_terms) == 1 and onebody_terms[0].values.numel() == 1: onebody_terms[0].values.data.fill_(value) return model if len(onebody_terms) != 0: raise ValueError( "checkpoint contains scalar `onebody_energy`, but the model factory " "returned a model with incompatible one-body terms" ) if model.atomic_types is None or len(model.atomic_types) != 1: raise ValueError( "checkpoint contains scalar `onebody_energy`; the model factory must " "return either one compatible ElementOneBodyTerm or a single-element " "model so the reference energy can be inserted" ) return UFPModel( onebody_terms=[ ElementOneBodyTerm( atomic_types=model.atomic_types, values=torch.tensor([value], dtype=model.preferred_dtype()), trainable=False, ) ], pair_terms=tuple(model.pair_terms), threebody_terms=tuple(model.threebody_terms), atomic_types=model.atomic_types, neighbor_backend=model.neighbor_backend, )
[docs] def export_uf23_checkpoint( checkpoint_path: str | Path, model_factory: Callable[[], UFPModel], output_path: str | Path, *, length_unit: str = "Angstrom", energy_unit: str = "eV", supported_devices: Sequence[str] = ("cuda", "cpu"), collect_extensions: str | Path | None = None, check: bool = True, ) -> None: """Rebuild a UF2+3 model, load a checkpoint state_dict, and export it.""" model = model_factory() if not isinstance(model, UFPModel): raise TypeError("`model_factory` must return a UFPModel") payload = _load_checkpoint_payload(checkpoint_path) state_dict = _checkpoint_state_dict(payload) _load_checkpoint_state_dict(model, state_dict, payload) model = _apply_onebody_energy_if_present(model, payload) export_uf23_atomistic_model( model, output_path, length_unit=length_unit, energy_unit=energy_unit, supported_devices=supported_devices, collect_extensions=collect_extensions, check=check, )
__all__ = [ "build_uf23_model_state", "export_uf23_atomistic_model", "export_uf23_checkpoint", ]