"""Offline projection of radial callables onto 1D spline coefficient rows."""
from __future__ import annotations
import math
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, replace
import torch
from ufp.projection.diagnostics import (
ProjectionChannelDiagnostic,
ProjectionDiagnostics,
ProjectionErrorSummary,
ProjectionSupportCoverage,
)
from ufp.splines.representation import uniform_stencil_1d, uniform_support_parameters
[docs]
@dataclass(frozen=True)
class RadialProjectionResult:
"""Projected spline coefficients and associated projection metadata."""
coeffs: torch.Tensor
spline: str
coeff_size: int
full_support_start: float
cutoff: float
first_knot: float
knot_spacing: float
sample_distances: torch.Tensor
sample_weights: torch.Tensor | None
design_rank: int
diagnostic_distances: torch.Tensor
diagnostic_weights: torch.Tensor | None
target_values: torch.Tensor
predicted_values: torch.Tensor
residuals: torch.Tensor
weighted_rmse: float | None
diagnostics: ProjectionDiagnostics
metadata: Mapping[str, object] | None = None
def _resolve_dtype(dtype: torch.dtype | None) -> torch.dtype:
"""Return a floating-point dtype for projection tensors."""
resolved = torch.get_default_dtype() if dtype is None else dtype
if not torch.empty((), dtype=resolved).is_floating_point():
raise ValueError("`dtype` must be a floating-point torch dtype")
return resolved
def _as_1d_tensor(
value: object,
*,
name: str,
dtype: torch.dtype,
device: torch.device | None,
) -> torch.Tensor:
"""Convert an array-like value to a flat tensor."""
tensor = torch.as_tensor(value, dtype=dtype, device=device)
if tensor.ndim == 0:
tensor = tensor.reshape(1)
else:
tensor = tensor.reshape(-1)
if tensor.numel() == 0:
raise ValueError(f"`{name}` must contain at least one value")
if not torch.all(torch.isfinite(tensor)):
raise ValueError(f"`{name}` must contain only finite values")
return tensor
def _validate_interval(full_support_start: float, cutoff: float) -> None:
"""Validate one radial full-support interval."""
if not math.isfinite(full_support_start):
raise ValueError("`full_support_start` must be finite")
if not math.isfinite(cutoff):
raise ValueError("`cutoff` must be finite")
if cutoff <= full_support_start:
raise ValueError("`cutoff` must be greater than `full_support_start`")
def _validate_distances(
distances: torch.Tensor,
*,
name: str,
full_support_start: float,
cutoff: float,
) -> None:
"""Validate sampled radii against the projection interval."""
if torch.any(distances < full_support_start) or torch.any(distances >= cutoff):
raise ValueError(f"`{name}` must lie inside [full_support_start, cutoff)")
def _midpoint_grid(
*,
n_samples: int,
full_support_start: float,
cutoff: float,
dtype: torch.dtype,
device: torch.device | None,
) -> torch.Tensor:
"""Build a deterministic midpoint grid over the full-support interval."""
if n_samples <= 0:
raise ValueError("`n_samples` must be positive")
step = (cutoff - full_support_start) / float(n_samples)
return full_support_start + step * (
torch.arange(n_samples, dtype=dtype, device=device) + 0.5
)
def _normalize_weights(
weights: object | None,
*,
name: str,
distances: torch.Tensor,
dtype: torch.dtype,
device: torch.device | None,
) -> torch.Tensor | None:
"""Validate optional non-negative projection weights."""
if weights is None:
return None
tensor = _as_1d_tensor(weights, name=name, dtype=dtype, device=device)
if tuple(tensor.shape) != tuple(distances.shape):
raise ValueError(f"`{name}` must have the same shape as the distances")
if torch.any(tensor < 0.0):
raise ValueError(f"`{name}` must be non-negative")
if not bool(torch.any(tensor > 0.0)):
raise ValueError(f"`{name}` must contain at least one positive weight")
return tensor
def _evaluate_function(
function: Callable[[torch.Tensor], object],
distances: torch.Tensor,
*,
dtype: torch.dtype,
) -> torch.Tensor:
"""Evaluate a radial callable and validate its vectorized output shape."""
values = torch.as_tensor(
function(distances),
dtype=dtype,
device=distances.device,
)
if values.ndim == 0 and distances.numel() == 1:
values = values.reshape(1)
else:
values = values.reshape(-1)
if tuple(values.shape) != tuple(distances.shape):
raise ValueError("`function` must return values with the same shape as input")
if not torch.all(torch.isfinite(values)):
raise ValueError("`function` must return only finite values")
return values
def _row_matrix_from_stencil(
indices: torch.Tensor,
values: torch.Tensor,
*,
coeff_size: int,
) -> torch.Tensor:
"""Materialize a dense design matrix from local 1D spline stencil rows."""
matrix = values.new_zeros((values.shape[0], int(coeff_size)))
matrix.scatter_add_(1, indices, values)
return matrix
def _design_matrix(
distances: torch.Tensor,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
spline: str,
) -> torch.Tensor:
"""Materialize spline basis rows for one radial sample grid."""
stencil = uniform_stencil_1d(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
spline=spline,
)
return _row_matrix_from_stencil(
stencil.indices,
stencil.values,
coeff_size=coeff_size,
)
def _as_float(value: torch.Tensor) -> float:
"""Return one scalar tensor as a Python float."""
return float(value.detach().cpu().item())
def _weighted_rmse(
residuals: torch.Tensor,
weights: torch.Tensor | None,
) -> float | None:
"""Return weighted RMSE for optional diagnostics weights."""
if weights is None:
return None
total_weight = torch.sum(weights)
return _as_float(torch.sqrt(torch.sum(weights * residuals.square()) / total_weight))
def _error_summary(residuals: torch.Tensor) -> ProjectionErrorSummary:
"""Build a scalar projection-error summary from tensor residuals."""
return ProjectionErrorSummary.from_residuals(
residuals.detach().cpu().reshape(-1).tolist()
)
def _support_coverage(
distances: torch.Tensor,
*,
full_support_start: float,
cutoff: float,
) -> ProjectionSupportCoverage:
"""Build support coverage metadata for radial diagnostic samples."""
return ProjectionSupportCoverage.from_samples(
distances.detach().cpu().reshape(-1).tolist(),
lower_bound=float(full_support_start),
upper_bound=float(cutoff),
)
def _projection_diagnostics(
*,
channel_label: str,
distances: torch.Tensor,
residuals: torch.Tensor,
full_support_start: float,
cutoff: float,
) -> ProjectionDiagnostics:
"""Return shared projection diagnostics for one radial channel."""
return ProjectionDiagnostics(
(
ProjectionChannelDiagnostic(
channel_label=channel_label,
sample_count=int(distances.numel()),
support_coverage=_support_coverage(
distances,
full_support_start=full_support_start,
cutoff=cutoff,
),
value_error=_error_summary(residuals),
),
)
)
def _diagnostics(
function: Callable[[torch.Tensor], object],
coeffs: torch.Tensor,
*,
distances: torch.Tensor,
weights: torch.Tensor | None,
coeff_size: int,
first_knot: float,
knot_spacing: float,
spline: str,
dtype: torch.dtype,
full_support_start: float,
cutoff: float,
channel_label: str,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
float | None,
ProjectionDiagnostics,
]:
"""Evaluate projection error diagnostics over one radial grid."""
target_values = _evaluate_function(function, distances, dtype=dtype).detach()
matrix = _design_matrix(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
spline=spline,
)
predicted_values = matrix @ coeffs
residuals = predicted_values - target_values
detached_residuals = residuals.detach()
return (
target_values.detach(),
predicted_values.detach(),
detached_residuals,
_weighted_rmse(residuals, weights),
_projection_diagnostics(
channel_label=channel_label,
distances=distances,
residuals=detached_residuals,
full_support_start=full_support_start,
cutoff=cutoff,
),
)
[docs]
def project_radial_function(
function: Callable[[torch.Tensor], object],
*,
coeff_size: int,
full_support_start: float,
cutoff: float,
spline: str = "cubic",
n_samples: int | None = None,
sample_distances: object | None = None,
sample_weights: object | None = None,
diagnostic_distances: object | None = None,
diagnostic_weights: object | None = None,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
rcond: float | None = None,
channel_label: str = "radial",
) -> RadialProjectionResult:
"""
Project a radial callable onto a uniform 1D pair-spline coefficient row.
Args:
function: Callable accepting a distance tensor and returning values with
the same shape.
coeff_size: Number of 1D spline coefficients in the target row.
full_support_start: Lower distance with full spline support.
cutoff: Upper distance with full spline support and pair-term cutoff.
spline: Spline family name.
n_samples: Number of deterministic midpoint samples when
``sample_distances`` is not supplied.
sample_distances: Optional explicit fitting radii.
sample_weights: Optional non-negative least-squares row weights.
diagnostic_distances: Optional radii for error diagnostics. Defaults to
the fitting radii.
diagnostic_weights: Optional non-negative diagnostic weights. Defaults to
``sample_weights`` when diagnostics reuse the fitting radii.
dtype: Floating-point dtype for projection tensors.
device: Device for projection tensors.
rcond: Optional cutoff passed to :func:`torch.linalg.lstsq`.
channel_label: Diagnostic label for the projected coefficient channel.
Returns:
Projected coefficients and diagnostic data.
"""
if not callable(function):
raise TypeError("`function` must be callable")
coeff_size = int(coeff_size)
full_support_start = float(full_support_start)
cutoff = float(cutoff)
_validate_interval(full_support_start, cutoff)
resolved_dtype = _resolve_dtype(dtype)
resolved_device = None if device is None else torch.device(device)
first_knot, knot_spacing = uniform_support_parameters(
coeff_size=coeff_size,
lower_full_support=full_support_start,
upper_full_support=cutoff,
spline=spline,
)
if sample_distances is None:
resolved_n_samples = (
max(4 * coeff_size, coeff_size + 8) if n_samples is None else int(n_samples)
)
distances = _midpoint_grid(
n_samples=resolved_n_samples,
full_support_start=full_support_start,
cutoff=cutoff,
dtype=resolved_dtype,
device=resolved_device,
)
else:
distances = _as_1d_tensor(
sample_distances,
name="sample_distances",
dtype=resolved_dtype,
device=resolved_device,
)
_validate_distances(
distances,
name="sample_distances",
full_support_start=full_support_start,
cutoff=cutoff,
)
weights = _normalize_weights(
sample_weights,
name="sample_weights",
distances=distances,
dtype=resolved_dtype,
device=resolved_device,
)
positive_sample_count = (
int(distances.numel())
if weights is None
else int(torch.count_nonzero(weights > 0.0).item())
)
if positive_sample_count < coeff_size:
raise ValueError(
"`sample_distances` and positive `sample_weights` must provide at "
f"least `coeff_size` effective samples ({coeff_size})"
)
target_values = _evaluate_function(function, distances, dtype=resolved_dtype)
design = _design_matrix(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
spline=spline,
)
fit_design = design
fit_target = target_values
if weights is not None:
sqrt_weights = torch.sqrt(weights)
fit_design = fit_design * sqrt_weights[:, None]
fit_target = fit_target * sqrt_weights
rank = int(torch.linalg.matrix_rank(fit_design).detach().cpu().item())
if rank < coeff_size:
raise ValueError(
"weighted spline design matrix is rank deficient; provide more "
"or better distributed samples"
)
coeffs = torch.linalg.lstsq(fit_design, fit_target, rcond=rcond).solution.detach()
if diagnostic_distances is None:
diag_distances = distances.detach().clone()
diag_weights = (
None
if diagnostic_weights is None and weights is None
else _normalize_weights(
weights if diagnostic_weights is None else diagnostic_weights,
name="diagnostic_weights",
distances=diag_distances,
dtype=resolved_dtype,
device=resolved_device,
)
)
else:
diag_distances = _as_1d_tensor(
diagnostic_distances,
name="diagnostic_distances",
dtype=resolved_dtype,
device=resolved_device,
)
_validate_distances(
diag_distances,
name="diagnostic_distances",
full_support_start=full_support_start,
cutoff=cutoff,
)
diag_weights = _normalize_weights(
diagnostic_weights,
name="diagnostic_weights",
distances=diag_distances,
dtype=resolved_dtype,
device=resolved_device,
)
(
diagnostic_target_values,
diagnostic_predicted_values,
diagnostic_residuals,
weighted_rmse,
diagnostics,
) = _diagnostics(
function,
coeffs,
distances=diag_distances,
weights=diag_weights,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
spline=spline,
dtype=resolved_dtype,
full_support_start=full_support_start,
cutoff=cutoff,
channel_label=str(channel_label),
)
return RadialProjectionResult(
coeffs=coeffs,
spline=str(spline),
coeff_size=coeff_size,
full_support_start=full_support_start,
cutoff=cutoff,
first_knot=float(first_knot),
knot_spacing=float(knot_spacing),
sample_distances=distances.detach().clone(),
sample_weights=None if weights is None else weights.detach().clone(),
design_rank=rank,
diagnostic_distances=diag_distances.detach().clone(),
diagnostic_weights=(
None if diag_weights is None else diag_weights.detach().clone()
),
target_values=diagnostic_target_values,
predicted_values=diagnostic_predicted_values,
residuals=diagnostic_residuals,
weighted_rmse=weighted_rmse,
diagnostics=diagnostics,
)
def _pair_tuple(pair: Sequence[int]) -> tuple[int, int]:
"""Normalize a pair key for projection metadata."""
if len(pair) != 2:
raise ValueError("`pair` must contain exactly two atomic numbers")
return int(pair[0]), int(pair[1])
def _iter_module_tensors(module: object):
"""Yield floating tensors from an optional torch module-like object."""
parameters = getattr(module, "parameters", None)
if callable(parameters):
yield from parameters()
buffers = getattr(module, "buffers", None)
if callable(buffers):
yield from buffers()
def _prior_tensor_options(
prior: object,
*,
dtype: torch.dtype | None,
device: torch.device | str | None,
) -> tuple[torch.dtype | None, torch.device | None]:
"""Resolve projection tensor options from explicit args or prior storage."""
resolved_dtype = dtype
resolved_device = None if device is None else torch.device(device)
for tensor in _iter_module_tensors(prior):
if not isinstance(tensor, torch.Tensor) or not tensor.is_floating_point():
continue
if resolved_dtype is None:
resolved_dtype = tensor.dtype
if resolved_device is None:
resolved_device = tensor.device
if resolved_dtype is not None and resolved_device is not None:
break
return resolved_dtype, resolved_device
def _prior_projection_metadata(prior: object) -> Mapping[str, object]:
"""Return projection metadata from an analytic prior when available."""
metadata = {
"class": f"{type(prior).__module__}.{type(prior).__qualname__}",
}
projection_metadata = getattr(prior, "projection_metadata", None)
if callable(projection_metadata):
supplied = projection_metadata()
if not isinstance(supplied, Mapping):
raise TypeError("`prior.projection_metadata()` must return a mapping")
metadata.update(dict(supplied))
return metadata
[docs]
def project_pair_prior(
prior: object,
pair: Sequence[int],
*,
coeff_size: int,
full_support_start: float = 0.0,
cutoff: float | None = None,
spline: str = "cubic",
n_samples: int | None = None,
sample_distances: object | None = None,
sample_weights: object | None = None,
diagnostic_distances: object | None = None,
diagnostic_weights: object | None = None,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
rcond: float | None = None,
channel_label: str | None = None,
) -> RadialProjectionResult:
"""
Project one analytic pair prior channel onto a 1D spline coefficient row.
The prior must expose ``radial_values(pair, distances)``. If it also exposes
``projection_metadata()``, that metadata is attached to the returned result
alongside the projection grid settings and pair channel.
"""
radial_values = getattr(prior, "radial_values", None)
if not callable(radial_values):
raise TypeError("`prior` must expose callable `radial_values(pair, distances)`")
resolved_pair = _pair_tuple(pair)
resolved_cutoff = getattr(prior, "cutoff", None) if cutoff is None else cutoff
if resolved_cutoff is None:
raise ValueError("`cutoff` is required when `prior.cutoff` is not set")
resolved_dtype, resolved_device = _prior_tensor_options(
prior,
dtype=dtype,
device=device,
)
def function(distances: torch.Tensor) -> object:
return radial_values(resolved_pair, distances)
resolved_label = (
f"pair[{resolved_pair!r}]" if channel_label is None else str(channel_label)
)
result = project_radial_function(
function,
coeff_size=coeff_size,
full_support_start=full_support_start,
cutoff=float(resolved_cutoff),
spline=spline,
n_samples=n_samples,
sample_distances=sample_distances,
sample_weights=sample_weights,
diagnostic_distances=diagnostic_distances,
diagnostic_weights=diagnostic_weights,
dtype=resolved_dtype,
device=resolved_device,
rcond=rcond,
channel_label=resolved_label,
)
metadata = {
"source": "analytic_pair_prior",
"pair": resolved_pair,
"channel_label": resolved_label,
"prior": _prior_projection_metadata(prior),
"projection": {
"coeff_size": result.coeff_size,
"spline": result.spline,
"full_support_start": result.full_support_start,
"cutoff": result.cutoff,
"first_knot": result.first_knot,
"knot_spacing": result.knot_spacing,
},
}
return replace(result, metadata=metadata)
__all__ = [
"RadialProjectionResult",
"project_pair_prior",
"project_radial_function",
]