"""Projection diagnostic containers.
The objects in this module are intentionally independent of projection
algorithms and plotting libraries. Projection code can return them directly, and
workflow/checkpoint code can serialize them through ``to_dict()``.
"""
from __future__ import annotations
import math
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
def _check_count(name: str, value: int) -> int:
"""Return a normalized non-negative integer count."""
value = int(value)
if value < 0:
raise ValueError(f"`{name}` must be non-negative")
return value
def _check_optional_non_negative(name: str, value: float | None) -> float | None:
"""Return a normalized finite non-negative float or ``None``."""
if value is None:
return None
normalized = float(value)
if not math.isfinite(normalized) or normalized < 0.0:
raise ValueError(f"`{name}` must be a finite non-negative value")
return normalized
def _check_optional_finite(name: str, value: float | None) -> float | None:
"""Return a normalized finite float or ``None``."""
if value is None:
return None
normalized = float(value)
if not math.isfinite(normalized):
raise ValueError(f"`{name}` must be finite")
return normalized
def _as_warning_tuple(warnings: Sequence[str] | Iterable[str]) -> tuple[str, ...]:
"""Normalize warning messages to a tuple of strings."""
return tuple(str(message) for message in warnings)
[docs]
@dataclass(frozen=True)
class ProjectionSupportCoverage:
"""Coverage of sampled points by a projection support interval or domain."""
sample_count: int
covered_count: int
lower_bound: float | None = None
upper_bound: float | None = None
minimum_sample: float | None = None
maximum_sample: float | None = None
def __post_init__(self) -> None:
"""Validate counts and optional support/sample bounds."""
sample_count = _check_count("sample_count", self.sample_count)
covered_count = _check_count("covered_count", self.covered_count)
if covered_count > sample_count:
raise ValueError("`covered_count` can not exceed `sample_count`")
object.__setattr__(self, "sample_count", sample_count)
object.__setattr__(self, "covered_count", covered_count)
object.__setattr__(
self,
"lower_bound",
_check_optional_finite("lower_bound", self.lower_bound),
)
object.__setattr__(
self,
"upper_bound",
_check_optional_finite("upper_bound", self.upper_bound),
)
object.__setattr__(
self,
"minimum_sample",
_check_optional_finite("minimum_sample", self.minimum_sample),
)
object.__setattr__(
self,
"maximum_sample",
_check_optional_finite("maximum_sample", self.maximum_sample),
)
if (
self.lower_bound is not None
and self.upper_bound is not None
and self.upper_bound < self.lower_bound
):
raise ValueError(
"`upper_bound` must be greater than or equal to `lower_bound`"
)
if (
self.minimum_sample is not None
and self.maximum_sample is not None
and self.maximum_sample < self.minimum_sample
):
raise ValueError(
"`maximum_sample` must be greater than or equal to `minimum_sample`"
)
[docs]
@classmethod
def from_samples(
cls,
samples: Sequence[float],
*,
lower_bound: float | None = None,
upper_bound: float | None = None,
) -> "ProjectionSupportCoverage":
"""Build coverage metadata from sampled coordinates and optional bounds."""
values = tuple(float(value) for value in samples)
for value in values:
if not math.isfinite(value):
raise ValueError("support samples must be finite")
covered = [
value
for value in values
if (lower_bound is None or value >= float(lower_bound))
and (upper_bound is None or value <= float(upper_bound))
]
return cls(
sample_count=len(values),
covered_count=len(covered),
lower_bound=lower_bound,
upper_bound=upper_bound,
minimum_sample=None if not values else min(values),
maximum_sample=None if not values else max(values),
)
@property
def coverage_fraction(self) -> float | None:
"""Return covered samples divided by total samples, or ``None`` if empty."""
if self.sample_count == 0:
return None
return float(self.covered_count) / float(self.sample_count)
@property
def fully_covered(self) -> bool:
"""Return whether every sampled point was inside support."""
return self.covered_count == self.sample_count
[docs]
def to_dict(self) -> dict[str, int | float | None | bool]:
"""Return JSON/checkpoint-friendly support coverage metadata."""
return {
"sample_count": self.sample_count,
"covered_count": self.covered_count,
"coverage_fraction": self.coverage_fraction,
"fully_covered": self.fully_covered,
"lower_bound": self.lower_bound,
"upper_bound": self.upper_bound,
"minimum_sample": self.minimum_sample,
"maximum_sample": self.maximum_sample,
}
[docs]
@dataclass(frozen=True)
class ProjectionErrorSummary:
"""Scalar error metrics for projected values or derivatives."""
sample_count: int
rmse: float | None = None
mean_absolute: float | None = None
max_absolute: float | None = None
def __post_init__(self) -> None:
"""Validate count and optional error metrics."""
object.__setattr__(
self,
"sample_count",
_check_count("sample_count", self.sample_count),
)
object.__setattr__(
self,
"rmse",
_check_optional_non_negative("rmse", self.rmse),
)
object.__setattr__(
self,
"mean_absolute",
_check_optional_non_negative("mean_absolute", self.mean_absolute),
)
object.__setattr__(
self,
"max_absolute",
_check_optional_non_negative("max_absolute", self.max_absolute),
)
[docs]
@classmethod
def from_residuals(
cls,
residuals: Sequence[float],
) -> "ProjectionErrorSummary":
"""Build error metrics from projection residuals."""
values = tuple(float(value) for value in residuals)
for value in values:
if not math.isfinite(value):
raise ValueError("projection residuals must be finite")
if not values:
return cls(sample_count=0)
absolute = tuple(abs(value) for value in values)
mse = sum(value * value for value in values) / float(len(values))
return cls(
sample_count=len(values),
rmse=math.sqrt(mse),
mean_absolute=sum(absolute) / float(len(absolute)),
max_absolute=max(absolute),
)
[docs]
def to_dict(self) -> dict[str, int | float | None]:
"""Return JSON/checkpoint-friendly error metadata."""
return {
"sample_count": self.sample_count,
"rmse": self.rmse,
"mean_absolute": self.mean_absolute,
"max_absolute": self.max_absolute,
}
[docs]
@dataclass(frozen=True)
class ProjectionChannelDiagnostic:
"""Projection diagnostics for one physical or logical coefficient channel."""
channel_label: str
sample_count: int
support_coverage: ProjectionSupportCoverage | None = None
value_error: ProjectionErrorSummary | None = None
derivative_error: ProjectionErrorSummary | None = None
warnings: tuple[str, ...] = field(default_factory=tuple)
failure_reason: str | None = None
def __post_init__(self) -> None:
"""Validate channel metadata and normalize plain values."""
object.__setattr__(self, "channel_label", str(self.channel_label))
object.__setattr__(
self,
"sample_count",
_check_count("sample_count", self.sample_count),
)
object.__setattr__(self, "warnings", _as_warning_tuple(self.warnings))
if self.failure_reason is not None:
object.__setattr__(self, "failure_reason", str(self.failure_reason))
[docs]
@classmethod
def failure(
cls,
*,
channel_label: str,
failure_reason: str,
sample_count: int = 0,
warnings: Sequence[str] = (),
support_coverage: ProjectionSupportCoverage | None = None,
) -> "ProjectionChannelDiagnostic":
"""Build a failed channel diagnostic with a required failure reason."""
if not failure_reason:
raise ValueError("`failure_reason` must not be empty")
return cls(
channel_label=channel_label,
sample_count=sample_count,
support_coverage=support_coverage,
warnings=tuple(warnings),
failure_reason=failure_reason,
)
@property
def succeeded(self) -> bool:
"""Return whether projection succeeded for this channel."""
return self.failure_reason is None
@property
def failed(self) -> bool:
"""Return whether projection failed for this channel."""
return self.failure_reason is not None
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON/checkpoint-friendly channel diagnostic metadata."""
return {
"channel_label": self.channel_label,
"sample_count": self.sample_count,
"support_coverage": (
None
if self.support_coverage is None
else self.support_coverage.to_dict()
),
"value_error": None
if self.value_error is None
else self.value_error.to_dict(),
"derivative_error": None
if self.derivative_error is None
else self.derivative_error.to_dict(),
"warnings": list(self.warnings),
"failure_reason": self.failure_reason,
"succeeded": self.succeeded,
}
[docs]
@dataclass(frozen=True)
class ProjectionDiagnosticsSummary:
"""Aggregate projection diagnostics across channels."""
channel_count: int
successful_channel_count: int
failed_channel_count: int
sample_count: int
covered_count: int
support_sample_count: int
support_coverage_fraction: float | None
max_value_rmse: float | None
max_value_max_absolute: float | None
max_derivative_rmse: float | None
max_derivative_max_absolute: float | None
warning_count: int
failed_channel_labels: tuple[str, ...] = field(default_factory=tuple)
warning_messages: tuple[str, ...] = field(default_factory=tuple)
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON/checkpoint-friendly aggregate metadata."""
return {
"channel_count": self.channel_count,
"successful_channel_count": self.successful_channel_count,
"failed_channel_count": self.failed_channel_count,
"sample_count": self.sample_count,
"covered_count": self.covered_count,
"support_sample_count": self.support_sample_count,
"support_coverage_fraction": self.support_coverage_fraction,
"max_value_rmse": self.max_value_rmse,
"max_value_max_absolute": self.max_value_max_absolute,
"max_derivative_rmse": self.max_derivative_rmse,
"max_derivative_max_absolute": self.max_derivative_max_absolute,
"warning_count": self.warning_count,
"failed_channel_labels": list(self.failed_channel_labels),
"warning_messages": list(self.warning_messages),
}
def _max_present(values: Iterable[float | None]) -> float | None:
"""Return the maximum non-``None`` value, or ``None`` if none are present."""
present = [float(value) for value in values if value is not None]
if not present:
return None
return max(present)
[docs]
def aggregate_channel_diagnostics(
channels: Sequence[ProjectionChannelDiagnostic],
) -> ProjectionDiagnosticsSummary:
"""Aggregate per-channel projection diagnostics into one summary."""
items = tuple(channels)
coverage_items = tuple(
channel.support_coverage
for channel in items
if channel.support_coverage is not None
)
support_sample_count = sum(item.sample_count for item in coverage_items)
covered_count = sum(item.covered_count for item in coverage_items)
warning_messages = tuple(
f"{channel.channel_label}: {message}"
for channel in items
for message in channel.warnings
)
return ProjectionDiagnosticsSummary(
channel_count=len(items),
successful_channel_count=sum(1 for channel in items if channel.succeeded),
failed_channel_count=sum(1 for channel in items if channel.failed),
sample_count=sum(channel.sample_count for channel in items),
covered_count=covered_count,
support_sample_count=support_sample_count,
support_coverage_fraction=(
None
if support_sample_count == 0
else float(covered_count) / float(support_sample_count)
),
max_value_rmse=_max_present(
None if channel.value_error is None else channel.value_error.rmse
for channel in items
),
max_value_max_absolute=_max_present(
None if channel.value_error is None else channel.value_error.max_absolute
for channel in items
),
max_derivative_rmse=_max_present(
None if channel.derivative_error is None else channel.derivative_error.rmse
for channel in items
),
max_derivative_max_absolute=_max_present(
None
if channel.derivative_error is None
else channel.derivative_error.max_absolute
for channel in items
),
warning_count=len(warning_messages),
failed_channel_labels=tuple(
channel.channel_label for channel in items if channel.failed
),
warning_messages=warning_messages,
)
[docs]
@dataclass(frozen=True)
class ProjectionDiagnostics:
"""Collection of channel diagnostics returned by one projection operation."""
channels: tuple[ProjectionChannelDiagnostic, ...] = field(default_factory=tuple)
def __post_init__(self) -> None:
"""Normalize channels to an immutable tuple."""
object.__setattr__(self, "channels", tuple(self.channels))
@property
def summary(self) -> ProjectionDiagnosticsSummary:
"""Return aggregate diagnostics across all channels."""
return aggregate_channel_diagnostics(self.channels)
@property
def failed_channels(self) -> tuple[ProjectionChannelDiagnostic, ...]:
"""Return channels whose projection failed."""
return tuple(channel for channel in self.channels if channel.failed)
@property
def warning_messages(self) -> tuple[str, ...]:
"""Return channel-qualified warning messages."""
return self.summary.warning_messages
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON/checkpoint-friendly diagnostics metadata."""
return {
"channels": [channel.to_dict() for channel in self.channels],
"summary": self.summary.to_dict(),
}
[docs]
def ensure_projection_succeeded(
diagnostics: ProjectionDiagnostics | Sequence[ProjectionChannelDiagnostic],
) -> None:
"""Raise ``ValueError`` when any channel diagnostic reports a failure."""
collection = (
diagnostics
if isinstance(diagnostics, ProjectionDiagnostics)
else ProjectionDiagnostics(tuple(diagnostics))
)
failures = collection.failed_channels
if not failures:
return
details = "; ".join(
f"{channel.channel_label}: {channel.failure_reason}" for channel in failures
)
raise ValueError(f"projection failed for {len(failures)} channel(s): {details}")
__all__ = [
"ProjectionChannelDiagnostic",
"ProjectionDiagnostics",
"ProjectionDiagnosticsSummary",
"ProjectionErrorSummary",
"ProjectionSupportCoverage",
"aggregate_channel_diagnostics",
"ensure_projection_succeeded",
]