Source code for ufp.splines.fitting

"""Projection helpers for uniform one-dimensional splines."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass

import torch

from ufp.splines.representation import (
    uniform_stencil_1d,
    uniform_support_parameters,
)


[docs] @dataclass(frozen=True) class UniformSpline1DFitResult: """Diagnostics from fitting one callable onto a uniform 1D spline basis.""" coeffs: torch.Tensor sample_distances: torch.Tensor target_values: torch.Tensor predicted_values: torch.Tensor rmse: float max_abs_error: float target_gradients: torch.Tensor | None = None predicted_gradients: torch.Tensor | None = None gradient_rmse: float | None = None gradient_max_abs_error: float | None = None
def _row_matrix_from_stencil( indices: torch.Tensor, values: torch.Tensor, *, coeff_size: int, ) -> torch.Tensor: """Materialize one dense design matrix from sparse local stencil rows.""" matrix = values.new_zeros((values.shape[0], int(coeff_size))) matrix.scatter_add_(1, indices, values) return matrix def _as_float(value: torch.Tensor) -> float: """Return one scalar tensor as a Python float.""" return float(value.detach().cpu().item())
[docs] def fit_uniform_spline_1d( function: Callable[[torch.Tensor], torch.Tensor], *, coeff_size: int, lower_full_support: float, upper_full_support: float, spline: str = "cubic", n_samples: int | None = None, sample_distances: torch.Tensor | None = None, derivative_weight: float = 1.0, dtype: torch.dtype | None = None, device: torch.device | str | None = None, rcond: float | None = None, ) -> UniformSpline1DFitResult: """ Fit a callable on a uniform 1D spline basis with optional derivative rows. Args: function: Callable accepting a distance tensor and returning energy-like values with the same shape. coeff_size: Number of spline coefficients. lower_full_support: Lower physical distance with full spline support. upper_full_support: Upper physical distance with full spline support. spline: Spline family name. n_samples: Number of midpoint samples when ``sample_distances`` is not supplied. Defaults to ``max(4 * coeff_size, coeff_size + 8)``. sample_distances: Optional explicit sample distances. derivative_weight: Weight applied to derivative rows. Set to ``0.0`` to fit only function values. dtype: Optional fitting dtype. device: Optional fitting device. rcond: Optional cutoff passed to :func:`torch.linalg.lstsq`. Returns: Coefficients and projection diagnostics. """ coeff_size = int(coeff_size) derivative_weight = float(derivative_weight) if derivative_weight < 0.0: raise ValueError("`derivative_weight` must be non-negative") resolved_device = None if device is None else torch.device(device) resolved_dtype = torch.get_default_dtype() if dtype is None else dtype lower_full_support = float(lower_full_support) upper_full_support = float(upper_full_support) if upper_full_support <= lower_full_support: raise ValueError( "`upper_full_support` must be greater than `lower_full_support`" ) first_knot, knot_spacing = uniform_support_parameters( coeff_size=coeff_size, lower_full_support=lower_full_support, upper_full_support=upper_full_support, spline=spline, ) if sample_distances is None: resolved_n_samples = ( max(4 * coeff_size, coeff_size + 8) if n_samples is None else int(n_samples) ) if resolved_n_samples <= 0: raise ValueError("`n_samples` must be positive") step = (upper_full_support - lower_full_support) / float(resolved_n_samples) distances = lower_full_support + step * ( torch.arange( resolved_n_samples, dtype=resolved_dtype, device=resolved_device, ) + 0.5 ) else: distances = torch.as_tensor( sample_distances, dtype=resolved_dtype, device=resolved_device, ).reshape(-1) if distances.numel() == 0: raise ValueError("`sample_distances` must contain at least one value") if torch.any(distances < lower_full_support) or torch.any( distances >= upper_full_support ): raise ValueError( "`sample_distances` must lie inside the full-support interval " "[lower_full_support, upper_full_support)" ) fit_distances = distances.detach().clone().requires_grad_(derivative_weight > 0.0) target_values = function(fit_distances) target_values = torch.as_tensor( target_values, dtype=resolved_dtype, device=fit_distances.device, ).reshape(-1) if target_values.shape != fit_distances.shape: raise ValueError("`function` must return values with the same shape as input") target_gradients = None if derivative_weight > 0.0: target_gradients = torch.autograd.grad( target_values.sum(), fit_distances, create_graph=False, retain_graph=False, )[0].detach() fit_distances = fit_distances.detach() target_values = target_values.detach() stencil = uniform_stencil_1d( fit_distances, coeff_size=coeff_size, first_knot=first_knot, knot_spacing=knot_spacing, spline=spline, ) value_matrix = _row_matrix_from_stencil( stencil.indices, stencil.values, coeff_size=coeff_size, ) rows = [value_matrix] rhs = [target_values] gradient_matrix = None if target_gradients is not None: gradient_matrix = _row_matrix_from_stencil( stencil.indices, stencil.grads, coeff_size=coeff_size, ) rows.append(float(derivative_weight) * gradient_matrix) rhs.append(float(derivative_weight) * target_gradients) design = torch.cat(rows, dim=0) target = torch.cat(rhs, dim=0) coeffs = torch.linalg.lstsq(design, target, rcond=rcond).solution predicted_values = value_matrix @ coeffs value_error = predicted_values - target_values rmse = _as_float(torch.sqrt(torch.mean(value_error.square()))) max_abs_error = _as_float(torch.max(torch.abs(value_error))) predicted_gradients = None gradient_rmse = None gradient_max_abs_error = None if target_gradients is not None: assert gradient_matrix is not None predicted_gradients = gradient_matrix @ coeffs gradient_error = predicted_gradients - target_gradients gradient_rmse = _as_float(torch.sqrt(torch.mean(gradient_error.square()))) gradient_max_abs_error = _as_float(torch.max(torch.abs(gradient_error))) return UniformSpline1DFitResult( coeffs=coeffs, sample_distances=fit_distances, target_values=target_values, predicted_values=predicted_values, rmse=rmse, max_abs_error=max_abs_error, target_gradients=target_gradients, predicted_gradients=predicted_gradients, gradient_rmse=gradient_rmse, gradient_max_abs_error=gradient_max_abs_error, )
__all__ = [ "UniformSpline1DFitResult", "fit_uniform_spline_1d", ]