"""Reusable workflow helpers for examples and small supervised UFP studies."""
from __future__ import annotations
import copy
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence, cast
import ase
import numpy as np
import torch
from ase.optimize import BFGS
from ufp.adapters.ase import UFPASECalculator
from ufp.terms import (
AlchemicalCoefficients,
ChargeScaledSplinePairTerm,
ChargeSelfEnergyTerm,
CollinearSpinExchangeTerm,
CollinearSpinLandauTerm,
CutoffEnvelope,
ElementOneBodyTerm,
LocalChargeCoulombTerm,
RepulsiveSplineTwoBodyTerm,
SplinePairTerm,
SplineThreeBodyTerm,
SplineTriplet2DTerm,
SplineTwoBodyTerm,
UFPModel,
)
MODEL_SCHEMA_NAME = "ufp.model_schema"
MODEL_SCHEMA_VERSION = 1
def _shape(value: torch.Tensor | Sequence[int]) -> list[int]:
"""Return one tensor or shape-like value as a JSON-friendly shape."""
if isinstance(value, torch.Tensor):
return [int(dim) for dim in value.shape]
return [int(dim) for dim in value]
def _object_sequence(value: object) -> Sequence[object]:
"""Return ``value`` as a sequence for schema reconstruction."""
return cast(Sequence[object], value)
def _mapping_sequence(value: object) -> Sequence[Mapping[str, object]]:
"""Return ``value`` as a sequence of mapping entries."""
return cast(Sequence[Mapping[str, object]], value)
def _schema_shape(value: object) -> tuple[int, ...]:
"""Return a schema shape object as a tuple of integers."""
return tuple(int(cast(Any, dim)) for dim in _object_sequence(value))
def _schema_int(value: object) -> int:
"""Return a schema scalar as an integer."""
return int(cast(Any, value))
def _schema_float(value: object) -> float:
"""Return a schema scalar as a float."""
return float(cast(Any, value))
def _optional_schema_int(value: object) -> int | None:
"""Return an optional schema scalar as an integer."""
return None if value is None else _schema_int(value)
def _schema_atomic_types(value: object) -> tuple[int, ...] | None:
"""Return optional atomic types from schema metadata."""
if value is None:
return None
return tuple(_schema_int(item) for item in _object_sequence(value))
def _schema_pair(value: object) -> tuple[int, int]:
"""Return an integer pair from schema metadata."""
items = tuple(_schema_int(item) for item in _object_sequence(value))
if len(items) != 2:
raise ValueError("schema pair entries must contain exactly two integers")
return items[0], items[1]
def _schema_triplet(value: object) -> tuple[int, int, int]:
"""Return an integer triplet from schema metadata."""
items = tuple(_schema_int(item) for item in _object_sequence(value))
if len(items) != 3:
raise ValueError("schema triplet entries must contain exactly three integers")
return items[0], items[1], items[2]
def _schema_pairs(value: object) -> list[tuple[int, int]]:
"""Return pair entries from schema metadata."""
return [_schema_pair(item) for item in _object_sequence(value)]
def _schema_triplets(value: object) -> list[tuple[int, int, int]]:
"""Return triplet entries from schema metadata."""
return [_schema_triplet(item) for item in _object_sequence(value)]
def _provider_schema(
provider: AlchemicalCoefficients,
*,
index: int,
) -> dict[str, object]:
"""Return a reconstructable schema for one alchemical coefficient provider."""
weights = provider.weights
return {
"index": int(index),
"n_true_terms": int(provider.n_true_terms),
"proxy_coeff_shape": _shape(provider.proxy_coeffs),
"weights_shape": None if weights is None else _shape(weights),
"proxy_trainable": bool(provider.proxy_coeffs.requires_grad),
"weights_trainable": isinstance(weights, torch.nn.Parameter),
}
def _provider_index(
providers: Sequence[AlchemicalCoefficients],
provider: AlchemicalCoefficients | None,
) -> int | None:
"""Return the model-level provider index for one term provider."""
if provider is None:
return None
for index, item in enumerate(providers):
if item is provider:
return int(index)
raise ValueError("term references an alchemical provider outside the model")
[docs]
@dataclass(frozen=True)
class TermSchemaContext:
"""Shared reconstruction context passed to registered term schema decoders."""
providers: Sequence[AlchemicalCoefficients]
dtype: torch.dtype
atomic_types: tuple[int, ...] | None
provider: AlchemicalCoefficients | None
coefficient_index: int | None
TermSchemaEncoder = Callable[[object, int | None], dict[str, object]]
TermSchemaDecoder = Callable[[Mapping[str, object], TermSchemaContext], object]
[docs]
@dataclass(frozen=True)
class TermSchemaCodec:
"""Encoder/decoder pair for one checkpoint-reconstructable term class."""
term_type: type
class_name: str
encode: TermSchemaEncoder
decode: TermSchemaDecoder
_TERM_SCHEMA_CODECS_BY_TYPE: dict[type, TermSchemaCodec] = {}
_TERM_SCHEMA_CODECS_BY_NAME: dict[str, TermSchemaCodec] = {}
[docs]
def register_term_schema_codec(
term_type: type,
encode: TermSchemaEncoder,
decode: TermSchemaDecoder,
*,
class_name: str | None = None,
) -> None:
"""Register schema serialization for a reconstructable term class."""
name = class_name or term_type.__name__
codec = TermSchemaCodec(
term_type=term_type,
class_name=name,
encode=encode,
decode=decode,
)
_TERM_SCHEMA_CODECS_BY_TYPE[term_type] = codec
_TERM_SCHEMA_CODECS_BY_NAME[name] = codec
def _schema_codec_for_term(term: object) -> TermSchemaCodec:
"""Return the registered codec for a term instance."""
codec = _TERM_SCHEMA_CODECS_BY_TYPE.get(type(term))
if codec is not None:
return codec
for candidate in _TERM_SCHEMA_CODECS_BY_TYPE.values():
if isinstance(term, candidate.term_type):
return candidate
raise TypeError(f"unsupported term type for model schema: {type(term).__name__}")
def _term_schema(
term,
*,
group: str,
providers: Sequence[AlchemicalCoefficients],
) -> dict[str, object]:
"""Return a reconstructable schema for one supported example term."""
codec = _schema_codec_for_term(term)
provider_index = _provider_index(
providers,
getattr(term, "coefficient_provider", None),
)
common: dict[str, object] = {
"group": group,
"class": codec.class_name,
"cutoff": None if term.cutoff is None else float(term.cutoff),
"atomic_types": None
if term.atomic_types is None
else [int(value) for value in term.atomic_types],
}
common.update(codec.encode(term, provider_index))
return common
def _encode_element_onebody(
term: object, provider_index: int | None
) -> dict[str, object]:
"""Encode one element-reference term."""
del provider_index
typed = cast(ElementOneBodyTerm, term)
return {
"values_shape": _shape(typed.values),
"trainable": bool(typed.values.requires_grad),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_element_onebody(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> ElementOneBodyTerm:
"""Decode one element-reference term."""
return ElementOneBodyTerm(
atomic_types=context.atomic_types,
values=torch.zeros(_schema_shape(entry["values_shape"]), dtype=context.dtype),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
)
def _encode_charge_self_energy(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one charge self-energy term."""
del provider_index
typed = cast(ChargeSelfEnergyTerm, term)
return {
"electronegativities_shape": _shape(typed.electronegativities),
"hardnesses_shape": _shape(typed.hardnesses),
"trainable": bool(
typed.electronegativities.requires_grad and typed.hardnesses.requires_grad
),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_charge_self_energy(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> ChargeSelfEnergyTerm:
"""Decode one charge self-energy term."""
return ChargeSelfEnergyTerm(
atomic_types=context.atomic_types,
electronegativities=torch.zeros(
_schema_shape(entry["electronegativities_shape"]),
dtype=context.dtype,
),
hardnesses=torch.zeros(
_schema_shape(entry["hardnesses_shape"]),
dtype=context.dtype,
),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_spin_landau(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one collinear spin Landau term."""
del provider_index
typed = cast(CollinearSpinLandauTerm, term)
return {
"quadratic_shape": _shape(typed.quadratic),
"quartic_shape": _shape(typed.quartic),
"trainable": bool(
typed.quadratic.requires_grad and typed.quartic.requires_grad
),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_spin_landau(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> CollinearSpinLandauTerm:
"""Decode one collinear spin Landau term."""
return CollinearSpinLandauTerm(
atomic_types=context.atomic_types,
quadratic=torch.zeros(
_schema_shape(entry["quadratic_shape"]), dtype=context.dtype
),
quartic=torch.zeros(_schema_shape(entry["quartic_shape"]), dtype=context.dtype),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_spline_pair(term: object, provider_index: int | None) -> dict[str, object]:
"""Encode one pair spline term."""
typed = cast(SplinePairTerm, term)
return {
"pair": [int(value) for value in typed.pair],
"coeff_shape": None if provider_index is not None else _shape(typed.coeffs),
"coefficient_provider": provider_index,
"coefficient_index": typed.coefficient_index,
"symmetric": bool(typed.symmetric),
"spline": str(typed.spline),
"full_support_start": float(typed.full_support_start),
"eps": float(typed.eps),
"enabled": bool(typed.enabled),
"trainable": bool(provider_index is not None or typed.coeffs.requires_grad),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_spline_pair(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> SplinePairTerm:
"""Decode one pair spline term."""
coeff_shape = entry.get("coeff_shape")
return SplinePairTerm(
cutoff=_schema_float(entry["cutoff"]),
pair=_schema_pair(entry["pair"]),
coeffs=None
if context.provider is not None
else torch.zeros(_schema_shape(coeff_shape), dtype=context.dtype),
coefficient_provider=context.provider,
coefficient_index=context.coefficient_index,
symmetric=bool(entry.get("symmetric", True)),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
enabled=bool(entry.get("enabled", True)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_spline_twobody(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one categorized two-body spline term."""
typed = cast(SplineTwoBodyTerm, term)
return {
"coeff_shape": None
if provider_index is not None
else _shape(typed.coeffs_by_pair),
"coefficient_provider": provider_index,
"coefficient_index": typed.coefficient_index,
"active_pairs": [
[int(first), int(second)] for first, second in typed.active_pair_categories
],
"symmetric": bool(typed.symmetric),
"spline": str(typed.spline),
"full_support_start": float(typed.full_support_start),
"eps": float(typed.eps),
"trainable": bool(
provider_index is not None or typed.coeffs_by_pair.requires_grad
),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_spline_twobody(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> SplineTwoBodyTerm:
"""Decode one categorized two-body spline term."""
coeff_shape = entry.get("coeff_shape")
return SplineTwoBodyTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeffs_by_pair=None
if context.provider is not None
else torch.zeros(_schema_shape(coeff_shape), dtype=context.dtype),
coefficient_provider=context.provider,
coefficient_index=context.coefficient_index,
active_pairs=_schema_pairs(entry.get("active_pairs", ())),
symmetric=bool(entry.get("symmetric", True)),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_state_scaled_twobody(
term: ChargeScaledSplinePairTerm | CollinearSpinExchangeTerm,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one categorized state-scaled two-body spline term."""
return {
"coeff_shape": None
if provider_index is not None
else _shape(term.coeffs_by_pair),
"coefficient_provider": provider_index,
"coefficient_index": term.coefficient_index,
"active_pairs": [
[int(first), int(second)] for first, second in term.active_pair_categories
],
"symmetric": bool(term.symmetric),
"spline": str(term.spline),
"full_support_start": float(term.full_support_start),
"eps": float(term.eps),
"trainable": bool(
provider_index is not None or term.coeffs_by_pair.requires_grad
),
"fittable": bool(term.fittable),
"frozen": bool(term.frozen),
}
def _decode_charge_scaled_twobody(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> ChargeScaledSplinePairTerm:
"""Decode one local-charge-scaled spline pair term."""
coeff_shape = entry.get("coeff_shape")
return ChargeScaledSplinePairTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeffs_by_pair=None
if context.provider is not None
else torch.zeros(_schema_shape(coeff_shape), dtype=context.dtype),
coefficient_provider=context.provider,
coefficient_index=context.coefficient_index,
active_pairs=_schema_pairs(entry.get("active_pairs", ())),
symmetric=bool(entry.get("symmetric", True)),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _decode_spin_exchange(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> CollinearSpinExchangeTerm:
"""Decode one collinear spin exchange spline pair term."""
coeff_shape = entry.get("coeff_shape")
return CollinearSpinExchangeTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeffs_by_pair=None
if context.provider is not None
else torch.zeros(_schema_shape(coeff_shape), dtype=context.dtype),
coefficient_provider=context.provider,
coefficient_index=context.coefficient_index,
active_pairs=_schema_pairs(entry.get("active_pairs", ())),
symmetric=bool(entry.get("symmetric", True)),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_repulsive_twobody(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one repulsive two-body spline term."""
del provider_index
typed = cast(RepulsiveSplineTwoBodyTerm, term)
return {
"coeff_size": int(typed.coeff_size),
"transition_span": int(typed.transition_span),
"active_pairs": [
[int(first), int(second)] for first, second in typed.active_pair_categories
],
"symmetric": bool(typed.symmetric),
"spline": str(typed.spline),
"full_support_start": float(typed.full_support_start),
"eps": float(typed.eps),
"trainable": any(parameter.requires_grad for parameter in typed.parameters()),
}
def _decode_repulsive_twobody(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> RepulsiveSplineTwoBodyTerm:
"""Decode one repulsive two-body spline term."""
return RepulsiveSplineTwoBodyTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeff_size=_schema_int(entry["coeff_size"]),
transition_span=_schema_int(entry["transition_span"]),
active_pairs=_schema_pairs(entry.get("active_pairs", ())),
symmetric=bool(entry.get("symmetric", True)),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
dtype=context.dtype,
)
def _encode_cutoff_envelope(envelope: CutoffEnvelope) -> dict[str, object]:
"""Encode cutoff-envelope metadata for checkpoint schema."""
return dict(envelope.to_dict())
def _decode_cutoff_envelope(value: object) -> CutoffEnvelope | str | None:
"""Decode cutoff-envelope metadata from checkpoint schema."""
if value is None:
return None
if isinstance(value, Mapping):
return CutoffEnvelope(
cutoff=_schema_float(value["cutoff"]),
onset=_schema_float(value.get("onset", 0.0)),
kind=str(value.get("kind", "smoothstep")),
)
return str(value)
def _encode_local_charge_coulomb(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one local charge Coulomb term."""
del provider_index
typed = cast(LocalChargeCoulombTerm, term)
return {
"active_pairs": [
[int(first), int(second)] for first, second in typed.active_pair_categories
],
"symmetric": bool(typed.symmetric),
"softening": float(typed.softening),
"scale": float(typed.scale),
"cutoff_envelope": _encode_cutoff_envelope(typed.cutoff_envelope),
"eps": float(typed.eps),
}
def _decode_local_charge_coulomb(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> LocalChargeCoulombTerm:
"""Decode one local charge Coulomb term."""
return LocalChargeCoulombTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
active_pairs=_schema_pairs(entry.get("active_pairs", ())),
symmetric=bool(entry.get("symmetric", True)),
softening=_schema_float(entry.get("softening", 1.0e-6)),
scale=_schema_float(entry.get("scale", 1.0)),
cutoff_envelope=_decode_cutoff_envelope(entry.get("cutoff_envelope")),
eps=_schema_float(entry.get("eps", 1.0e-12)),
)
def _encode_spline_threebody(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one three-body spline term."""
typed = cast(SplineThreeBodyTerm, term)
return {
"coeff_shape": None
if provider_index is not None
else _shape(typed.coeffs_by_triplet),
"coefficient_provider": provider_index,
"coefficient_index": typed.coefficient_index,
"active_triplets": [
[int(source), int(first), int(second)]
for source, first, second in typed.active_triplet_categories
],
"spline": str(typed.spline),
"full_support_start_xy": float(typed.full_support_start_xy),
"full_support_start_z": float(typed.full_support_start_z),
"eps": float(typed.eps),
"trainable": bool(
provider_index is not None or typed.coeffs_by_triplet.requires_grad
),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_spline_threebody(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> SplineThreeBodyTerm:
"""Decode one three-body spline term."""
coeff_shape = entry.get("coeff_shape")
return SplineThreeBodyTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeffs_by_triplet=None
if context.provider is not None
else torch.zeros(_schema_shape(coeff_shape), dtype=context.dtype),
coefficient_provider=context.provider,
coefficient_index=context.coefficient_index,
active_triplets=_schema_triplets(entry.get("active_triplets", ())),
spline=str(entry.get("spline", "cubic")),
full_support_start_xy=_schema_float(entry.get("full_support_start_xy", 0.0)),
full_support_start_z=_schema_float(entry.get("full_support_start_z", 2.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
def _encode_spline_triplet2d(
term: object,
provider_index: int | None,
) -> dict[str, object]:
"""Encode one two-distance triplet spline term."""
if provider_index is not None:
raise TypeError("SplineTriplet2DTerm does not support coefficient providers")
typed = cast(SplineTriplet2DTerm, term)
return {
"coeff_shape": _shape(typed.coeffs_by_triplet),
"active_triplets": [
[int(source), int(first), int(second)]
for source, first, second in typed.active_triplet_categories
],
"spline": str(typed.spline),
"full_support_start": float(typed.full_support_start),
"eps": float(typed.eps),
"trainable": bool(typed.coeffs_by_triplet.requires_grad),
"fittable": bool(typed.fittable),
"frozen": bool(typed.frozen),
}
def _decode_spline_triplet2d(
entry: Mapping[str, object],
context: TermSchemaContext,
) -> SplineTriplet2DTerm:
"""Decode one two-distance triplet spline term."""
return SplineTriplet2DTerm(
cutoff=_schema_float(entry["cutoff"]),
atomic_types=context.atomic_types,
coeffs_by_triplet=torch.zeros(
_schema_shape(entry["coeff_shape"]),
dtype=context.dtype,
),
active_triplets=_schema_triplets(entry.get("active_triplets", ())),
spline=str(entry.get("spline", "cubic")),
full_support_start=_schema_float(entry.get("full_support_start", 0.0)),
eps=_schema_float(entry.get("eps", 1.0e-12)),
trainable=bool(entry.get("trainable", True)),
fittable=bool(entry.get("fittable", True)),
frozen=bool(entry.get("frozen", False)),
dtype=context.dtype,
)
register_term_schema_codec(
ElementOneBodyTerm,
_encode_element_onebody,
_decode_element_onebody,
)
register_term_schema_codec(
ChargeSelfEnergyTerm,
_encode_charge_self_energy,
_decode_charge_self_energy,
)
register_term_schema_codec(
CollinearSpinLandauTerm,
_encode_spin_landau,
_decode_spin_landau,
)
register_term_schema_codec(SplinePairTerm, _encode_spline_pair, _decode_spline_pair)
register_term_schema_codec(
SplineTwoBodyTerm,
_encode_spline_twobody,
_decode_spline_twobody,
)
register_term_schema_codec(
ChargeScaledSplinePairTerm,
lambda term, provider_index: _encode_state_scaled_twobody(
cast(ChargeScaledSplinePairTerm, term),
provider_index,
),
_decode_charge_scaled_twobody,
)
register_term_schema_codec(
CollinearSpinExchangeTerm,
lambda term, provider_index: _encode_state_scaled_twobody(
cast(CollinearSpinExchangeTerm, term),
provider_index,
),
_decode_spin_exchange,
)
register_term_schema_codec(
RepulsiveSplineTwoBodyTerm,
_encode_repulsive_twobody,
_decode_repulsive_twobody,
)
register_term_schema_codec(
LocalChargeCoulombTerm,
_encode_local_charge_coulomb,
_decode_local_charge_coulomb,
)
register_term_schema_codec(
SplineThreeBodyTerm,
_encode_spline_threebody,
_decode_spline_threebody,
)
register_term_schema_codec(
SplineTriplet2DTerm,
_encode_spline_triplet2d,
_decode_spline_triplet2d,
)
[docs]
def model_schema(model: UFPModel) -> dict[str, object]:
"""Return a reconstructable model schema for supported UFP example terms."""
providers = tuple(model.alchemical_coefficients)
groups = {
id(term): ("onebody", index) for index, term in enumerate(model.onebody_terms)
}
groups.update(
{id(term): ("pair", index) for index, term in enumerate(model.pair_terms)}
)
groups.update(
{
id(term): ("threebody", index)
for index, term in enumerate(model.threebody_terms)
}
)
groups.update(
{id(term): ("other", index) for index, term in enumerate(model.other_terms)}
)
return {
"schema": {
"name": MODEL_SCHEMA_NAME,
"version": MODEL_SCHEMA_VERSION,
},
"model_class": type(model).__name__,
"atomic_types": None
if model.atomic_types is None
else [int(value) for value in model.atomic_types],
"neighbor_backend": str(model.neighbor_backend.value),
"alchemical_coefficients": [
_provider_schema(provider, index=index)
for index, provider in enumerate(providers)
],
"terms": [
_term_schema(term, group=groups[id(term)][0], providers=providers)
for term in model.terms
],
}
def _provider_from_schema(
entry: Mapping[str, object],
*,
dtype: torch.dtype,
) -> AlchemicalCoefficients:
"""Build one alchemical coefficient provider from schema metadata."""
proxy_shape = _schema_shape(entry["proxy_coeff_shape"])
weights_shape = entry.get("weights_shape")
weights = (
None
if weights_shape is None
else torch.zeros(_schema_shape(weights_shape), dtype=dtype)
)
return AlchemicalCoefficients(
proxy_coeffs=torch.zeros(proxy_shape, dtype=dtype),
n_true_terms=_schema_int(entry["n_true_terms"]),
weights=weights,
proxy_trainable=bool(entry.get("proxy_trainable", True)),
weights_trainable=bool(entry.get("weights_trainable", True)),
)
def _provider_ref(
providers: Sequence[AlchemicalCoefficients],
index: object,
) -> AlchemicalCoefficients | None:
"""Resolve an optional provider index from a term schema."""
if index is None:
return None
return providers[_schema_int(index)]
def _term_from_schema(
entry: Mapping[str, object],
*,
providers: Sequence[AlchemicalCoefficients],
dtype: torch.dtype,
):
"""Build one supported example term from schema metadata."""
class_name = str(entry["class"])
try:
codec = _TERM_SCHEMA_CODECS_BY_NAME[class_name]
except KeyError as exc:
raise ValueError(
f"unsupported term class in model schema: {class_name}"
) from exc
atomic_types = _schema_atomic_types(entry.get("atomic_types"))
provider = _provider_ref(providers, entry.get("coefficient_provider"))
coefficient_index = _optional_schema_int(entry.get("coefficient_index"))
context = TermSchemaContext(
providers=providers,
dtype=dtype,
atomic_types=atomic_types,
provider=provider,
coefficient_index=coefficient_index,
)
return codec.decode(entry, context)
[docs]
def model_from_schema(
schema: Mapping[str, object],
*,
dtype: torch.dtype = torch.float64,
) -> UFPModel:
"""Reconstruct a supported UFP example model from a checkpoint schema."""
schema_info = schema.get("schema")
if not isinstance(schema_info, Mapping):
raise ValueError("model schema is missing schema metadata")
if schema_info.get("name") != MODEL_SCHEMA_NAME:
raise ValueError("unsupported model schema name")
if int(schema_info.get("version", 0)) != MODEL_SCHEMA_VERSION:
raise ValueError("unsupported model schema version")
providers = tuple(
_provider_from_schema(entry, dtype=dtype)
for entry in _mapping_sequence(schema.get("alchemical_coefficients", ()))
)
terms = []
for entry in _mapping_sequence(schema["terms"]):
term = _term_from_schema(entry, providers=providers, dtype=dtype)
terms.append(term)
return UFPModel(
terms=terms,
atomic_types=_schema_atomic_types(schema.get("atomic_types")),
neighbor_backend=str(schema.get("neighbor_backend", "auto")),
)
[docs]
def load_model_from_checkpoint(
checkpoint: Mapping[str, object],
*,
dtype: torch.dtype = torch.float64,
strict: bool = True,
) -> UFPModel:
"""Build a model from a checkpoint ``model_schema`` and load its state."""
if "model_schema" not in checkpoint:
raise ValueError("checkpoint does not contain a `model_schema` entry")
schema = checkpoint["model_schema"]
if not isinstance(schema, Mapping):
raise ValueError("checkpoint `model_schema` must be a mapping")
model = model_from_schema(schema, dtype=dtype)
state_dict = checkpoint.get("state_dict")
if not isinstance(state_dict, Mapping):
raise ValueError("checkpoint does not contain a state_dict mapping")
model.load_state_dict(state_dict, strict=strict)
return model
[docs]
def add_element_reference_term(
interaction_model: UFPModel,
atomic_type: int,
onebody_energy: float,
*,
dtype: torch.dtype | None = None,
trainable: bool = False,
) -> UFPModel:
"""Return a model with a one-element one-body term prepended."""
resolved_dtype = interaction_model.preferred_dtype() if dtype is None else dtype
pair_terms = tuple(copy.deepcopy(term) for term in interaction_model.pair_terms)
threebody_terms = tuple(
copy.deepcopy(term) for term in interaction_model.threebody_terms
)
return UFPModel(
onebody_terms=[
ElementOneBodyTerm(
atomic_types=[int(atomic_type)],
values=torch.tensor([onebody_energy], dtype=resolved_dtype),
trainable=trainable,
)
],
pair_terms=pair_terms,
threebody_terms=threebody_terms,
atomic_types=[int(atomic_type)],
neighbor_backend=interaction_model.neighbor_backend,
)
[docs]
def save_checkpoint(
filename: Path,
*,
interaction_model: UFPModel,
metadata: dict[str, object],
onebody_energy: float | None = None,
) -> None:
"""Save a simple torch checkpoint for a fitted model."""
payload = {
"state_dict": interaction_model.state_dict(),
"model_schema": model_schema(interaction_model),
"metadata": dict(metadata),
}
if onebody_energy is not None:
payload["onebody_energy"] = float(onebody_energy)
torch.save(payload, filename)
[docs]
def demonstrate_calculator(
model: UFPModel,
frame: ase.Atoms,
) -> dict[str, object]:
"""Evaluate one structure with the ASE calculator wrapper."""
atoms = frame.copy()
atoms.calc = UFPASECalculator(model)
forces = atoms.get_forces()
return {
"energy": float(atoms.get_potential_energy()),
"forces": forces,
"max_force": float(np.max(np.abs(forces))),
}
[docs]
def relax_structure(
model: UFPModel,
frame: ase.Atoms,
*,
fmax: float = 0.05,
steps: int = 50,
) -> dict[str, object]:
"""Run a short ASE relaxation using UFP forces."""
atoms = frame.copy()
atoms.calc = UFPASECalculator(model)
optimizer = BFGS(atoms, logfile=None)
optimizer.run(fmax=fmax, steps=steps)
forces = atoms.get_forces()
return {
"atoms": atoms,
"energy": float(atoms.get_potential_energy()),
"forces": forces,
"max_force": float(np.max(np.abs(forces))),
"steps": optimizer.nsteps,
}