Source code for ufp.workflows.checkpoints

"""Self-describing workflow checkpoint helpers."""

from __future__ import annotations

import hashlib
from collections.abc import Mapping, Sequence
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, cast

import torch
from torch.serialization import MAP_LOCATION

from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import (
    BlockSelector,
    ResolvedCoefficientSelectionEntry,
    SelectedCoefficientBlock,
    resolve_coefficient_selection_summary,
)
from ufp.terms.contracts import UFPTerm
from ufp.terms.model import UFPModel
from ufp.version import __version__


WORKFLOW_CHECKPOINT_SCHEMA_VERSION = 1
WORKFLOW_CHECKPOINT_SCHEMA_NAME = "ufp.workflow_checkpoint"

_REQUIRED_FIELDS = {
    "schema",
    "package",
    "model",
    "term_metadata",
    "coefficient_layout",
    "selector_metadata",
    "fixed_coefficient_hashes",
    "stage_metadata",
    "projection_diagnostics",
    "validation_metrics",
    "state_dict",
}


[docs] class WorkflowCheckpointError(ValueError): """Raised when a workflow checkpoint is missing or incompatible."""
def _callable_label(value: object) -> str: """Return a stable label for a callable or type.""" module = getattr(value, "__module__", type(value).__module__) qualname = getattr(value, "__qualname__", type(value).__qualname__) return f"{module}.{qualname}" def _checkpoint_value(value: Any) -> Any: """Convert common result objects into plain checkpoint metadata values.""" if value is None or isinstance(value, (str, int, float, bool)): return value if isinstance(value, torch.dtype): return str(value) if isinstance(value, torch.device): return str(value) if isinstance(value, Path): return str(value) if isinstance(value, torch.Tensor): tensor = value.detach().cpu() if tensor.ndim == 0: return tensor.item() return tensor.tolist() if isinstance(value, type) or callable(value): return _callable_label(value) to_dict = getattr(value, "to_dict", None) if callable(to_dict): return _checkpoint_value(to_dict()) if is_dataclass(value) and not isinstance(value, type): return _checkpoint_value(asdict(value)) if isinstance(value, Mapping): return {str(key): _checkpoint_value(item) for key, item in value.items()} if isinstance(value, tuple): return [_checkpoint_value(item) for item in value] if isinstance(value, list): return [_checkpoint_value(item) for item in value] if isinstance(value, set): return sorted((_checkpoint_value(item) for item in value), key=str) tolist = getattr(value, "tolist", None) if callable(tolist): return _checkpoint_value(tolist()) return repr(value)
[docs] def normalize_checkpoint_metadata(value: Any) -> Any: """Return ``value`` converted into workflow-checkpoint metadata values.""" return _checkpoint_value(value)
def _dtype_name(dtype: torch.dtype) -> str: """Return a stable dtype name for metadata.""" return str(dtype) def _shape_list(shape: Sequence[int]) -> list[int]: """Return a plain integer shape list.""" return [int(dim) for dim in shape] def _tensor_metadata(tensor: torch.Tensor) -> dict[str, object]: """Return shape and dtype metadata for one tensor.""" return { "dtype": _dtype_name(tensor.dtype), "shape": _shape_list(tuple(tensor.shape)), "requires_grad": bool(getattr(tensor, "requires_grad", False)), } def _term_group_metadata(model: UFPModel) -> dict[int, tuple[str, int]]: """Map term object ids to their model group and group-local index.""" groups = ( ("onebody", model.onebody_terms), ("pair", model.pair_terms), ("threebody", model.threebody_terms), ("other", model.other_terms), ) metadata: dict[int, tuple[str, int]] = {} for group_name, terms in groups: for group_index, term in enumerate(terms): metadata[id(term)] = (group_name, int(group_index)) return metadata def _term_metadata( term: UFPTerm, *, index: int, group: str, group_index: int, ) -> dict[str, object]: """Return self-describing metadata for one model term.""" state = term.state_dict() cutoff = term.cutoff atomic_types = term.atomic_types parameter_blocks = [] for block in term.parameter_blocks(): provider = block.coefficient_provider parameter_blocks.append( { "name": block.name, "kind": block.kind, "label": block.label, "shape": _shape_list(block.shape), "regularization_group": block.regularization_group, "fittable": bool(block.fittable), "frozen": bool(block.frozen), "assembler": None if block.assembler is None else str(block.assembler), "coefficient_index": ( None if block.coefficient_index is None else int(block.coefficient_index) ), "coefficient_provider": None if provider is None else { "module": type(provider).__module__, "class": type(provider).__name__, "uses_identity_weights": bool(provider.uses_identity_weights), }, } ) return { "index": int(index), "group": group, "group_index": int(group_index), "module": type(term).__module__, "class": type(term).__name__, "cutoff": None if cutoff is None else float(cutoff), "atomic_types": None if atomic_types is None else [int(value) for value in atomic_types], "provides_forces": bool(term.provides_forces), "state": { name: _tensor_metadata(tensor.detach()) for name, tensor in sorted(state.items()) }, "parameter_blocks": parameter_blocks, }
[docs] def model_metadata(model: UFPModel) -> dict[str, object]: """Return model architecture and term metadata without parameter values.""" groups = _term_group_metadata(model) terms = [] for index, term in enumerate(model.terms): group, group_index = groups[id(term)] terms.append( _term_metadata( term, index=int(index), group=group, group_index=group_index, ) ) return { "module": type(model).__module__, "class": type(model).__name__, "atomic_types": None if model.atomic_types is None else [int(value) for value in model.atomic_types], "neighbor_backend": str(model.neighbor_backend.value), "cutoff": None if model.cutoff is None else float(model.cutoff), "n_terms": len(terms), "terms": terms, }
def _block_term_index( block: TermBlock, term_indices: Mapping[int, int], ) -> int: """Return the model-order term index that owns one layout block.""" return int(term_indices[id(block.term)])
[docs] def coefficient_layout_metadata(layout: ParameterLayout) -> dict[str, object]: """Return checkpoint metadata for a model coefficient layout.""" term_indices = {id(term): index for index, term in enumerate(layout.model.terms)} return { "size": int(layout.size), "blocks": [ { "index": int(block.index), "term_index": _block_term_index(block, term_indices), "name": block.name, "kind": block.kind, "label": block.label, "shape": _shape_list(block.shape), "start": int(block.start), "stop": int(block.stop), "size": int(block.size), "regularization_group": block.regularization_group, "fittable": bool(block.fittable), "frozen": bool(block.frozen), "coefficient_index": block.coefficient_index, "coefficient_provider": block.coefficient_provider is not None, } for block in layout.blocks ], "providers": [ { "block_indices": [int(index) for index in group.block_indices], "coefficient_shape": _shape_list(group.coefficient_shape), "n_true_terms": int(group.n_true_terms), "n_proxy_terms": int(group.n_proxy_terms), "uses_identity_weights": bool(group.uses_identity_weights), } for group in layout.providers ], }
def _slice_metadata(value: slice) -> dict[str, object]: """Return a plain representation of one coefficient slice.""" return { "start": value.start, "stop": value.stop, "step": value.step, } def _compact_slice_metadata(value: slice | None) -> dict[str, int] | None: """Return a plain compact-solve slice representation.""" if value is None: return None return {"start": int(value.start), "stop": int(value.stop)} def _indices_hash(indices: Sequence[int]) -> str: """Hash a sequence of coefficient indices.""" tensor = torch.tensor(tuple(int(index) for index in indices), dtype=torch.int64) return hashlib.sha256(tensor.numpy().tobytes()).hexdigest() def _selection_entry_metadata( entry: ResolvedCoefficientSelectionEntry, ) -> dict[str, object]: """Return checkpoint metadata for one resolved selector entry.""" return { "source": entry.source, "block_index": int(entry.block_index), "block_label": entry.block_label, "size": int(entry.size), "original_indices": [int(index) for index in entry.original_indices], "original_indices_hash": _indices_hash(entry.original_indices), "layout_indices": [int(index) for index in entry.layout_indices], "layout_indices_hash": _indices_hash(entry.layout_indices), "compact_slice": _compact_slice_metadata(entry.compact_slice), "block_shape": _shape_list(entry.block_shape), "channels": [ None if channel is None else [int(value) for value in channel] for channel in entry.channels ], "coefficient_slices": [ [_slice_metadata(component) for component in selector_slices] for selector_slices in entry.coefficient_slices ], "selectors": [ None if selector is None else repr(selector) for selector in entry.selectors ], }
[docs] def selector_metadata( layout: ParameterLayout, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), ) -> dict[str, object]: """Resolve and serialize fit/freeze selector metadata for one layout.""" resolved = resolve_coefficient_selection_summary( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ) return { "fit_blocks": None if fit_blocks is None else [repr(selector) for selector in fit_blocks], "freeze_blocks": [repr(selector) for selector in freeze_blocks], "selected_block_indices": [ int(selection.block.index) for selection in resolved.selected_blocks ], "entries": [_selection_entry_metadata(entry) for entry in resolved.entries], }
def _hash_tensor_values(tensor: torch.Tensor) -> str: """Return a value digest for a CPU-contiguous tensor.""" contiguous = tensor.detach().cpu().contiguous() hasher = hashlib.sha256() hasher.update(str(contiguous.dtype).encode("utf8")) hasher.update(torch.tensor(contiguous.shape, dtype=torch.int64).numpy().tobytes()) hasher.update(contiguous.numpy().tobytes()) return hasher.hexdigest() def _fixed_indices_by_block( layout: ParameterLayout, selected_blocks: Sequence[SelectedCoefficientBlock], ) -> dict[int, tuple[int, ...]]: """Return fixed coefficient indices for each layout block.""" selected = { int(selection.block.index): set(int(index) for index in selection.indices) for selection in selected_blocks } fixed: dict[int, tuple[int, ...]] = {} for block in layout.blocks: selected_indices = selected.get(int(block.index), set()) indices = tuple( int(index) for index in range(block.size) if int(index) not in selected_indices ) if indices: fixed[int(block.index)] = indices return fixed
[docs] def fixed_coefficient_hashes( layout: ParameterLayout, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), ) -> dict[str, object]: """Return value hashes for coefficients outside the selected solve layout.""" resolved = resolve_coefficient_selection_summary( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ) fixed_indices = _fixed_indices_by_block(layout, resolved.selected_blocks) blocks = [] signature_hasher = hashlib.sha256() for block in layout.blocks: indices = fixed_indices.get(int(block.index), ()) if not indices: continue values = block.read().detach().reshape(-1).cpu() index = torch.tensor(indices, dtype=torch.int64) fixed_values = values.index_select(0, index).contiguous() values_hash = _hash_tensor_values(fixed_values) indices_hash = _indices_hash(indices) entry = { "block_index": int(block.index), "block_label": block.label, "fixed_indices": [int(index) for index in indices], "fixed_indices_hash": indices_hash, "dtype": _dtype_name(fixed_values.dtype), "shape": _shape_list(tuple(fixed_values.shape)), "values_hash": values_hash, } signature_hasher.update(str(entry["block_index"]).encode("utf8")) signature_hasher.update(indices_hash.encode("utf8")) signature_hasher.update(values_hash.encode("utf8")) blocks.append(entry) return { "algorithm": "sha256", "signature": signature_hasher.hexdigest(), "blocks": blocks, }
[docs] def build_workflow_checkpoint( model: UFPModel, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), stage_metadata: Mapping[str, object] | None = None, projection_diagnostics: object | None = None, validation_metrics: object | None = None, metadata: Mapping[str, object] | None = None, ) -> dict[str, object]: """Build a plain dictionary workflow checkpoint payload.""" from ufp.workflows.models import model_schema layout = ParameterLayout.from_model(model, include_frozen=True) model_info = model_metadata(model) return { "schema": { "name": WORKFLOW_CHECKPOINT_SCHEMA_NAME, "version": WORKFLOW_CHECKPOINT_SCHEMA_VERSION, }, "package": { "name": "ufp", "version": __version__, }, "model": model_info, "model_schema": model_schema(model), "term_metadata": model_info["terms"], "coefficient_layout": coefficient_layout_metadata(layout), "selector_metadata": selector_metadata( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ), "fixed_coefficient_hashes": fixed_coefficient_hashes( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ), "stage_metadata": _checkpoint_value( {} if stage_metadata is None else stage_metadata ), "projection_diagnostics": _checkpoint_value(projection_diagnostics), "validation_metrics": _checkpoint_value(validation_metrics), "metadata": _checkpoint_value({} if metadata is None else metadata), "state_dict": model.state_dict(), }
def _torch_load(path: Path | str, *, map_location: MAP_LOCATION = "cpu") -> object: """Load a torch checkpoint with compatibility for older PyTorch releases.""" try: return torch.load(path, map_location=map_location, weights_only=False) except TypeError: return torch.load(path, map_location=map_location)
[docs] def save_workflow_checkpoint( path: Path | str, model: UFPModel, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), stage_metadata: Mapping[str, object] | None = None, projection_diagnostics: object | None = None, validation_metrics: object | None = None, metadata: Mapping[str, object] | None = None, ) -> dict[str, object]: """Save a self-describing workflow checkpoint and return the payload.""" payload = build_workflow_checkpoint( model, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, stage_metadata=stage_metadata, projection_diagnostics=projection_diagnostics, validation_metrics=validation_metrics, metadata=metadata, ) torch.save(payload, path) return payload
def _require_mapping(payload: object, *, name: str) -> Mapping[str, object]: """Validate that one payload section is a mapping.""" if not isinstance(payload, Mapping): raise WorkflowCheckpointError(f"`{name}` must be a dictionary") return payload
[docs] def validate_workflow_checkpoint(payload: object) -> Mapping[str, object]: """Validate the outer schema and required fields of a checkpoint payload.""" checkpoint = _require_mapping(payload, name="checkpoint") missing = sorted(field for field in _REQUIRED_FIELDS if field not in checkpoint) if missing: joined = ", ".join(f"`{field}`" for field in missing) raise WorkflowCheckpointError( f"workflow checkpoint missing required field(s): {joined}" ) schema = _require_mapping(checkpoint["schema"], name="schema") schema_name = schema.get("name") if schema_name != WORKFLOW_CHECKPOINT_SCHEMA_NAME: raise WorkflowCheckpointError( "unsupported workflow checkpoint schema " f"{schema_name!r}; expected {WORKFLOW_CHECKPOINT_SCHEMA_NAME!r}" ) schema_version = schema.get("version") if schema_version != WORKFLOW_CHECKPOINT_SCHEMA_VERSION: raise WorkflowCheckpointError( "unsupported workflow checkpoint schema version " f"{schema_version!r}; expected {WORKFLOW_CHECKPOINT_SCHEMA_VERSION}" ) _require_mapping(checkpoint["package"], name="package") _require_mapping(checkpoint["model"], name="model") _require_mapping(checkpoint["coefficient_layout"], name="coefficient_layout") _require_mapping(checkpoint["selector_metadata"], name="selector_metadata") _require_mapping( checkpoint["fixed_coefficient_hashes"], name="fixed_coefficient_hashes", ) state_dict = _require_mapping(checkpoint["state_dict"], name="state_dict") if not all(isinstance(value, torch.Tensor) for value in state_dict.values()): raise WorkflowCheckpointError("`state_dict` values must be tensors") return checkpoint
def _metadata_mismatch( expected: object, actual: object, *, path: str, ) -> str | None: """Return the first mismatch path for two plain metadata values.""" if isinstance(expected, Mapping) and isinstance(actual, Mapping): expected_keys = set(expected) actual_keys = set(actual) missing = sorted(expected_keys - actual_keys) extra = sorted(actual_keys - expected_keys) if missing: return f"{path}.{missing[0]} missing from current model" if extra: return f"{path}.{extra[0]} not present in checkpoint" for key in sorted(expected_keys): mismatch = _metadata_mismatch( expected[key], actual[key], path=f"{path}.{key}", ) if mismatch is not None: return mismatch return None if isinstance(expected, Sequence) and not isinstance(expected, (str, bytes)): if not isinstance(actual, Sequence) or isinstance(actual, (str, bytes)): return f"{path} type differs" if len(expected) != len(actual): return f"{path} length differs ({len(expected)} != {len(actual)})" for index, (left, right) in enumerate(zip(expected, actual, strict=True)): mismatch = _metadata_mismatch(left, right, path=f"{path}[{index}]") if mismatch is not None: return mismatch return None if expected != actual: return f"{path} differs ({expected!r} != {actual!r})" return None
[docs] def validate_model_metadata( model: UFPModel, checkpoint_model_metadata: Mapping[str, object], ) -> None: """Raise if a current model is incompatible with checkpoint metadata.""" current = model_metadata(model) mismatch = _metadata_mismatch(checkpoint_model_metadata, current, path="model") if mismatch is not None: raise WorkflowCheckpointError(f"incompatible model metadata: {mismatch}")
def _fixed_block_by_index( fixed_hashes: Mapping[str, object], ) -> dict[int, Mapping[str, object]]: """Index stored fixed-coefficient metadata by block index.""" blocks = fixed_hashes.get("blocks") if not isinstance(blocks, Sequence) or isinstance(blocks, (str, bytes)): raise WorkflowCheckpointError( "`fixed_coefficient_hashes.blocks` must be a list" ) result: dict[int, Mapping[str, object]] = {} for item in blocks: block = _require_mapping(item, name="fixed coefficient block") block_index = block.get("block_index") if not isinstance(block_index, int): raise WorkflowCheckpointError( "fixed coefficient block metadata must include integer `block_index`" ) result[int(block_index)] = block return result
[docs] def validate_fixed_coefficient_hashes( model: UFPModel, fixed_hashes: Mapping[str, object], ) -> None: """Validate stored fixed-coefficient hashes against the current model.""" if fixed_hashes.get("algorithm") != "sha256": raise WorkflowCheckpointError( "`fixed_coefficient_hashes.algorithm` must be 'sha256'" ) layout = ParameterLayout.from_model(model, include_frozen=True) stored = _fixed_block_by_index(fixed_hashes) for block in layout.blocks: stored_block = stored.get(int(block.index)) if stored_block is None: continue raw_indices = stored_block.get("fixed_indices") if not isinstance(raw_indices, Sequence) or isinstance( raw_indices, (str, bytes) ): raise WorkflowCheckpointError( "fixed coefficient block metadata must include `fixed_indices`" ) indices = tuple(int(index) for index in raw_indices) values = block.read().detach().reshape(-1).cpu() index = torch.tensor(indices, dtype=torch.int64) values_hash = _hash_tensor_values(values.index_select(0, index).contiguous()) if values_hash != stored_block.get("values_hash"): raise WorkflowCheckpointError( "stale fixed coefficient hash for block " f"{block.label!r} (index {int(block.index)})" )
[docs] def load_workflow_checkpoint( path: Path | str, model: UFPModel | None = None, *, map_location: MAP_LOCATION = "cpu", load_state_dict: bool = True, strict: bool = True, validate_model: bool = True, validate_fixed_coefficients: bool = False, ) -> Mapping[str, object]: """Load and optionally apply a self-describing workflow checkpoint.""" payload = validate_workflow_checkpoint( _torch_load(path, map_location=map_location), ) if model is not None: if validate_model: validate_model_metadata( model, _require_mapping(payload["model"], name="model"), ) if validate_fixed_coefficients: validate_fixed_coefficient_hashes( model, _require_mapping( payload["fixed_coefficient_hashes"], name="fixed_coefficient_hashes", ), ) if load_state_dict: state_dict = cast(Mapping[str, Any], payload["state_dict"]) model.load_state_dict(state_dict, strict=strict) return payload
__all__ = [ "WORKFLOW_CHECKPOINT_SCHEMA_NAME", "WORKFLOW_CHECKPOINT_SCHEMA_VERSION", "WorkflowCheckpointError", "build_workflow_checkpoint", "coefficient_layout_metadata", "fixed_coefficient_hashes", "load_workflow_checkpoint", "model_metadata", "normalize_checkpoint_metadata", "save_workflow_checkpoint", "selector_metadata", "validate_fixed_coefficient_hashes", "validate_model_metadata", "validate_workflow_checkpoint", ]