Source code for ufp.splines.representation
"""
Stencil representations for uniform spline bases.
Use this module to inspect which coefficients a geometry touches and the basis
weights or gradients that should be applied to those coefficients. The support
helpers exported here are low-level expert APIs; hot callers may use them
directly after doing their own setup-time dispatch and filtering.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable
from dataclasses import dataclass
import torch
from ufp.splines._cubic import uniform_basis_and_grad as cubic_basis_and_grad
from ufp.splines._quadratic import uniform_basis_and_grad as quadratic_basis_and_grad
from ufp.splines._quartic import uniform_basis_and_grad as quartic_basis_and_grad
BasisFn = Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]
[docs]
@dataclass(frozen=True)
class Stencil1D:
"""Local 1D spline stencil with coefficient indices, values, and gradients."""
indices: torch.Tensor
values: torch.Tensor
grads: torch.Tensor
[docs]
@dataclass(frozen=True)
class Stencil2D:
"""Local 2D spline stencil with flat indices, values, and gradients."""
indices: torch.Tensor
values: torch.Tensor
grad_x: torch.Tensor
grad_y: torch.Tensor
[docs]
@dataclass(frozen=True)
class Stencil3D:
"""Local 3D spline stencil with coefficient indices, values, and gradients."""
indices: torch.Tensor
values: torch.Tensor
grad_x: torch.Tensor
grad_y: torch.Tensor
grad_z: torch.Tensor
[docs]
@dataclass(frozen=True)
class SupportedStencil3D:
"""3D stencil data for coordinates inside the spline support."""
mask: torch.Tensor
x: torch.Tensor
y: torch.Tensor
z: torch.Tensor
stencil: Stencil3D
[docs]
@dataclass(frozen=True)
class Stencil6D:
"""Local 6D spline stencil with flat indices, values, and per-axis gradients."""
indices: torch.Tensor
values: torch.Tensor
grads: tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
_BASIS_AND_DEGREE: dict[str, tuple[int, BasisFn]] = {
"quadratic": (2, quadratic_basis_and_grad),
"cubic": (3, cubic_basis_and_grad),
"quartic": (4, quartic_basis_and_grad),
}
def _get_basis_and_degree(spline: str) -> tuple[int, BasisFn]:
"""Return the degree and basis evaluator for a named spline family."""
try:
return _BASIS_AND_DEGREE[spline]
except KeyError as exc:
choices = ", ".join(sorted(_BASIS_AND_DEGREE))
raise ValueError(
f"Unsupported spline '{spline}'. Expected one of: {choices}."
) from exc
def _scaled_coordinate(
x: torch.Tensor,
first_knot: float,
knot_spacing: float,
*,
nonnegative: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert physical coordinates into cell indices and local spline coordinates."""
scaled = (x - first_knot) / knot_spacing
if nonnegative and first_knot <= 0.0:
cell = scaled.to(torch.int64)
else:
cell = torch.floor(scaled)
return cell.to(torch.int64), scaled - cell
def _wrap_indices(indices: torch.Tensor, size: int) -> torch.Tensor:
"""Wrap valid negative indices into the coefficient range."""
size = int(size)
if size <= 0:
raise ValueError("coefficient size must be positive")
if torch.any(indices < -size) or torch.any(indices >= size):
raise IndexError(
f"spline stencil touched coefficients outside [-{size}, {size - 1}]"
)
return torch.remainder(indices, size)
[docs]
def uniform_support_parameters(
*,
coeff_size: int,
lower_full_support: float,
upper_full_support: float,
spline: str,
) -> tuple[float, float]:
"""Compute the first knot and spacing for a uniform spline support."""
coeff_size = int(coeff_size)
if coeff_size <= 0:
raise ValueError("coefficient size must be positive")
degree, _ = _get_basis_and_degree(spline)
if coeff_size <= degree:
raise ValueError(
f"coefficient size must be larger than the spline degree ({degree})"
)
lower_full_support = float(lower_full_support)
upper_full_support = float(upper_full_support)
span = upper_full_support - lower_full_support
if span <= 0.0:
raise ValueError("upper full-support boundary must exceed the lower boundary")
knot_spacing = span / float(coeff_size - degree)
first_knot = lower_full_support - degree * knot_spacing
return first_knot, knot_spacing
[docs]
def spline_support_mask_1d(
x: torch.Tensor,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
spline: str,
) -> torch.Tensor:
"""Return a mask selecting 1D coordinates inside the active spline support."""
degree, _ = _get_basis_and_degree(spline)
cell, _ = _scaled_coordinate(x, first_knot, knot_spacing)
return (cell >= degree) & (cell < int(coeff_size))
[docs]
def spline_support_mask_2d(
x: torch.Tensor,
y: torch.Tensor,
*,
coeff_shape: tuple[int, int],
first_knot_x: float,
first_knot_y: float,
knot_spacing_x: float,
knot_spacing_y: float,
spline: str,
) -> torch.Tensor:
"""Return a mask selecting 2D coordinates inside active spline support."""
nx, ny = (int(value) for value in coeff_shape)
degree, _ = _get_basis_and_degree(spline)
cell_x, _ = _scaled_coordinate(x, first_knot_x, knot_spacing_x)
cell_y, _ = _scaled_coordinate(y, first_knot_y, knot_spacing_y)
return (cell_x >= degree) & (cell_x < nx) & (cell_y >= degree) & (cell_y < ny)
[docs]
def spline_support_mask_3d(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
spline: str,
) -> torch.Tensor:
"""Return a mask selecting 3D coordinates inside the active spline support."""
nx, ny, nz = (int(value) for value in coeff_shape)
degree, _ = _get_basis_and_degree(spline)
cell_x, _ = _scaled_coordinate(x, first_knot_xy, knot_spacing_xy)
cell_y, _ = _scaled_coordinate(y, first_knot_xy, knot_spacing_xy)
cell_z, _ = _scaled_coordinate(z, first_knot_z, knot_spacing_z)
return (
(cell_x >= degree)
& (cell_x < nx)
& (cell_y >= degree)
& (cell_y < ny)
& (cell_z >= degree)
& (cell_z < nz)
)
[docs]
def spline_support_mask_6d(
coords: tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
*,
coeff_shape: tuple[int, int, int, int, int, int],
first_knots: tuple[float, float, float, float, float, float],
knot_spacings: tuple[float, float, float, float, float, float],
spline: str,
) -> torch.Tensor:
"""Return a mask selecting 6D coordinates inside active spline support."""
degree, _ = _get_basis_and_degree(spline)
mask = torch.ones_like(coords[0], dtype=torch.bool)
for coord, size, first_knot, knot_spacing in zip(
coords,
coeff_shape,
first_knots,
knot_spacings,
strict=True,
):
cell, _ = _scaled_coordinate(coord, first_knot, knot_spacing)
mask = mask & (cell >= degree) & (cell < int(size))
return mask
[docs]
def uniform_stencil_2d(
x: torch.Tensor,
y: torch.Tensor,
*,
coeff_shape: tuple[int, int],
first_knot_x: float,
first_knot_y: float,
knot_spacing_x: float,
knot_spacing_y: float,
spline: str,
) -> Stencil2D:
"""Return coefficient indices, values, and gradients for a 2D uniform stencil."""
nx, ny = (int(value) for value in coeff_shape)
degree, basis_and_grad = _get_basis_and_degree(spline)
cell_x, ux = _scaled_coordinate(x, first_knot_x, knot_spacing_x)
cell_y, uy = _scaled_coordinate(y, first_knot_y, knot_spacing_y)
bx, dbx_du = basis_and_grad(ux)
by, dby_du = basis_and_grad(uy)
dbx = dbx_du / knot_spacing_x
dby = dby_du / knot_spacing_y
offsets = torch.arange(degree + 1, dtype=torch.int64, device=x.device)
ix = _wrap_indices((cell_x - degree)[:, None] + offsets[None, :], nx)
iy = _wrap_indices((cell_y - degree)[:, None] + offsets[None, :], ny)
flat = (ix[:, :, None] * ny + iy[:, None, :]).reshape(x.shape[0], -1)
bx3 = bx[:, :, None]
by3 = by[:, None, :]
dbx3 = dbx[:, :, None]
dby3 = dby[:, None, :]
return Stencil2D(
indices=flat,
values=(bx3 * by3).reshape(x.shape[0], -1),
grad_x=(dbx3 * by3).reshape(x.shape[0], -1),
grad_y=(bx3 * dby3).reshape(x.shape[0], -1),
)
def _uniform_stencil_3d_from_scaled(
cell_x: torch.Tensor,
ux: torch.Tensor,
cell_y: torch.Tensor,
uy: torch.Tensor,
cell_z: torch.Tensor,
uz: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
knot_spacing_xy: float,
knot_spacing_z: float,
spline: str,
check_bounds: bool,
) -> Stencil3D:
"""Build a 3D stencil from precomputed cell/local spline coordinates."""
nx, ny, nz = (int(value) for value in coeff_shape)
degree, basis_and_grad = _get_basis_and_degree(spline)
if ux.numel() == 0:
stencil_width = (degree + 1) ** 3
empty_indices = torch.empty(
(0, stencil_width),
dtype=torch.int64,
device=ux.device,
)
empty_values = ux.new_empty((0, stencil_width))
return Stencil3D(
indices=empty_indices,
values=empty_values,
grad_x=empty_values,
grad_y=empty_values,
grad_z=empty_values,
)
bx, dbx_du = basis_and_grad(ux)
by, dby_du = basis_and_grad(uy)
bz, dbz_du = basis_and_grad(uz)
dbx = dbx_du / knot_spacing_xy
dby = dby_du / knot_spacing_xy
dbz = dbz_du / knot_spacing_z
offsets = torch.arange(degree + 1, dtype=torch.int64, device=ux.device)
ix = (cell_x - degree)[:, None] + offsets[None, :]
iy = (cell_y - degree)[:, None] + offsets[None, :]
iz = (cell_z - degree)[:, None] + offsets[None, :]
if check_bounds:
ix = _wrap_indices(ix, nx)
iy = _wrap_indices(iy, ny)
iz = _wrap_indices(iz, nz)
ix4 = ix[:, :, None, None]
iy4 = iy[:, None, :, None]
iz4 = iz[:, None, None, :]
flat = (((ix4 * ny) + iy4) * nz + iz4).reshape(ux.shape[0], -1)
bx4 = bx[:, :, None, None]
by4 = by[:, None, :, None]
bz4 = bz[:, None, None, :]
dbx4 = dbx[:, :, None, None]
dby4 = dby[:, None, :, None]
dbz4 = dbz[:, None, None, :]
values = (bx4 * by4 * bz4).reshape(ux.shape[0], -1)
grad_x = (dbx4 * by4 * bz4).reshape(ux.shape[0], -1)
grad_y = (bx4 * dby4 * bz4).reshape(ux.shape[0], -1)
grad_z = (bx4 * by4 * dbz4).reshape(ux.shape[0], -1)
return Stencil3D(
indices=flat,
values=values,
grad_x=grad_x,
grad_y=grad_y,
grad_z=grad_z,
)
[docs]
def all_supported_uniform_stencil_3d(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
spline: str,
) -> Stencil3D:
"""
Return an unchecked 3D stencil for coordinates already inside support.
Callers must filter coordinates to the full support region first. This is the
fast path used by three-body evaluation, where center-neighbor and
neighbor-neighbor distances are non-negative by construction.
Args:
x: First coordinate tensor.
y: Second coordinate tensor.
z: Third coordinate tensor.
coeff_shape: Three-dimensional coefficient grid shape.
first_knot_xy: First knot for the first two distance coordinates.
first_knot_z: First knot for the neighbor-neighbor distance coordinate.
knot_spacing_xy: Knot spacing for the first two distance coordinates.
knot_spacing_z: Knot spacing for the neighbor-neighbor distance coordinate.
spline: Spline family name.
Returns:
Dense stencil values and gradients for all provided coordinates.
Raises:
ValueError: If coordinate tensors have different shapes.
"""
if x.shape != y.shape or x.shape != z.shape:
raise ValueError("`x`, `y`, and `z` must have matching shapes")
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
flat_z = z.reshape(-1)
cell_x, ux = _scaled_coordinate(
flat_x,
first_knot_xy,
knot_spacing_xy,
nonnegative=True,
)
cell_y, uy = _scaled_coordinate(
flat_y,
first_knot_xy,
knot_spacing_xy,
nonnegative=True,
)
cell_z, uz = _scaled_coordinate(
flat_z,
first_knot_z,
knot_spacing_z,
nonnegative=True,
)
return _uniform_stencil_3d_from_scaled(
cell_x,
ux,
cell_y,
uy,
cell_z,
uz,
coeff_shape=coeff_shape,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
check_bounds=False,
)
[docs]
def supported_uniform_stencil_3d(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
spline: str,
) -> SupportedStencil3D:
"""
Return the support mask and unchecked stencil for supported 3D coordinates.
Prefer filtering coordinates before calling
:func:`all_supported_uniform_stencil_3d` in performance-sensitive paths.
Args:
x: First coordinate tensor.
y: Second coordinate tensor.
z: Third coordinate tensor.
coeff_shape: Three-dimensional coefficient grid shape.
first_knot_xy: First knot for the first two distance coordinates.
first_knot_z: First knot for the neighbor-neighbor distance coordinate.
knot_spacing_xy: Knot spacing for the first two distance coordinates.
knot_spacing_z: Knot spacing for the neighbor-neighbor distance coordinate.
spline: Spline family name.
Returns:
Support mask and stencil for the supported coordinates.
Raises:
ValueError: If coordinate tensors have different shapes.
"""
warnings.warn(
"`supported_uniform_stencil_3d` is deprecated; filter coordinates before "
"calling `all_supported_uniform_stencil_3d` instead.",
DeprecationWarning,
stacklevel=2,
)
if x.shape != y.shape or x.shape != z.shape:
raise ValueError("`x`, `y`, and `z` must have matching shapes")
nx, ny, nz = (int(value) for value in coeff_shape)
degree, _ = _get_basis_and_degree(spline)
flat_x = x.reshape(-1)
flat_y = y.reshape(-1)
flat_z = z.reshape(-1)
cell_x, ux = _scaled_coordinate(flat_x, first_knot_xy, knot_spacing_xy)
cell_y, uy = _scaled_coordinate(flat_y, first_knot_xy, knot_spacing_xy)
cell_z, uz = _scaled_coordinate(flat_z, first_knot_z, knot_spacing_z)
flat_mask = (
(cell_x >= degree)
& (cell_x < nx)
& (cell_y >= degree)
& (cell_y < ny)
& (cell_z >= degree)
& (cell_z < nz)
)
supported_x = flat_x[flat_mask]
supported_y = flat_y[flat_mask]
supported_z = flat_z[flat_mask]
stencil = _uniform_stencil_3d_from_scaled(
cell_x[flat_mask],
ux[flat_mask],
cell_y[flat_mask],
uy[flat_mask],
cell_z[flat_mask],
uz[flat_mask],
coeff_shape=(nx, ny, nz),
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
check_bounds=False,
)
return SupportedStencil3D(
mask=flat_mask.reshape(x.shape),
x=supported_x,
y=supported_y,
z=supported_z,
stencil=stencil,
)
[docs]
def uniform_stencil_1d(
x: torch.Tensor,
*,
coeff_size: int,
first_knot: float,
knot_spacing: float,
spline: str,
) -> Stencil1D:
"""Return coefficient indices, values, and gradients for a 1D uniform stencil."""
degree, basis_and_grad = _get_basis_and_degree(spline)
cell, u = _scaled_coordinate(x, first_knot, knot_spacing)
basis, grad_u = basis_and_grad(u)
offsets = torch.arange(degree + 1, dtype=torch.int64, device=x.device)
start = cell - degree
indices = _wrap_indices(start[:, None] + offsets[None, :], coeff_size)
return Stencil1D(
indices=indices,
values=basis,
grads=grad_u / knot_spacing,
)
[docs]
def uniform_stencil_6d(
coords: tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
*,
coeff_shape: tuple[int, int, int, int, int, int],
first_knots: tuple[float, float, float, float, float, float],
knot_spacings: tuple[float, float, float, float, float, float],
spline: str,
) -> Stencil6D:
"""Return coefficient indices, values, and gradients for a 6D uniform stencil."""
degree, basis_and_grad = _get_basis_and_degree(spline)
if any(coord.shape != coords[0].shape for coord in coords[1:]):
raise ValueError("all 6D coordinates must have matching shapes")
flat_coords = tuple(coord.reshape(-1) for coord in coords)
offsets = torch.arange(degree + 1, dtype=torch.int64, device=coords[0].device)
indices_1d = []
basis_values = []
basis_grads = []
for coord, size, first_knot, knot_spacing in zip(
flat_coords,
coeff_shape,
first_knots,
knot_spacings,
strict=True,
):
cell, u = _scaled_coordinate(coord, first_knot, knot_spacing)
basis, grad_u = basis_and_grad(u)
indices_1d.append(
_wrap_indices((cell - degree)[:, None] + offsets[None, :], int(size))
)
basis_values.append(basis)
basis_grads.append(grad_u / knot_spacing)
strides = (
coeff_shape[1]
* coeff_shape[2]
* coeff_shape[3]
* coeff_shape[4]
* coeff_shape[5],
coeff_shape[2] * coeff_shape[3] * coeff_shape[4] * coeff_shape[5],
coeff_shape[3] * coeff_shape[4] * coeff_shape[5],
coeff_shape[4] * coeff_shape[5],
coeff_shape[5],
1,
)
shaped_indices = [
indices_1d[0][:, :, None, None, None, None, None],
indices_1d[1][:, None, :, None, None, None, None],
indices_1d[2][:, None, None, :, None, None, None],
indices_1d[3][:, None, None, None, :, None, None],
indices_1d[4][:, None, None, None, None, :, None],
indices_1d[5][:, None, None, None, None, None, :],
]
flat_indices = sum(
index * int(stride)
for index, stride in zip(shaped_indices, strides, strict=True)
).reshape(flat_coords[0].shape[0], -1) # type: ignore[union-attr]
shaped_basis = [
basis_values[0][:, :, None, None, None, None, None],
basis_values[1][:, None, :, None, None, None, None],
basis_values[2][:, None, None, :, None, None, None],
basis_values[3][:, None, None, None, :, None, None],
basis_values[4][:, None, None, None, None, :, None],
basis_values[5][:, None, None, None, None, None, :],
]
values_nd = (
shaped_basis[0]
* shaped_basis[1]
* shaped_basis[2]
* shaped_basis[3]
* shaped_basis[4]
* shaped_basis[5]
)
shaped_grads = [
basis_grads[0][:, :, None, None, None, None, None],
basis_grads[1][:, None, :, None, None, None, None],
basis_grads[2][:, None, None, :, None, None, None],
basis_grads[3][:, None, None, None, :, None, None],
basis_grads[4][:, None, None, None, None, :, None],
basis_grads[5][:, None, None, None, None, None, :],
]
grads = []
for axis in range(6):
product = shaped_grads[axis]
for other_axis, basis in enumerate(shaped_basis):
if other_axis != axis:
product = product * basis
grads.append(product.reshape(flat_coords[0].shape[0], -1))
return Stencil6D(
indices=flat_indices,
values=values_nd.reshape(flat_coords[0].shape[0], -1),
grads=tuple(grads), # type: ignore[arg-type]
)
[docs]
def uniform_stencil_3d(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
spline: str,
) -> Stencil3D:
"""Return coefficient indices, values, and gradients for a 3D uniform stencil."""
cell_x, ux = _scaled_coordinate(x, first_knot_xy, knot_spacing_xy)
cell_y, uy = _scaled_coordinate(y, first_knot_xy, knot_spacing_xy)
cell_z, uz = _scaled_coordinate(z, first_knot_z, knot_spacing_z)
return _uniform_stencil_3d_from_scaled(
cell_x,
ux,
cell_y,
uy,
cell_z,
uz,
coeff_shape=coeff_shape,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
spline=spline,
check_bounds=True,
)
__all__ = [
"Stencil1D",
"Stencil2D",
"Stencil3D",
"Stencil6D",
"SupportedStencil3D",
"all_supported_uniform_stencil_3d",
"spline_support_mask_1d",
"spline_support_mask_2d",
"spline_support_mask_3d",
"spline_support_mask_6d",
"supported_uniform_stencil_3d",
"uniform_stencil_1d",
"uniform_stencil_2d",
"uniform_stencil_3d",
"uniform_stencil_6d",
"uniform_support_parameters",
]