"""Explicit workflow stage objects for composing existing UFP helpers."""
from __future__ import annotations
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Protocol, cast
import torch
from ufp.leastsquares import LinearFitter
from ufp.terms import UFPModel
from ufp.training import evaluate_model, fit_model, freeze_model_coefficients
from ufp.workflows.checkpoints import normalize_checkpoint_metadata
from ufp.workflows.residuals import materialize_residual_dataset
STAGE_METADATA_VERSION = 1
[docs]
@dataclass(frozen=True)
class StageResult:
"""Outputs and metadata produced by one explicit workflow stage."""
outputs: Mapping[str, object]
metadata: Mapping[str, object]
[docs]
def update_context(
self,
context: MutableMapping[str, object],
) -> MutableMapping[str, object]:
"""Update a user-owned context mapping with this stage's outputs."""
context.update(self.outputs)
return context
[docs]
class WorkflowStage(Protocol):
"""Protocol implemented by lightweight workflow stages."""
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by the stage."""
...
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by the stage."""
...
@property
def metadata(self) -> Mapping[str, object]:
"""JSON-friendly stage metadata."""
...
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Execute the stage against a user-owned context mapping."""
...
def _stage_metadata(
*,
stage_type: str,
name: str,
required_inputs: Sequence[str],
produced_outputs: Sequence[str],
config: Mapping[str, object],
) -> dict[str, object]:
"""Return shared stage metadata."""
return {
"version": STAGE_METADATA_VERSION,
"stage_type": stage_type,
"name": name,
"required_inputs": tuple(required_inputs),
"produced_outputs": tuple(produced_outputs),
"config": normalize_checkpoint_metadata(dict(config)),
}
def _require_context(context: Mapping[str, object], key: str) -> object:
"""Read one required context key with a clear error."""
try:
return context[key]
except KeyError as exc:
raise KeyError(f"workflow stage input {key!r} is missing") from exc
[docs]
@dataclass(frozen=True)
class LinearFitStage:
"""Selector-aware wrapper around :class:`ufp.leastsquares.LinearFitter`."""
name: str = "linear_fit"
model_key: str = "model"
samples_key: str = "fit_samples"
result_key: str = "linear_fit_result"
fitter_kwargs: Mapping[str, object] = field(default_factory=dict)
fit_kwargs: Mapping[str, object] = field(default_factory=dict)
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by this stage."""
return (self.model_key, self.samples_key)
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by this stage."""
return (self.result_key, self.model_key)
@property
def metadata(self) -> Mapping[str, object]:
"""Return JSON-friendly stage metadata."""
return _stage_metadata(
stage_type="linear_fit",
name=self.name,
required_inputs=self.required_inputs,
produced_outputs=self.produced_outputs,
config={
"fitter_kwargs": dict(self.fitter_kwargs),
"fit_kwargs": dict(self.fit_kwargs),
},
)
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Build a ``LinearFitter`` and execute its ``fit`` method."""
model = _require_context(context, self.model_key)
if not isinstance(model, UFPModel):
raise TypeError(f"context[{self.model_key!r}] must be a UFPModel")
samples = _require_context(context, self.samples_key)
fitter = LinearFitter(model, **dict(self.fitter_kwargs))
result = fitter.fit(cast(Any, samples), **dict(self.fit_kwargs))
return StageResult(
outputs={
self.result_key: result,
self.model_key: model,
},
metadata=self.metadata,
)
[docs]
@dataclass(frozen=True)
class TrainStage:
"""Wrapper around optimizer training with optional coefficient freeze masks."""
name: str = "train"
model_key: str = "model"
train_loader_key: str = "train_loader"
history_key: str = "training_history"
optimizer_key: str = "optimizer"
freeze_state_key: str = "freeze_state"
freeze_selectors: Sequence[int | str] = ()
optimizer_factory: Callable[..., torch.optim.Optimizer] = torch.optim.Adam
optimizer_kwargs: Mapping[str, object] = field(
default_factory=lambda: {"lr": 1.0e-3}
)
fit_kwargs: Mapping[str, object] = field(default_factory=lambda: {"epochs": 1})
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by this stage."""
return (self.model_key, self.train_loader_key)
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by this stage."""
return (self.history_key, self.optimizer_key, self.freeze_state_key)
@property
def metadata(self) -> Mapping[str, object]:
"""Return JSON-friendly stage metadata."""
return _stage_metadata(
stage_type="train",
name=self.name,
required_inputs=self.required_inputs,
produced_outputs=self.produced_outputs,
config={
"freeze_selectors": tuple(self.freeze_selectors),
"optimizer_factory": self.optimizer_factory,
"optimizer_kwargs": dict(self.optimizer_kwargs),
"fit_kwargs": dict(self.fit_kwargs),
},
)
def _optimizer(self, model: UFPModel, context: Mapping[str, object]):
"""Return an existing optimizer or create one from stage config."""
existing = context.get(self.optimizer_key)
if existing is not None:
if not isinstance(existing, torch.optim.Optimizer):
raise TypeError(
f"context[{self.optimizer_key!r}] must be a torch optimizer"
)
return existing
return self.optimizer_factory(model.parameters(), **dict(self.optimizer_kwargs))
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Train a model with optional coefficient freeze masks."""
model = _require_context(context, self.model_key)
if not isinstance(model, UFPModel):
raise TypeError(f"context[{self.model_key!r}] must be a UFPModel")
train_loader = _require_context(context, self.train_loader_key)
optimizer = self._optimizer(model, context)
freeze_state = (
freeze_model_coefficients(model, self.freeze_selectors)
if self.freeze_selectors
else None
)
if freeze_state is not None:
freeze_state.wrap_optimizer(optimizer)
try:
history = cast(Any, fit_model)(
model,
train_loader,
optimizer=optimizer,
**dict(self.fit_kwargs),
)
finally:
if freeze_state is not None:
freeze_state.unwrap_optimizer(optimizer)
return StageResult(
outputs={
self.history_key: history,
self.optimizer_key: optimizer,
self.freeze_state_key: freeze_state,
},
metadata=self.metadata,
)
[docs]
@dataclass(frozen=True)
class ProjectStage:
"""Projection stage that delegates to an explicit projection callable."""
projector: Callable[..., object]
name: str = "project"
input_bindings: Mapping[str, str] = field(default_factory=dict)
projector_kwargs: Mapping[str, object] = field(default_factory=dict)
result_key: str = "projection_result"
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by this stage."""
return tuple(dict.fromkeys(self.input_bindings.values()))
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by this stage."""
return (self.result_key,)
@property
def metadata(self) -> Mapping[str, object]:
"""Return JSON-friendly stage metadata."""
return _stage_metadata(
stage_type="project",
name=self.name,
required_inputs=self.required_inputs,
produced_outputs=self.produced_outputs,
config={
"projector": self.projector,
"input_bindings": dict(self.input_bindings),
"projector_kwargs": dict(self.projector_kwargs),
},
)
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Call the configured projection helper."""
kwargs = {
argument: _require_context(context, key)
for argument, key in self.input_bindings.items()
}
kwargs.update(dict(self.projector_kwargs))
result = self.projector(**kwargs)
return StageResult(
outputs={self.result_key: result},
metadata=self.metadata,
)
[docs]
@dataclass(frozen=True)
class ResidualizeStage:
"""Wrapper around nonlinear frozen-component residual materialization."""
name: str = "residualize"
model_key: str = "model"
dataset_key: str = "dataset"
result_key: str = "residual_result"
dataset_output_key: str = "residual_dataset"
residual_kwargs: Mapping[str, object] = field(default_factory=dict)
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by this stage."""
return (self.model_key, self.dataset_key)
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by this stage."""
return (self.result_key, self.dataset_output_key)
@property
def metadata(self) -> Mapping[str, object]:
"""Return JSON-friendly stage metadata."""
return _stage_metadata(
stage_type="residualize",
name=self.name,
required_inputs=self.required_inputs,
produced_outputs=self.produced_outputs,
config={"residual_kwargs": dict(self.residual_kwargs)},
)
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Materialize residual labels for a dataset."""
model = _require_context(context, self.model_key)
if not isinstance(model, UFPModel):
raise TypeError(f"context[{self.model_key!r}] must be a UFPModel")
dataset = _require_context(context, self.dataset_key)
result = cast(Any, materialize_residual_dataset)(
model,
dataset,
**dict(self.residual_kwargs),
)
return StageResult(
outputs={
self.result_key: result,
self.dataset_output_key: result.dataset,
},
metadata=self.metadata,
)
[docs]
@dataclass(frozen=True)
class ValidateStage:
"""Validation/metrics stage that delegates to ``evaluate_model``."""
name: str = "validate"
model_key: str = "model"
loader_key: str = "validation_loader"
metrics_key: str = "validation_metrics"
evaluate_kwargs: Mapping[str, object] = field(default_factory=dict)
@property
def required_inputs(self) -> tuple[str, ...]:
"""Context keys required by this stage."""
return (self.model_key, self.loader_key)
@property
def produced_outputs(self) -> tuple[str, ...]:
"""Context keys produced by this stage."""
return (self.metrics_key,)
@property
def metadata(self) -> Mapping[str, object]:
"""Return JSON-friendly stage metadata."""
return _stage_metadata(
stage_type="validate",
name=self.name,
required_inputs=self.required_inputs,
produced_outputs=self.produced_outputs,
config={"evaluate_kwargs": dict(self.evaluate_kwargs)},
)
[docs]
def run(self, context: Mapping[str, object]) -> StageResult:
"""Evaluate a model and return aggregate metrics."""
model = _require_context(context, self.model_key)
loader = _require_context(context, self.loader_key)
metrics = cast(Any, evaluate_model)(
model,
loader,
**dict(self.evaluate_kwargs),
)
return StageResult(
outputs={self.metrics_key: metrics},
metadata=self.metadata,
)
__all__ = [
"LinearFitStage",
"ProjectStage",
"ResidualizeStage",
"StageResult",
"TrainStage",
"ValidateStage",
"WorkflowStage",
"workflow_stage_metadata",
]