Source code for ufp.splines.fitting
"""Projection helpers for uniform one-dimensional splines."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import torch
from ufp.splines.representation import (
uniform_stencil_1d,
uniform_support_parameters,
)
[docs]
@dataclass(frozen=True)
class UniformSpline1DFitResult:
"""Diagnostics from fitting one callable onto a uniform 1D spline basis."""
coeffs: torch.Tensor
sample_distances: torch.Tensor
target_values: torch.Tensor
predicted_values: torch.Tensor
rmse: float
max_abs_error: float
target_gradients: torch.Tensor | None = None
predicted_gradients: torch.Tensor | None = None
gradient_rmse: float | None = None
gradient_max_abs_error: float | None = None
def _row_matrix_from_stencil(
indices: torch.Tensor,
values: torch.Tensor,
*,
coeff_size: int,
) -> torch.Tensor:
"""Materialize one dense design matrix from sparse local stencil rows."""
matrix = values.new_zeros((values.shape[0], int(coeff_size)))
matrix.scatter_add_(1, indices, values)
return matrix
def _as_float(value: torch.Tensor) -> float:
"""Return one scalar tensor as a Python float."""
return float(value.detach().cpu().item())
[docs]
def fit_uniform_spline_1d(
function: Callable[[torch.Tensor], torch.Tensor],
*,
coeff_size: int,
lower_full_support: float,
upper_full_support: float,
spline: str = "cubic",
n_samples: int | None = None,
sample_distances: torch.Tensor | None = None,
derivative_weight: float = 1.0,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
rcond: float | None = None,
) -> UniformSpline1DFitResult:
"""
Fit a callable on a uniform 1D spline basis with optional derivative rows.
Args:
function: Callable accepting a distance tensor and returning energy-like
values with the same shape.
coeff_size: Number of spline coefficients.
lower_full_support: Lower physical distance with full spline support.
upper_full_support: Upper physical distance with full spline support.
spline: Spline family name.
n_samples: Number of midpoint samples when ``sample_distances`` is not
supplied. Defaults to ``max(4 * coeff_size, coeff_size + 8)``.
sample_distances: Optional explicit sample distances.
derivative_weight: Weight applied to derivative rows. Set to ``0.0`` to
fit only function values.
dtype: Optional fitting dtype.
device: Optional fitting device.
rcond: Optional cutoff passed to :func:`torch.linalg.lstsq`.
Returns:
Coefficients and projection diagnostics.
"""
coeff_size = int(coeff_size)
derivative_weight = float(derivative_weight)
if derivative_weight < 0.0:
raise ValueError("`derivative_weight` must be non-negative")
resolved_device = None if device is None else torch.device(device)
resolved_dtype = torch.get_default_dtype() if dtype is None else dtype
lower_full_support = float(lower_full_support)
upper_full_support = float(upper_full_support)
if upper_full_support <= lower_full_support:
raise ValueError(
"`upper_full_support` must be greater than `lower_full_support`"
)
first_knot, knot_spacing = uniform_support_parameters(
coeff_size=coeff_size,
lower_full_support=lower_full_support,
upper_full_support=upper_full_support,
spline=spline,
)
if sample_distances is None:
resolved_n_samples = (
max(4 * coeff_size, coeff_size + 8) if n_samples is None else int(n_samples)
)
if resolved_n_samples <= 0:
raise ValueError("`n_samples` must be positive")
step = (upper_full_support - lower_full_support) / float(resolved_n_samples)
distances = lower_full_support + step * (
torch.arange(
resolved_n_samples,
dtype=resolved_dtype,
device=resolved_device,
)
+ 0.5
)
else:
distances = torch.as_tensor(
sample_distances,
dtype=resolved_dtype,
device=resolved_device,
).reshape(-1)
if distances.numel() == 0:
raise ValueError("`sample_distances` must contain at least one value")
if torch.any(distances < lower_full_support) or torch.any(
distances >= upper_full_support
):
raise ValueError(
"`sample_distances` must lie inside the full-support interval "
"[lower_full_support, upper_full_support)"
)
fit_distances = distances.detach().clone().requires_grad_(derivative_weight > 0.0)
target_values = function(fit_distances)
target_values = torch.as_tensor(
target_values,
dtype=resolved_dtype,
device=fit_distances.device,
).reshape(-1)
if target_values.shape != fit_distances.shape:
raise ValueError("`function` must return values with the same shape as input")
target_gradients = None
if derivative_weight > 0.0:
target_gradients = torch.autograd.grad(
target_values.sum(),
fit_distances,
create_graph=False,
retain_graph=False,
)[0].detach()
fit_distances = fit_distances.detach()
target_values = target_values.detach()
stencil = uniform_stencil_1d(
fit_distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
spline=spline,
)
value_matrix = _row_matrix_from_stencil(
stencil.indices,
stencil.values,
coeff_size=coeff_size,
)
rows = [value_matrix]
rhs = [target_values]
gradient_matrix = None
if target_gradients is not None:
gradient_matrix = _row_matrix_from_stencil(
stencil.indices,
stencil.grads,
coeff_size=coeff_size,
)
rows.append(float(derivative_weight) * gradient_matrix)
rhs.append(float(derivative_weight) * target_gradients)
design = torch.cat(rows, dim=0)
target = torch.cat(rhs, dim=0)
coeffs = torch.linalg.lstsq(design, target, rcond=rcond).solution
predicted_values = value_matrix @ coeffs
value_error = predicted_values - target_values
rmse = _as_float(torch.sqrt(torch.mean(value_error.square())))
max_abs_error = _as_float(torch.max(torch.abs(value_error)))
predicted_gradients = None
gradient_rmse = None
gradient_max_abs_error = None
if target_gradients is not None:
assert gradient_matrix is not None
predicted_gradients = gradient_matrix @ coeffs
gradient_error = predicted_gradients - target_gradients
gradient_rmse = _as_float(torch.sqrt(torch.mean(gradient_error.square())))
gradient_max_abs_error = _as_float(torch.max(torch.abs(gradient_error)))
return UniformSpline1DFitResult(
coeffs=coeffs,
sample_distances=fit_distances,
target_values=target_values,
predicted_values=predicted_values,
rmse=rmse,
max_abs_error=max_abs_error,
target_gradients=target_gradients,
predicted_gradients=predicted_gradients,
gradient_rmse=gradient_rmse,
gradient_max_abs_error=gradient_max_abs_error,
)
__all__ = [
"UniformSpline1DFitResult",
"fit_uniform_spline_1d",
]