Source code for ufp.workflows.stages

"""Explicit workflow stage objects for composing existing UFP helpers."""

from __future__ import annotations

from collections.abc import Callable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Protocol, cast

import torch

from ufp.leastsquares import LinearFitter
from ufp.terms import UFPModel
from ufp.training import evaluate_model, fit_model, freeze_model_coefficients
from ufp.workflows.checkpoints import normalize_checkpoint_metadata
from ufp.workflows.residuals import materialize_residual_dataset


STAGE_METADATA_VERSION = 1


[docs] @dataclass(frozen=True) class StageResult: """Outputs and metadata produced by one explicit workflow stage.""" outputs: Mapping[str, object] metadata: Mapping[str, object]
[docs] def update_context( self, context: MutableMapping[str, object], ) -> MutableMapping[str, object]: """Update a user-owned context mapping with this stage's outputs.""" context.update(self.outputs) return context
[docs] class WorkflowStage(Protocol): """Protocol implemented by lightweight workflow stages.""" @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by the stage.""" ... @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by the stage.""" ... @property def metadata(self) -> Mapping[str, object]: """JSON-friendly stage metadata.""" ...
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Execute the stage against a user-owned context mapping.""" ...
def _stage_metadata( *, stage_type: str, name: str, required_inputs: Sequence[str], produced_outputs: Sequence[str], config: Mapping[str, object], ) -> dict[str, object]: """Return shared stage metadata.""" return { "version": STAGE_METADATA_VERSION, "stage_type": stage_type, "name": name, "required_inputs": tuple(required_inputs), "produced_outputs": tuple(produced_outputs), "config": normalize_checkpoint_metadata(dict(config)), }
[docs] def workflow_stage_metadata( stages_or_results: Sequence[WorkflowStage | StageResult | Mapping[str, object]], *, name: str = "workflow", ) -> dict[str, object]: """Return checkpoint-ready metadata for an explicit stage sequence.""" entries: list[dict[str, object]] = [] for index, item in enumerate(stages_or_results): metadata = getattr(item, "metadata", item) if not isinstance(metadata, Mapping): raise TypeError("workflow stage metadata entries must be mappings") entry: dict[str, object] = {"index": int(index)} entry.update({str(key): value for key, value in metadata.items()}) entries.append(entry) return cast( dict[str, object], normalize_checkpoint_metadata( { "version": STAGE_METADATA_VERSION, "name": name, "stages": entries, } ), )
def _require_context(context: Mapping[str, object], key: str) -> object: """Read one required context key with a clear error.""" try: return context[key] except KeyError as exc: raise KeyError(f"workflow stage input {key!r} is missing") from exc
[docs] @dataclass(frozen=True) class LinearFitStage: """Selector-aware wrapper around :class:`ufp.leastsquares.LinearFitter`.""" name: str = "linear_fit" model_key: str = "model" samples_key: str = "fit_samples" result_key: str = "linear_fit_result" fitter_kwargs: Mapping[str, object] = field(default_factory=dict) fit_kwargs: Mapping[str, object] = field(default_factory=dict) @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by this stage.""" return (self.model_key, self.samples_key) @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by this stage.""" return (self.result_key, self.model_key) @property def metadata(self) -> Mapping[str, object]: """Return JSON-friendly stage metadata.""" return _stage_metadata( stage_type="linear_fit", name=self.name, required_inputs=self.required_inputs, produced_outputs=self.produced_outputs, config={ "fitter_kwargs": dict(self.fitter_kwargs), "fit_kwargs": dict(self.fit_kwargs), }, )
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Build a ``LinearFitter`` and execute its ``fit`` method.""" model = _require_context(context, self.model_key) if not isinstance(model, UFPModel): raise TypeError(f"context[{self.model_key!r}] must be a UFPModel") samples = _require_context(context, self.samples_key) fitter = LinearFitter(model, **dict(self.fitter_kwargs)) result = fitter.fit(cast(Any, samples), **dict(self.fit_kwargs)) return StageResult( outputs={ self.result_key: result, self.model_key: model, }, metadata=self.metadata, )
[docs] @dataclass(frozen=True) class TrainStage: """Wrapper around optimizer training with optional coefficient freeze masks.""" name: str = "train" model_key: str = "model" train_loader_key: str = "train_loader" history_key: str = "training_history" optimizer_key: str = "optimizer" freeze_state_key: str = "freeze_state" freeze_selectors: Sequence[int | str] = () optimizer_factory: Callable[..., torch.optim.Optimizer] = torch.optim.Adam optimizer_kwargs: Mapping[str, object] = field( default_factory=lambda: {"lr": 1.0e-3} ) fit_kwargs: Mapping[str, object] = field(default_factory=lambda: {"epochs": 1}) @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by this stage.""" return (self.model_key, self.train_loader_key) @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by this stage.""" return (self.history_key, self.optimizer_key, self.freeze_state_key) @property def metadata(self) -> Mapping[str, object]: """Return JSON-friendly stage metadata.""" return _stage_metadata( stage_type="train", name=self.name, required_inputs=self.required_inputs, produced_outputs=self.produced_outputs, config={ "freeze_selectors": tuple(self.freeze_selectors), "optimizer_factory": self.optimizer_factory, "optimizer_kwargs": dict(self.optimizer_kwargs), "fit_kwargs": dict(self.fit_kwargs), }, ) def _optimizer(self, model: UFPModel, context: Mapping[str, object]): """Return an existing optimizer or create one from stage config.""" existing = context.get(self.optimizer_key) if existing is not None: if not isinstance(existing, torch.optim.Optimizer): raise TypeError( f"context[{self.optimizer_key!r}] must be a torch optimizer" ) return existing return self.optimizer_factory(model.parameters(), **dict(self.optimizer_kwargs))
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Train a model with optional coefficient freeze masks.""" model = _require_context(context, self.model_key) if not isinstance(model, UFPModel): raise TypeError(f"context[{self.model_key!r}] must be a UFPModel") train_loader = _require_context(context, self.train_loader_key) optimizer = self._optimizer(model, context) freeze_state = ( freeze_model_coefficients(model, self.freeze_selectors) if self.freeze_selectors else None ) if freeze_state is not None: freeze_state.wrap_optimizer(optimizer) try: history = cast(Any, fit_model)( model, train_loader, optimizer=optimizer, **dict(self.fit_kwargs), ) finally: if freeze_state is not None: freeze_state.unwrap_optimizer(optimizer) return StageResult( outputs={ self.history_key: history, self.optimizer_key: optimizer, self.freeze_state_key: freeze_state, }, metadata=self.metadata, )
[docs] @dataclass(frozen=True) class ProjectStage: """Projection stage that delegates to an explicit projection callable.""" projector: Callable[..., object] name: str = "project" input_bindings: Mapping[str, str] = field(default_factory=dict) projector_kwargs: Mapping[str, object] = field(default_factory=dict) result_key: str = "projection_result" @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by this stage.""" return tuple(dict.fromkeys(self.input_bindings.values())) @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by this stage.""" return (self.result_key,) @property def metadata(self) -> Mapping[str, object]: """Return JSON-friendly stage metadata.""" return _stage_metadata( stage_type="project", name=self.name, required_inputs=self.required_inputs, produced_outputs=self.produced_outputs, config={ "projector": self.projector, "input_bindings": dict(self.input_bindings), "projector_kwargs": dict(self.projector_kwargs), }, )
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Call the configured projection helper.""" kwargs = { argument: _require_context(context, key) for argument, key in self.input_bindings.items() } kwargs.update(dict(self.projector_kwargs)) result = self.projector(**kwargs) return StageResult( outputs={self.result_key: result}, metadata=self.metadata, )
[docs] @dataclass(frozen=True) class ResidualizeStage: """Wrapper around nonlinear frozen-component residual materialization.""" name: str = "residualize" model_key: str = "model" dataset_key: str = "dataset" result_key: str = "residual_result" dataset_output_key: str = "residual_dataset" residual_kwargs: Mapping[str, object] = field(default_factory=dict) @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by this stage.""" return (self.model_key, self.dataset_key) @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by this stage.""" return (self.result_key, self.dataset_output_key) @property def metadata(self) -> Mapping[str, object]: """Return JSON-friendly stage metadata.""" return _stage_metadata( stage_type="residualize", name=self.name, required_inputs=self.required_inputs, produced_outputs=self.produced_outputs, config={"residual_kwargs": dict(self.residual_kwargs)}, )
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Materialize residual labels for a dataset.""" model = _require_context(context, self.model_key) if not isinstance(model, UFPModel): raise TypeError(f"context[{self.model_key!r}] must be a UFPModel") dataset = _require_context(context, self.dataset_key) result = cast(Any, materialize_residual_dataset)( model, dataset, **dict(self.residual_kwargs), ) return StageResult( outputs={ self.result_key: result, self.dataset_output_key: result.dataset, }, metadata=self.metadata, )
[docs] @dataclass(frozen=True) class ValidateStage: """Validation/metrics stage that delegates to ``evaluate_model``.""" name: str = "validate" model_key: str = "model" loader_key: str = "validation_loader" metrics_key: str = "validation_metrics" evaluate_kwargs: Mapping[str, object] = field(default_factory=dict) @property def required_inputs(self) -> tuple[str, ...]: """Context keys required by this stage.""" return (self.model_key, self.loader_key) @property def produced_outputs(self) -> tuple[str, ...]: """Context keys produced by this stage.""" return (self.metrics_key,) @property def metadata(self) -> Mapping[str, object]: """Return JSON-friendly stage metadata.""" return _stage_metadata( stage_type="validate", name=self.name, required_inputs=self.required_inputs, produced_outputs=self.produced_outputs, config={"evaluate_kwargs": dict(self.evaluate_kwargs)}, )
[docs] def run(self, context: Mapping[str, object]) -> StageResult: """Evaluate a model and return aggregate metrics.""" model = _require_context(context, self.model_key) loader = _require_context(context, self.loader_key) metrics = cast(Any, evaluate_model)( model, loader, **dict(self.evaluate_kwargs), ) return StageResult( outputs={self.metrics_key: metrics}, metadata=self.metadata, )
__all__ = [ "LinearFitStage", "ProjectStage", "ResidualizeStage", "StageResult", "TrainStage", "ValidateStage", "WorkflowStage", "workflow_stage_metadata", ]