"""Derivative helpers for uniform one-dimensional cubic splines."""
from __future__ import annotations
from dataclasses import dataclass
import torch
from ufp.splines._cubic import cubic_eval_1d_with_grads
from ufp.splines.representation import spline_support_mask_1d, uniform_stencil_1d
def _as_distance_tensor(
distances,
*,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""Return distances as a flat tensor while preserving tensor dtype by default."""
tensor = torch.as_tensor(distances, dtype=dtype, device=device)
return tensor.reshape(-1)
def _validate_cubic_support(
distances: torch.Tensor,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
) -> None:
"""Reject points that do not have full cubic spline support."""
supported = spline_support_mask_1d(
distances,
coeff_size=int(coeff_size),
first_knot=float(first_knot),
knot_spacing=float(knot_spacing),
spline="cubic",
)
if not bool(torch.all(supported)):
raise ValueError("all distances must lie inside the cubic spline support")
def _row_matrix_from_stencil(
indices: torch.Tensor,
values: torch.Tensor,
*,
coeff_size: int,
) -> torch.Tensor:
"""Materialize a dense row matrix from local stencil columns and weights."""
matrix = values.new_zeros((values.shape[0], int(coeff_size)))
matrix.scatter_add_(1, indices, values)
return matrix
[docs]
def cubic_value_rows_1d(
distances,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""Return dense cubic B-spline value rows for supported distances."""
distances = _as_distance_tensor(distances, dtype=dtype, device=device)
_validate_cubic_support(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
)
stencil = uniform_stencil_1d(
distances,
coeff_size=int(coeff_size),
first_knot=float(first_knot),
knot_spacing=float(knot_spacing),
spline="cubic",
)
return _row_matrix_from_stencil(
stencil.indices,
stencil.values,
coeff_size=int(coeff_size),
)
[docs]
def cubic_derivative_rows_1d(
distances,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""Return dense first-derivative rows for supported cubic spline distances."""
distances = _as_distance_tensor(distances, dtype=dtype, device=device)
_validate_cubic_support(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
)
stencil = uniform_stencil_1d(
distances,
coeff_size=int(coeff_size),
first_knot=float(first_knot),
knot_spacing=float(knot_spacing),
spline="cubic",
)
return _row_matrix_from_stencil(
stencil.indices,
stencil.grads,
coeff_size=int(coeff_size),
)
[docs]
def cubic_second_derivative_operator(
*,
coeff_size: int,
knot_spacing: float,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""Return the operator mapping cubic coefficients to curvature coefficients."""
coeff_size = int(coeff_size)
if coeff_size <= 3:
raise ValueError("cubic coefficient size must be greater than 3")
knot_spacing = float(knot_spacing)
if knot_spacing <= 0.0:
raise ValueError("`knot_spacing` must be positive")
matrix = torch.zeros(
(coeff_size - 2, coeff_size),
dtype=torch.get_default_dtype() if dtype is None else dtype,
device=device,
)
rows = torch.arange(coeff_size - 2, dtype=torch.int64, device=matrix.device)
scale = 1.0 / (knot_spacing * knot_spacing)
matrix[rows, rows] = scale
matrix[rows, rows + 1] = -2.0 * scale
matrix[rows, rows + 2] = scale
return matrix
[docs]
def cubic_second_derivative_values_1d(
distances,
coeffs: torch.Tensor,
*,
first_knot: float,
knot_spacing: float,
) -> torch.Tensor:
"""Evaluate the second derivative of a uniform cubic spline."""
coeffs = torch.as_tensor(coeffs)
if coeffs.ndim < 1:
raise ValueError("`coeffs` must have at least one dimension")
coeff_size = int(coeffs.shape[-1])
distances = _as_distance_tensor(
distances,
dtype=coeffs.dtype,
device=coeffs.device,
)
_validate_cubic_support(
distances,
coeff_size=coeff_size,
first_knot=first_knot,
knot_spacing=knot_spacing,
)
operator = cubic_second_derivative_operator(
coeff_size=coeff_size,
knot_spacing=float(knot_spacing),
dtype=coeffs.dtype,
device=coeffs.device,
)
curvature_coeffs = torch.matmul(coeffs, operator.transpose(0, 1))
linear_first_knot = float(first_knot) + 2.0 * float(knot_spacing)
scaled = (distances - linear_first_knot) / float(knot_spacing)
cell = torch.floor(scaled).to(torch.int64)
curvature_size = int(curvature_coeffs.shape[-1])
cell = torch.clamp(cell, 1, curvature_size - 1)
u = torch.clamp(scaled - cell, 0.0, 1.0)
indices = (
cell[:, None]
- 1
+ torch.arange(
2,
dtype=torch.int64,
device=distances.device,
)[None, :]
)
if bool(torch.any(indices < 0)) or bool(torch.any(indices >= curvature_size)):
raise ValueError("all distances must lie inside the curvature support")
basis = torch.stack((1.0 - u, u), dim=1)
selected = curvature_coeffs[..., indices]
return (selected * basis).sum(dim=-1)
[docs]
def cubic_spline_diagnostics_1d(
distances,
coeffs: torch.Tensor,
*,
first_knot: float,
knot_spacing: float,
) -> UniformCubicSpline1DDiagnostics:
"""Return sampled value, gradient, and curvature diagnostics for one spline."""
coeffs = torch.as_tensor(coeffs)
if coeffs.ndim != 1:
raise ValueError("`coeffs` must have shape (n_coeffs,)")
distances = _as_distance_tensor(
distances,
dtype=coeffs.dtype,
device=coeffs.device,
)
_validate_cubic_support(
distances,
coeff_size=int(coeffs.shape[0]),
first_knot=first_knot,
knot_spacing=knot_spacing,
)
values, gradients = cubic_eval_1d_with_grads(
float(knot_spacing),
distances - float(first_knot),
coeffs,
)
curvatures = cubic_second_derivative_values_1d(
distances,
coeffs,
first_knot=first_knot,
knot_spacing=knot_spacing,
)
return UniformCubicSpline1DDiagnostics(
distances=distances,
values=values,
gradients=gradients,
curvatures=curvatures,
)
__all__ = [
"UniformCubicSpline1DDiagnostics",
"cubic_derivative_rows_1d",
"cubic_second_derivative_operator",
"cubic_second_derivative_values_1d",
"cubic_spline_diagnostics_1d",
"cubic_value_rows_1d",
]