Source code for ufp.projection.radial

"""Offline projection of radial callables onto 1D spline coefficient rows."""

from __future__ import annotations

import math
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, replace

import torch

from ufp.projection.diagnostics import (
    ProjectionChannelDiagnostic,
    ProjectionDiagnostics,
    ProjectionErrorSummary,
    ProjectionSupportCoverage,
)
from ufp.splines.representation import uniform_stencil_1d, uniform_support_parameters


[docs] @dataclass(frozen=True) class RadialProjectionResult: """Projected spline coefficients and associated projection metadata.""" coeffs: torch.Tensor spline: str coeff_size: int full_support_start: float cutoff: float first_knot: float knot_spacing: float sample_distances: torch.Tensor sample_weights: torch.Tensor | None design_rank: int diagnostic_distances: torch.Tensor diagnostic_weights: torch.Tensor | None target_values: torch.Tensor predicted_values: torch.Tensor residuals: torch.Tensor weighted_rmse: float | None diagnostics: ProjectionDiagnostics metadata: Mapping[str, object] | None = None
def _resolve_dtype(dtype: torch.dtype | None) -> torch.dtype: """Return a floating-point dtype for projection tensors.""" resolved = torch.get_default_dtype() if dtype is None else dtype if not torch.empty((), dtype=resolved).is_floating_point(): raise ValueError("`dtype` must be a floating-point torch dtype") return resolved def _as_1d_tensor( value: object, *, name: str, dtype: torch.dtype, device: torch.device | None, ) -> torch.Tensor: """Convert an array-like value to a flat tensor.""" tensor = torch.as_tensor(value, dtype=dtype, device=device) if tensor.ndim == 0: tensor = tensor.reshape(1) else: tensor = tensor.reshape(-1) if tensor.numel() == 0: raise ValueError(f"`{name}` must contain at least one value") if not torch.all(torch.isfinite(tensor)): raise ValueError(f"`{name}` must contain only finite values") return tensor def _validate_interval(full_support_start: float, cutoff: float) -> None: """Validate one radial full-support interval.""" if not math.isfinite(full_support_start): raise ValueError("`full_support_start` must be finite") if not math.isfinite(cutoff): raise ValueError("`cutoff` must be finite") if cutoff <= full_support_start: raise ValueError("`cutoff` must be greater than `full_support_start`") def _validate_distances( distances: torch.Tensor, *, name: str, full_support_start: float, cutoff: float, ) -> None: """Validate sampled radii against the projection interval.""" if torch.any(distances < full_support_start) or torch.any(distances >= cutoff): raise ValueError(f"`{name}` must lie inside [full_support_start, cutoff)") def _midpoint_grid( *, n_samples: int, full_support_start: float, cutoff: float, dtype: torch.dtype, device: torch.device | None, ) -> torch.Tensor: """Build a deterministic midpoint grid over the full-support interval.""" if n_samples <= 0: raise ValueError("`n_samples` must be positive") step = (cutoff - full_support_start) / float(n_samples) return full_support_start + step * ( torch.arange(n_samples, dtype=dtype, device=device) + 0.5 ) def _normalize_weights( weights: object | None, *, name: str, distances: torch.Tensor, dtype: torch.dtype, device: torch.device | None, ) -> torch.Tensor | None: """Validate optional non-negative projection weights.""" if weights is None: return None tensor = _as_1d_tensor(weights, name=name, dtype=dtype, device=device) if tuple(tensor.shape) != tuple(distances.shape): raise ValueError(f"`{name}` must have the same shape as the distances") if torch.any(tensor < 0.0): raise ValueError(f"`{name}` must be non-negative") if not bool(torch.any(tensor > 0.0)): raise ValueError(f"`{name}` must contain at least one positive weight") return tensor def _evaluate_function( function: Callable[[torch.Tensor], object], distances: torch.Tensor, *, dtype: torch.dtype, ) -> torch.Tensor: """Evaluate a radial callable and validate its vectorized output shape.""" values = torch.as_tensor( function(distances), dtype=dtype, device=distances.device, ) if values.ndim == 0 and distances.numel() == 1: values = values.reshape(1) else: values = values.reshape(-1) if tuple(values.shape) != tuple(distances.shape): raise ValueError("`function` must return values with the same shape as input") if not torch.all(torch.isfinite(values)): raise ValueError("`function` must return only finite values") return values def _row_matrix_from_stencil( indices: torch.Tensor, values: torch.Tensor, *, coeff_size: int, ) -> torch.Tensor: """Materialize a dense design matrix from local 1D spline stencil rows.""" matrix = values.new_zeros((values.shape[0], int(coeff_size))) matrix.scatter_add_(1, indices, values) return matrix def _design_matrix( distances: torch.Tensor, *, coeff_size: int, first_knot: float, knot_spacing: float, spline: str, ) -> torch.Tensor: """Materialize spline basis rows for one radial sample grid.""" stencil = uniform_stencil_1d( distances, coeff_size=coeff_size, first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, ) return _row_matrix_from_stencil( stencil.indices, stencil.values, coeff_size=coeff_size, ) def _as_float(value: torch.Tensor) -> float: """Return one scalar tensor as a Python float.""" return float(value.detach().cpu().item()) def _weighted_rmse( residuals: torch.Tensor, weights: torch.Tensor | None, ) -> float | None: """Return weighted RMSE for optional diagnostics weights.""" if weights is None: return None total_weight = torch.sum(weights) return _as_float(torch.sqrt(torch.sum(weights * residuals.square()) / total_weight)) def _error_summary(residuals: torch.Tensor) -> ProjectionErrorSummary: """Build a scalar projection-error summary from tensor residuals.""" return ProjectionErrorSummary.from_residuals( residuals.detach().cpu().reshape(-1).tolist() ) def _support_coverage( distances: torch.Tensor, *, full_support_start: float, cutoff: float, ) -> ProjectionSupportCoverage: """Build support coverage metadata for radial diagnostic samples.""" return ProjectionSupportCoverage.from_samples( distances.detach().cpu().reshape(-1).tolist(), lower_bound=float(full_support_start), upper_bound=float(cutoff), ) def _projection_diagnostics( *, channel_label: str, distances: torch.Tensor, residuals: torch.Tensor, full_support_start: float, cutoff: float, ) -> ProjectionDiagnostics: """Return shared projection diagnostics for one radial channel.""" return ProjectionDiagnostics( ( ProjectionChannelDiagnostic( channel_label=channel_label, sample_count=int(distances.numel()), support_coverage=_support_coverage( distances, full_support_start=full_support_start, cutoff=cutoff, ), value_error=_error_summary(residuals), ), ) ) def _diagnostics( function: Callable[[torch.Tensor], object], coeffs: torch.Tensor, *, distances: torch.Tensor, weights: torch.Tensor | None, coeff_size: int, first_knot: float, knot_spacing: float, spline: str, dtype: torch.dtype, full_support_start: float, cutoff: float, channel_label: str, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, float | None, ProjectionDiagnostics, ]: """Evaluate projection error diagnostics over one radial grid.""" target_values = _evaluate_function(function, distances, dtype=dtype).detach() matrix = _design_matrix( distances, coeff_size=coeff_size, first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, ) predicted_values = matrix @ coeffs residuals = predicted_values - target_values detached_residuals = residuals.detach() return ( target_values.detach(), predicted_values.detach(), detached_residuals, _weighted_rmse(residuals, weights), _projection_diagnostics( channel_label=channel_label, distances=distances, residuals=detached_residuals, full_support_start=full_support_start, cutoff=cutoff, ), )
[docs] def project_radial_function( function: Callable[[torch.Tensor], object], *, coeff_size: int, full_support_start: float, cutoff: float, spline: str = "cubic", n_samples: int | None = None, sample_distances: object | None = None, sample_weights: object | None = None, diagnostic_distances: object | None = None, diagnostic_weights: object | None = None, dtype: torch.dtype | None = None, device: torch.device | str | None = None, rcond: float | None = None, channel_label: str = "radial", ) -> RadialProjectionResult: """ Project a radial callable onto a uniform 1D pair-spline coefficient row. Args: function: Callable accepting a distance tensor and returning values with the same shape. coeff_size: Number of 1D spline coefficients in the target row. full_support_start: Lower distance with full spline support. cutoff: Upper distance with full spline support and pair-term cutoff. spline: Spline family name. n_samples: Number of deterministic midpoint samples when ``sample_distances`` is not supplied. sample_distances: Optional explicit fitting radii. sample_weights: Optional non-negative least-squares row weights. diagnostic_distances: Optional radii for error diagnostics. Defaults to the fitting radii. diagnostic_weights: Optional non-negative diagnostic weights. Defaults to ``sample_weights`` when diagnostics reuse the fitting radii. dtype: Floating-point dtype for projection tensors. device: Device for projection tensors. rcond: Optional cutoff passed to :func:`torch.linalg.lstsq`. channel_label: Diagnostic label for the projected coefficient channel. Returns: Projected coefficients and diagnostic data. """ if not callable(function): raise TypeError("`function` must be callable") coeff_size = int(coeff_size) full_support_start = float(full_support_start) cutoff = float(cutoff) _validate_interval(full_support_start, cutoff) resolved_dtype = _resolve_dtype(dtype) resolved_device = None if device is None else torch.device(device) first_knot, knot_spacing = uniform_support_parameters( coeff_size=coeff_size, lower_full_support=full_support_start, upper_full_support=cutoff, spline=spline, ) if sample_distances is None: resolved_n_samples = ( max(4 * coeff_size, coeff_size + 8) if n_samples is None else int(n_samples) ) distances = _midpoint_grid( n_samples=resolved_n_samples, full_support_start=full_support_start, cutoff=cutoff, dtype=resolved_dtype, device=resolved_device, ) else: distances = _as_1d_tensor( sample_distances, name="sample_distances", dtype=resolved_dtype, device=resolved_device, ) _validate_distances( distances, name="sample_distances", full_support_start=full_support_start, cutoff=cutoff, ) weights = _normalize_weights( sample_weights, name="sample_weights", distances=distances, dtype=resolved_dtype, device=resolved_device, ) positive_sample_count = ( int(distances.numel()) if weights is None else int(torch.count_nonzero(weights > 0.0).item()) ) if positive_sample_count < coeff_size: raise ValueError( "`sample_distances` and positive `sample_weights` must provide at " f"least `coeff_size` effective samples ({coeff_size})" ) target_values = _evaluate_function(function, distances, dtype=resolved_dtype) design = _design_matrix( distances, coeff_size=coeff_size, first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, ) fit_design = design fit_target = target_values if weights is not None: sqrt_weights = torch.sqrt(weights) fit_design = fit_design * sqrt_weights[:, None] fit_target = fit_target * sqrt_weights rank = int(torch.linalg.matrix_rank(fit_design).detach().cpu().item()) if rank < coeff_size: raise ValueError( "weighted spline design matrix is rank deficient; provide more " "or better distributed samples" ) coeffs = torch.linalg.lstsq(fit_design, fit_target, rcond=rcond).solution.detach() if diagnostic_distances is None: diag_distances = distances.detach().clone() diag_weights = ( None if diagnostic_weights is None and weights is None else _normalize_weights( weights if diagnostic_weights is None else diagnostic_weights, name="diagnostic_weights", distances=diag_distances, dtype=resolved_dtype, device=resolved_device, ) ) else: diag_distances = _as_1d_tensor( diagnostic_distances, name="diagnostic_distances", dtype=resolved_dtype, device=resolved_device, ) _validate_distances( diag_distances, name="diagnostic_distances", full_support_start=full_support_start, cutoff=cutoff, ) diag_weights = _normalize_weights( diagnostic_weights, name="diagnostic_weights", distances=diag_distances, dtype=resolved_dtype, device=resolved_device, ) ( diagnostic_target_values, diagnostic_predicted_values, diagnostic_residuals, weighted_rmse, diagnostics, ) = _diagnostics( function, coeffs, distances=diag_distances, weights=diag_weights, coeff_size=coeff_size, first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, dtype=resolved_dtype, full_support_start=full_support_start, cutoff=cutoff, channel_label=str(channel_label), ) return RadialProjectionResult( coeffs=coeffs, spline=str(spline), coeff_size=coeff_size, full_support_start=full_support_start, cutoff=cutoff, first_knot=float(first_knot), knot_spacing=float(knot_spacing), sample_distances=distances.detach().clone(), sample_weights=None if weights is None else weights.detach().clone(), design_rank=rank, diagnostic_distances=diag_distances.detach().clone(), diagnostic_weights=( None if diag_weights is None else diag_weights.detach().clone() ), target_values=diagnostic_target_values, predicted_values=diagnostic_predicted_values, residuals=diagnostic_residuals, weighted_rmse=weighted_rmse, diagnostics=diagnostics, )
def _pair_tuple(pair: Sequence[int]) -> tuple[int, int]: """Normalize a pair key for projection metadata.""" if len(pair) != 2: raise ValueError("`pair` must contain exactly two atomic numbers") return int(pair[0]), int(pair[1]) def _iter_module_tensors(module: object): """Yield floating tensors from an optional torch module-like object.""" parameters = getattr(module, "parameters", None) if callable(parameters): yield from parameters() buffers = getattr(module, "buffers", None) if callable(buffers): yield from buffers() def _prior_tensor_options( prior: object, *, dtype: torch.dtype | None, device: torch.device | str | None, ) -> tuple[torch.dtype | None, torch.device | None]: """Resolve projection tensor options from explicit args or prior storage.""" resolved_dtype = dtype resolved_device = None if device is None else torch.device(device) for tensor in _iter_module_tensors(prior): if not isinstance(tensor, torch.Tensor) or not tensor.is_floating_point(): continue if resolved_dtype is None: resolved_dtype = tensor.dtype if resolved_device is None: resolved_device = tensor.device if resolved_dtype is not None and resolved_device is not None: break return resolved_dtype, resolved_device def _prior_projection_metadata(prior: object) -> Mapping[str, object]: """Return projection metadata from an analytic prior when available.""" metadata = { "class": f"{type(prior).__module__}.{type(prior).__qualname__}", } projection_metadata = getattr(prior, "projection_metadata", None) if callable(projection_metadata): supplied = projection_metadata() if not isinstance(supplied, Mapping): raise TypeError("`prior.projection_metadata()` must return a mapping") metadata.update(dict(supplied)) return metadata
[docs] def project_pair_prior( prior: object, pair: Sequence[int], *, coeff_size: int, full_support_start: float = 0.0, cutoff: float | None = None, spline: str = "cubic", n_samples: int | None = None, sample_distances: object | None = None, sample_weights: object | None = None, diagnostic_distances: object | None = None, diagnostic_weights: object | None = None, dtype: torch.dtype | None = None, device: torch.device | str | None = None, rcond: float | None = None, channel_label: str | None = None, ) -> RadialProjectionResult: """ Project one analytic pair prior channel onto a 1D spline coefficient row. The prior must expose ``radial_values(pair, distances)``. If it also exposes ``projection_metadata()``, that metadata is attached to the returned result alongside the projection grid settings and pair channel. """ radial_values = getattr(prior, "radial_values", None) if not callable(radial_values): raise TypeError("`prior` must expose callable `radial_values(pair, distances)`") resolved_pair = _pair_tuple(pair) resolved_cutoff = getattr(prior, "cutoff", None) if cutoff is None else cutoff if resolved_cutoff is None: raise ValueError("`cutoff` is required when `prior.cutoff` is not set") resolved_dtype, resolved_device = _prior_tensor_options( prior, dtype=dtype, device=device, ) def function(distances: torch.Tensor) -> object: return radial_values(resolved_pair, distances) resolved_label = ( f"pair[{resolved_pair!r}]" if channel_label is None else str(channel_label) ) result = project_radial_function( function, coeff_size=coeff_size, full_support_start=full_support_start, cutoff=float(resolved_cutoff), spline=spline, n_samples=n_samples, sample_distances=sample_distances, sample_weights=sample_weights, diagnostic_distances=diagnostic_distances, diagnostic_weights=diagnostic_weights, dtype=resolved_dtype, device=resolved_device, rcond=rcond, channel_label=resolved_label, ) metadata = { "source": "analytic_pair_prior", "pair": resolved_pair, "channel_label": resolved_label, "prior": _prior_projection_metadata(prior), "projection": { "coeff_size": result.coeff_size, "spline": result.spline, "full_support_start": result.full_support_start, "cutoff": result.cutoff, "first_knot": result.first_knot, "knot_spacing": result.knot_spacing, }, } return replace(result, metadata=metadata)
__all__ = [ "RadialProjectionResult", "project_pair_prior", "project_radial_function", ]