"""Self-describing workflow checkpoint helpers."""
from __future__ import annotations
import hashlib
from collections.abc import Mapping, Sequence
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, cast
import torch
from torch.serialization import MAP_LOCATION
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import (
BlockSelector,
ResolvedCoefficientSelectionEntry,
SelectedCoefficientBlock,
resolve_coefficient_selection_summary,
)
from ufp.terms.contracts import UFPTerm
from ufp.terms.model import UFPModel
from ufp.version import __version__
WORKFLOW_CHECKPOINT_SCHEMA_VERSION = 1
WORKFLOW_CHECKPOINT_SCHEMA_NAME = "ufp.workflow_checkpoint"
_REQUIRED_FIELDS = {
"schema",
"package",
"model",
"term_metadata",
"coefficient_layout",
"selector_metadata",
"fixed_coefficient_hashes",
"stage_metadata",
"projection_diagnostics",
"validation_metrics",
"state_dict",
}
[docs]
class WorkflowCheckpointError(ValueError):
"""Raised when a workflow checkpoint is missing or incompatible."""
def _callable_label(value: object) -> str:
"""Return a stable label for a callable or type."""
module = getattr(value, "__module__", type(value).__module__)
qualname = getattr(value, "__qualname__", type(value).__qualname__)
return f"{module}.{qualname}"
def _checkpoint_value(value: Any) -> Any:
"""Convert common result objects into plain checkpoint metadata values."""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, torch.dtype):
return str(value)
if isinstance(value, torch.device):
return str(value)
if isinstance(value, Path):
return str(value)
if isinstance(value, torch.Tensor):
tensor = value.detach().cpu()
if tensor.ndim == 0:
return tensor.item()
return tensor.tolist()
if isinstance(value, type) or callable(value):
return _callable_label(value)
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return _checkpoint_value(to_dict())
if is_dataclass(value) and not isinstance(value, type):
return _checkpoint_value(asdict(value))
if isinstance(value, Mapping):
return {str(key): _checkpoint_value(item) for key, item in value.items()}
if isinstance(value, tuple):
return [_checkpoint_value(item) for item in value]
if isinstance(value, list):
return [_checkpoint_value(item) for item in value]
if isinstance(value, set):
return sorted((_checkpoint_value(item) for item in value), key=str)
tolist = getattr(value, "tolist", None)
if callable(tolist):
return _checkpoint_value(tolist())
return repr(value)
def _dtype_name(dtype: torch.dtype) -> str:
"""Return a stable dtype name for metadata."""
return str(dtype)
def _shape_list(shape: Sequence[int]) -> list[int]:
"""Return a plain integer shape list."""
return [int(dim) for dim in shape]
def _tensor_metadata(tensor: torch.Tensor) -> dict[str, object]:
"""Return shape and dtype metadata for one tensor."""
return {
"dtype": _dtype_name(tensor.dtype),
"shape": _shape_list(tuple(tensor.shape)),
"requires_grad": bool(getattr(tensor, "requires_grad", False)),
}
def _term_group_metadata(model: UFPModel) -> dict[int, tuple[str, int]]:
"""Map term object ids to their model group and group-local index."""
groups = (
("onebody", model.onebody_terms),
("pair", model.pair_terms),
("threebody", model.threebody_terms),
("other", model.other_terms),
)
metadata: dict[int, tuple[str, int]] = {}
for group_name, terms in groups:
for group_index, term in enumerate(terms):
metadata[id(term)] = (group_name, int(group_index))
return metadata
def _term_metadata(
term: UFPTerm,
*,
index: int,
group: str,
group_index: int,
) -> dict[str, object]:
"""Return self-describing metadata for one model term."""
state = term.state_dict()
cutoff = term.cutoff
atomic_types = term.atomic_types
parameter_blocks = []
for block in term.parameter_blocks():
provider = block.coefficient_provider
parameter_blocks.append(
{
"name": block.name,
"kind": block.kind,
"label": block.label,
"shape": _shape_list(block.shape),
"regularization_group": block.regularization_group,
"fittable": bool(block.fittable),
"frozen": bool(block.frozen),
"assembler": None if block.assembler is None else str(block.assembler),
"coefficient_index": (
None
if block.coefficient_index is None
else int(block.coefficient_index)
),
"coefficient_provider": None
if provider is None
else {
"module": type(provider).__module__,
"class": type(provider).__name__,
"uses_identity_weights": bool(provider.uses_identity_weights),
},
}
)
return {
"index": int(index),
"group": group,
"group_index": int(group_index),
"module": type(term).__module__,
"class": type(term).__name__,
"cutoff": None if cutoff is None else float(cutoff),
"atomic_types": None
if atomic_types is None
else [int(value) for value in atomic_types],
"provides_forces": bool(term.provides_forces),
"state": {
name: _tensor_metadata(tensor.detach())
for name, tensor in sorted(state.items())
},
"parameter_blocks": parameter_blocks,
}
def _block_term_index(
block: TermBlock,
term_indices: Mapping[int, int],
) -> int:
"""Return the model-order term index that owns one layout block."""
return int(term_indices[id(block.term)])
def _slice_metadata(value: slice) -> dict[str, object]:
"""Return a plain representation of one coefficient slice."""
return {
"start": value.start,
"stop": value.stop,
"step": value.step,
}
def _compact_slice_metadata(value: slice | None) -> dict[str, int] | None:
"""Return a plain compact-solve slice representation."""
if value is None:
return None
return {"start": int(value.start), "stop": int(value.stop)}
def _indices_hash(indices: Sequence[int]) -> str:
"""Hash a sequence of coefficient indices."""
tensor = torch.tensor(tuple(int(index) for index in indices), dtype=torch.int64)
return hashlib.sha256(tensor.numpy().tobytes()).hexdigest()
def _selection_entry_metadata(
entry: ResolvedCoefficientSelectionEntry,
) -> dict[str, object]:
"""Return checkpoint metadata for one resolved selector entry."""
return {
"source": entry.source,
"block_index": int(entry.block_index),
"block_label": entry.block_label,
"size": int(entry.size),
"original_indices": [int(index) for index in entry.original_indices],
"original_indices_hash": _indices_hash(entry.original_indices),
"layout_indices": [int(index) for index in entry.layout_indices],
"layout_indices_hash": _indices_hash(entry.layout_indices),
"compact_slice": _compact_slice_metadata(entry.compact_slice),
"block_shape": _shape_list(entry.block_shape),
"channels": [
None if channel is None else [int(value) for value in channel]
for channel in entry.channels
],
"coefficient_slices": [
[_slice_metadata(component) for component in selector_slices]
for selector_slices in entry.coefficient_slices
],
"selectors": [
None if selector is None else repr(selector) for selector in entry.selectors
],
}
def _hash_tensor_values(tensor: torch.Tensor) -> str:
"""Return a value digest for a CPU-contiguous tensor."""
contiguous = tensor.detach().cpu().contiguous()
hasher = hashlib.sha256()
hasher.update(str(contiguous.dtype).encode("utf8"))
hasher.update(torch.tensor(contiguous.shape, dtype=torch.int64).numpy().tobytes())
hasher.update(contiguous.numpy().tobytes())
return hasher.hexdigest()
def _fixed_indices_by_block(
layout: ParameterLayout,
selected_blocks: Sequence[SelectedCoefficientBlock],
) -> dict[int, tuple[int, ...]]:
"""Return fixed coefficient indices for each layout block."""
selected = {
int(selection.block.index): set(int(index) for index in selection.indices)
for selection in selected_blocks
}
fixed: dict[int, tuple[int, ...]] = {}
for block in layout.blocks:
selected_indices = selected.get(int(block.index), set())
indices = tuple(
int(index)
for index in range(block.size)
if int(index) not in selected_indices
)
if indices:
fixed[int(block.index)] = indices
return fixed
[docs]
def fixed_coefficient_hashes(
layout: ParameterLayout,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
) -> dict[str, object]:
"""Return value hashes for coefficients outside the selected solve layout."""
resolved = resolve_coefficient_selection_summary(
layout,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
)
fixed_indices = _fixed_indices_by_block(layout, resolved.selected_blocks)
blocks = []
signature_hasher = hashlib.sha256()
for block in layout.blocks:
indices = fixed_indices.get(int(block.index), ())
if not indices:
continue
values = block.read().detach().reshape(-1).cpu()
index = torch.tensor(indices, dtype=torch.int64)
fixed_values = values.index_select(0, index).contiguous()
values_hash = _hash_tensor_values(fixed_values)
indices_hash = _indices_hash(indices)
entry = {
"block_index": int(block.index),
"block_label": block.label,
"fixed_indices": [int(index) for index in indices],
"fixed_indices_hash": indices_hash,
"dtype": _dtype_name(fixed_values.dtype),
"shape": _shape_list(tuple(fixed_values.shape)),
"values_hash": values_hash,
}
signature_hasher.update(str(entry["block_index"]).encode("utf8"))
signature_hasher.update(indices_hash.encode("utf8"))
signature_hasher.update(values_hash.encode("utf8"))
blocks.append(entry)
return {
"algorithm": "sha256",
"signature": signature_hasher.hexdigest(),
"blocks": blocks,
}
[docs]
def build_workflow_checkpoint(
model: UFPModel,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
stage_metadata: Mapping[str, object] | None = None,
projection_diagnostics: object | None = None,
validation_metrics: object | None = None,
metadata: Mapping[str, object] | None = None,
) -> dict[str, object]:
"""Build a plain dictionary workflow checkpoint payload."""
from ufp.workflows.models import model_schema
layout = ParameterLayout.from_model(model, include_frozen=True)
model_info = model_metadata(model)
return {
"schema": {
"name": WORKFLOW_CHECKPOINT_SCHEMA_NAME,
"version": WORKFLOW_CHECKPOINT_SCHEMA_VERSION,
},
"package": {
"name": "ufp",
"version": __version__,
},
"model": model_info,
"model_schema": model_schema(model),
"term_metadata": model_info["terms"],
"coefficient_layout": coefficient_layout_metadata(layout),
"selector_metadata": selector_metadata(
layout,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
),
"fixed_coefficient_hashes": fixed_coefficient_hashes(
layout,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
),
"stage_metadata": _checkpoint_value(
{} if stage_metadata is None else stage_metadata
),
"projection_diagnostics": _checkpoint_value(projection_diagnostics),
"validation_metrics": _checkpoint_value(validation_metrics),
"metadata": _checkpoint_value({} if metadata is None else metadata),
"state_dict": model.state_dict(),
}
def _torch_load(path: Path | str, *, map_location: MAP_LOCATION = "cpu") -> object:
"""Load a torch checkpoint with compatibility for older PyTorch releases."""
try:
return torch.load(path, map_location=map_location, weights_only=False)
except TypeError:
return torch.load(path, map_location=map_location)
[docs]
def save_workflow_checkpoint(
path: Path | str,
model: UFPModel,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
stage_metadata: Mapping[str, object] | None = None,
projection_diagnostics: object | None = None,
validation_metrics: object | None = None,
metadata: Mapping[str, object] | None = None,
) -> dict[str, object]:
"""Save a self-describing workflow checkpoint and return the payload."""
payload = build_workflow_checkpoint(
model,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
stage_metadata=stage_metadata,
projection_diagnostics=projection_diagnostics,
validation_metrics=validation_metrics,
metadata=metadata,
)
torch.save(payload, path)
return payload
def _require_mapping(payload: object, *, name: str) -> Mapping[str, object]:
"""Validate that one payload section is a mapping."""
if not isinstance(payload, Mapping):
raise WorkflowCheckpointError(f"`{name}` must be a dictionary")
return payload
[docs]
def validate_workflow_checkpoint(payload: object) -> Mapping[str, object]:
"""Validate the outer schema and required fields of a checkpoint payload."""
checkpoint = _require_mapping(payload, name="checkpoint")
missing = sorted(field for field in _REQUIRED_FIELDS if field not in checkpoint)
if missing:
joined = ", ".join(f"`{field}`" for field in missing)
raise WorkflowCheckpointError(
f"workflow checkpoint missing required field(s): {joined}"
)
schema = _require_mapping(checkpoint["schema"], name="schema")
schema_name = schema.get("name")
if schema_name != WORKFLOW_CHECKPOINT_SCHEMA_NAME:
raise WorkflowCheckpointError(
"unsupported workflow checkpoint schema "
f"{schema_name!r}; expected {WORKFLOW_CHECKPOINT_SCHEMA_NAME!r}"
)
schema_version = schema.get("version")
if schema_version != WORKFLOW_CHECKPOINT_SCHEMA_VERSION:
raise WorkflowCheckpointError(
"unsupported workflow checkpoint schema version "
f"{schema_version!r}; expected {WORKFLOW_CHECKPOINT_SCHEMA_VERSION}"
)
_require_mapping(checkpoint["package"], name="package")
_require_mapping(checkpoint["model"], name="model")
_require_mapping(checkpoint["coefficient_layout"], name="coefficient_layout")
_require_mapping(checkpoint["selector_metadata"], name="selector_metadata")
_require_mapping(
checkpoint["fixed_coefficient_hashes"],
name="fixed_coefficient_hashes",
)
state_dict = _require_mapping(checkpoint["state_dict"], name="state_dict")
if not all(isinstance(value, torch.Tensor) for value in state_dict.values()):
raise WorkflowCheckpointError("`state_dict` values must be tensors")
return checkpoint
def _metadata_mismatch(
expected: object,
actual: object,
*,
path: str,
) -> str | None:
"""Return the first mismatch path for two plain metadata values."""
if isinstance(expected, Mapping) and isinstance(actual, Mapping):
expected_keys = set(expected)
actual_keys = set(actual)
missing = sorted(expected_keys - actual_keys)
extra = sorted(actual_keys - expected_keys)
if missing:
return f"{path}.{missing[0]} missing from current model"
if extra:
return f"{path}.{extra[0]} not present in checkpoint"
for key in sorted(expected_keys):
mismatch = _metadata_mismatch(
expected[key],
actual[key],
path=f"{path}.{key}",
)
if mismatch is not None:
return mismatch
return None
if isinstance(expected, Sequence) and not isinstance(expected, (str, bytes)):
if not isinstance(actual, Sequence) or isinstance(actual, (str, bytes)):
return f"{path} type differs"
if len(expected) != len(actual):
return f"{path} length differs ({len(expected)} != {len(actual)})"
for index, (left, right) in enumerate(zip(expected, actual, strict=True)):
mismatch = _metadata_mismatch(left, right, path=f"{path}[{index}]")
if mismatch is not None:
return mismatch
return None
if expected != actual:
return f"{path} differs ({expected!r} != {actual!r})"
return None
def _fixed_block_by_index(
fixed_hashes: Mapping[str, object],
) -> dict[int, Mapping[str, object]]:
"""Index stored fixed-coefficient metadata by block index."""
blocks = fixed_hashes.get("blocks")
if not isinstance(blocks, Sequence) or isinstance(blocks, (str, bytes)):
raise WorkflowCheckpointError(
"`fixed_coefficient_hashes.blocks` must be a list"
)
result: dict[int, Mapping[str, object]] = {}
for item in blocks:
block = _require_mapping(item, name="fixed coefficient block")
block_index = block.get("block_index")
if not isinstance(block_index, int):
raise WorkflowCheckpointError(
"fixed coefficient block metadata must include integer `block_index`"
)
result[int(block_index)] = block
return result
[docs]
def validate_fixed_coefficient_hashes(
model: UFPModel,
fixed_hashes: Mapping[str, object],
) -> None:
"""Validate stored fixed-coefficient hashes against the current model."""
if fixed_hashes.get("algorithm") != "sha256":
raise WorkflowCheckpointError(
"`fixed_coefficient_hashes.algorithm` must be 'sha256'"
)
layout = ParameterLayout.from_model(model, include_frozen=True)
stored = _fixed_block_by_index(fixed_hashes)
for block in layout.blocks:
stored_block = stored.get(int(block.index))
if stored_block is None:
continue
raw_indices = stored_block.get("fixed_indices")
if not isinstance(raw_indices, Sequence) or isinstance(
raw_indices, (str, bytes)
):
raise WorkflowCheckpointError(
"fixed coefficient block metadata must include `fixed_indices`"
)
indices = tuple(int(index) for index in raw_indices)
values = block.read().detach().reshape(-1).cpu()
index = torch.tensor(indices, dtype=torch.int64)
values_hash = _hash_tensor_values(values.index_select(0, index).contiguous())
if values_hash != stored_block.get("values_hash"):
raise WorkflowCheckpointError(
"stale fixed coefficient hash for block "
f"{block.label!r} (index {int(block.index)})"
)
[docs]
def load_workflow_checkpoint(
path: Path | str,
model: UFPModel | None = None,
*,
map_location: MAP_LOCATION = "cpu",
load_state_dict: bool = True,
strict: bool = True,
validate_model: bool = True,
validate_fixed_coefficients: bool = False,
) -> Mapping[str, object]:
"""Load and optionally apply a self-describing workflow checkpoint."""
payload = validate_workflow_checkpoint(
_torch_load(path, map_location=map_location),
)
if model is not None:
if validate_model:
validate_model_metadata(
model,
_require_mapping(payload["model"], name="model"),
)
if validate_fixed_coefficients:
validate_fixed_coefficient_hashes(
model,
_require_mapping(
payload["fixed_coefficient_hashes"],
name="fixed_coefficient_hashes",
),
)
if load_state_dict:
state_dict = cast(Mapping[str, Any], payload["state_dict"])
model.load_state_dict(state_dict, strict=strict)
return payload
__all__ = [
"WORKFLOW_CHECKPOINT_SCHEMA_NAME",
"WORKFLOW_CHECKPOINT_SCHEMA_VERSION",
"WorkflowCheckpointError",
"build_workflow_checkpoint",
"coefficient_layout_metadata",
"fixed_coefficient_hashes",
"load_workflow_checkpoint",
"model_metadata",
"normalize_checkpoint_metadata",
"save_workflow_checkpoint",
"selector_metadata",
"validate_fixed_coefficient_hashes",
"validate_model_metadata",
"validate_workflow_checkpoint",
]