Source code for ufp.splines.representation

"""
Stencil representations for uniform spline bases.

Use this module to inspect which coefficients a geometry touches and the basis
weights or gradients that should be applied to those coefficients. The support
helpers exported here are low-level expert APIs; hot callers may use them
directly after doing their own setup-time dispatch and filtering.
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from dataclasses import dataclass

import torch

from ufp.splines._cubic import uniform_basis_and_grad as cubic_basis_and_grad
from ufp.splines._quadratic import uniform_basis_and_grad as quadratic_basis_and_grad
from ufp.splines._quartic import uniform_basis_and_grad as quartic_basis_and_grad


BasisFn = Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]


[docs] @dataclass(frozen=True) class Stencil1D: """Local 1D spline stencil with coefficient indices, values, and gradients.""" indices: torch.Tensor values: torch.Tensor grads: torch.Tensor
[docs] @dataclass(frozen=True) class Stencil2D: """Local 2D spline stencil with flat indices, values, and gradients.""" indices: torch.Tensor values: torch.Tensor grad_x: torch.Tensor grad_y: torch.Tensor
[docs] @dataclass(frozen=True) class Stencil3D: """Local 3D spline stencil with coefficient indices, values, and gradients.""" indices: torch.Tensor values: torch.Tensor grad_x: torch.Tensor grad_y: torch.Tensor grad_z: torch.Tensor
[docs] @dataclass(frozen=True) class SupportedStencil3D: """3D stencil data for coordinates inside the spline support.""" mask: torch.Tensor x: torch.Tensor y: torch.Tensor z: torch.Tensor stencil: Stencil3D
[docs] @dataclass(frozen=True) class Stencil6D: """Local 6D spline stencil with flat indices, values, and per-axis gradients.""" indices: torch.Tensor values: torch.Tensor grads: tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]
_BASIS_AND_DEGREE: dict[str, tuple[int, BasisFn]] = { "quadratic": (2, quadratic_basis_and_grad), "cubic": (3, cubic_basis_and_grad), "quartic": (4, quartic_basis_and_grad), } def _get_basis_and_degree(spline: str) -> tuple[int, BasisFn]: """Return the degree and basis evaluator for a named spline family.""" try: return _BASIS_AND_DEGREE[spline] except KeyError as exc: choices = ", ".join(sorted(_BASIS_AND_DEGREE)) raise ValueError( f"Unsupported spline '{spline}'. Expected one of: {choices}." ) from exc def _scaled_coordinate( x: torch.Tensor, first_knot: float, knot_spacing: float, *, nonnegative: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Convert physical coordinates into cell indices and local spline coordinates.""" scaled = (x - first_knot) / knot_spacing if nonnegative and first_knot <= 0.0: cell = scaled.to(torch.int64) else: cell = torch.floor(scaled) return cell.to(torch.int64), scaled - cell def _wrap_indices(indices: torch.Tensor, size: int) -> torch.Tensor: """Wrap valid negative indices into the coefficient range.""" size = int(size) if size <= 0: raise ValueError("coefficient size must be positive") if torch.any(indices < -size) or torch.any(indices >= size): raise IndexError( f"spline stencil touched coefficients outside [-{size}, {size - 1}]" ) return torch.remainder(indices, size)
[docs] def uniform_support_parameters( *, coeff_size: int, lower_full_support: float, upper_full_support: float, spline: str, ) -> tuple[float, float]: """Compute the first knot and spacing for a uniform spline support.""" coeff_size = int(coeff_size) if coeff_size <= 0: raise ValueError("coefficient size must be positive") degree, _ = _get_basis_and_degree(spline) if coeff_size <= degree: raise ValueError( f"coefficient size must be larger than the spline degree ({degree})" ) lower_full_support = float(lower_full_support) upper_full_support = float(upper_full_support) span = upper_full_support - lower_full_support if span <= 0.0: raise ValueError("upper full-support boundary must exceed the lower boundary") knot_spacing = span / float(coeff_size - degree) first_knot = lower_full_support - degree * knot_spacing return first_knot, knot_spacing
[docs] def spline_support_mask_1d( x: torch.Tensor, *, coeff_size: int, first_knot: float, knot_spacing: float, spline: str, ) -> torch.Tensor: """Return a mask selecting 1D coordinates inside the active spline support.""" degree, _ = _get_basis_and_degree(spline) cell, _ = _scaled_coordinate(x, first_knot, knot_spacing) return (cell >= degree) & (cell < int(coeff_size))
[docs] def spline_support_mask_2d( x: torch.Tensor, y: torch.Tensor, *, coeff_shape: tuple[int, int], first_knot_x: float, first_knot_y: float, knot_spacing_x: float, knot_spacing_y: float, spline: str, ) -> torch.Tensor: """Return a mask selecting 2D coordinates inside active spline support.""" nx, ny = (int(value) for value in coeff_shape) degree, _ = _get_basis_and_degree(spline) cell_x, _ = _scaled_coordinate(x, first_knot_x, knot_spacing_x) cell_y, _ = _scaled_coordinate(y, first_knot_y, knot_spacing_y) return (cell_x >= degree) & (cell_x < nx) & (cell_y >= degree) & (cell_y < ny)
[docs] def spline_support_mask_3d( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, coeff_shape: tuple[int, int, int], first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, spline: str, ) -> torch.Tensor: """Return a mask selecting 3D coordinates inside the active spline support.""" nx, ny, nz = (int(value) for value in coeff_shape) degree, _ = _get_basis_and_degree(spline) cell_x, _ = _scaled_coordinate(x, first_knot_xy, knot_spacing_xy) cell_y, _ = _scaled_coordinate(y, first_knot_xy, knot_spacing_xy) cell_z, _ = _scaled_coordinate(z, first_knot_z, knot_spacing_z) return ( (cell_x >= degree) & (cell_x < nx) & (cell_y >= degree) & (cell_y < ny) & (cell_z >= degree) & (cell_z < nz) )
[docs] def spline_support_mask_6d( coords: tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ], *, coeff_shape: tuple[int, int, int, int, int, int], first_knots: tuple[float, float, float, float, float, float], knot_spacings: tuple[float, float, float, float, float, float], spline: str, ) -> torch.Tensor: """Return a mask selecting 6D coordinates inside active spline support.""" degree, _ = _get_basis_and_degree(spline) mask = torch.ones_like(coords[0], dtype=torch.bool) for coord, size, first_knot, knot_spacing in zip( coords, coeff_shape, first_knots, knot_spacings, strict=True, ): cell, _ = _scaled_coordinate(coord, first_knot, knot_spacing) mask = mask & (cell >= degree) & (cell < int(size)) return mask
[docs] def uniform_stencil_2d( x: torch.Tensor, y: torch.Tensor, *, coeff_shape: tuple[int, int], first_knot_x: float, first_knot_y: float, knot_spacing_x: float, knot_spacing_y: float, spline: str, ) -> Stencil2D: """Return coefficient indices, values, and gradients for a 2D uniform stencil.""" nx, ny = (int(value) for value in coeff_shape) degree, basis_and_grad = _get_basis_and_degree(spline) cell_x, ux = _scaled_coordinate(x, first_knot_x, knot_spacing_x) cell_y, uy = _scaled_coordinate(y, first_knot_y, knot_spacing_y) bx, dbx_du = basis_and_grad(ux) by, dby_du = basis_and_grad(uy) dbx = dbx_du / knot_spacing_x dby = dby_du / knot_spacing_y offsets = torch.arange(degree + 1, dtype=torch.int64, device=x.device) ix = _wrap_indices((cell_x - degree)[:, None] + offsets[None, :], nx) iy = _wrap_indices((cell_y - degree)[:, None] + offsets[None, :], ny) flat = (ix[:, :, None] * ny + iy[:, None, :]).reshape(x.shape[0], -1) bx3 = bx[:, :, None] by3 = by[:, None, :] dbx3 = dbx[:, :, None] dby3 = dby[:, None, :] return Stencil2D( indices=flat, values=(bx3 * by3).reshape(x.shape[0], -1), grad_x=(dbx3 * by3).reshape(x.shape[0], -1), grad_y=(bx3 * dby3).reshape(x.shape[0], -1), )
def _uniform_stencil_3d_from_scaled( cell_x: torch.Tensor, ux: torch.Tensor, cell_y: torch.Tensor, uy: torch.Tensor, cell_z: torch.Tensor, uz: torch.Tensor, *, coeff_shape: tuple[int, int, int], knot_spacing_xy: float, knot_spacing_z: float, spline: str, check_bounds: bool, ) -> Stencil3D: """Build a 3D stencil from precomputed cell/local spline coordinates.""" nx, ny, nz = (int(value) for value in coeff_shape) degree, basis_and_grad = _get_basis_and_degree(spline) if ux.numel() == 0: stencil_width = (degree + 1) ** 3 empty_indices = torch.empty( (0, stencil_width), dtype=torch.int64, device=ux.device, ) empty_values = ux.new_empty((0, stencil_width)) return Stencil3D( indices=empty_indices, values=empty_values, grad_x=empty_values, grad_y=empty_values, grad_z=empty_values, ) bx, dbx_du = basis_and_grad(ux) by, dby_du = basis_and_grad(uy) bz, dbz_du = basis_and_grad(uz) dbx = dbx_du / knot_spacing_xy dby = dby_du / knot_spacing_xy dbz = dbz_du / knot_spacing_z offsets = torch.arange(degree + 1, dtype=torch.int64, device=ux.device) ix = (cell_x - degree)[:, None] + offsets[None, :] iy = (cell_y - degree)[:, None] + offsets[None, :] iz = (cell_z - degree)[:, None] + offsets[None, :] if check_bounds: ix = _wrap_indices(ix, nx) iy = _wrap_indices(iy, ny) iz = _wrap_indices(iz, nz) ix4 = ix[:, :, None, None] iy4 = iy[:, None, :, None] iz4 = iz[:, None, None, :] flat = (((ix4 * ny) + iy4) * nz + iz4).reshape(ux.shape[0], -1) bx4 = bx[:, :, None, None] by4 = by[:, None, :, None] bz4 = bz[:, None, None, :] dbx4 = dbx[:, :, None, None] dby4 = dby[:, None, :, None] dbz4 = dbz[:, None, None, :] values = (bx4 * by4 * bz4).reshape(ux.shape[0], -1) grad_x = (dbx4 * by4 * bz4).reshape(ux.shape[0], -1) grad_y = (bx4 * dby4 * bz4).reshape(ux.shape[0], -1) grad_z = (bx4 * by4 * dbz4).reshape(ux.shape[0], -1) return Stencil3D( indices=flat, values=values, grad_x=grad_x, grad_y=grad_y, grad_z=grad_z, )
[docs] def all_supported_uniform_stencil_3d( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, coeff_shape: tuple[int, int, int], first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, spline: str, ) -> Stencil3D: """ Return an unchecked 3D stencil for coordinates already inside support. Callers must filter coordinates to the full support region first. This is the fast path used by three-body evaluation, where center-neighbor and neighbor-neighbor distances are non-negative by construction. Args: x: First coordinate tensor. y: Second coordinate tensor. z: Third coordinate tensor. coeff_shape: Three-dimensional coefficient grid shape. first_knot_xy: First knot for the first two distance coordinates. first_knot_z: First knot for the neighbor-neighbor distance coordinate. knot_spacing_xy: Knot spacing for the first two distance coordinates. knot_spacing_z: Knot spacing for the neighbor-neighbor distance coordinate. spline: Spline family name. Returns: Dense stencil values and gradients for all provided coordinates. Raises: ValueError: If coordinate tensors have different shapes. """ if x.shape != y.shape or x.shape != z.shape: raise ValueError("`x`, `y`, and `z` must have matching shapes") flat_x = x.reshape(-1) flat_y = y.reshape(-1) flat_z = z.reshape(-1) cell_x, ux = _scaled_coordinate( flat_x, first_knot_xy, knot_spacing_xy, nonnegative=True, ) cell_y, uy = _scaled_coordinate( flat_y, first_knot_xy, knot_spacing_xy, nonnegative=True, ) cell_z, uz = _scaled_coordinate( flat_z, first_knot_z, knot_spacing_z, nonnegative=True, ) return _uniform_stencil_3d_from_scaled( cell_x, ux, cell_y, uy, cell_z, uz, coeff_shape=coeff_shape, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, check_bounds=False, )
[docs] def supported_uniform_stencil_3d( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, coeff_shape: tuple[int, int, int], first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, spline: str, ) -> SupportedStencil3D: """ Return the support mask and unchecked stencil for supported 3D coordinates. Prefer filtering coordinates before calling :func:`all_supported_uniform_stencil_3d` in performance-sensitive paths. Args: x: First coordinate tensor. y: Second coordinate tensor. z: Third coordinate tensor. coeff_shape: Three-dimensional coefficient grid shape. first_knot_xy: First knot for the first two distance coordinates. first_knot_z: First knot for the neighbor-neighbor distance coordinate. knot_spacing_xy: Knot spacing for the first two distance coordinates. knot_spacing_z: Knot spacing for the neighbor-neighbor distance coordinate. spline: Spline family name. Returns: Support mask and stencil for the supported coordinates. Raises: ValueError: If coordinate tensors have different shapes. """ warnings.warn( "`supported_uniform_stencil_3d` is deprecated; filter coordinates before " "calling `all_supported_uniform_stencil_3d` instead.", DeprecationWarning, stacklevel=2, ) if x.shape != y.shape or x.shape != z.shape: raise ValueError("`x`, `y`, and `z` must have matching shapes") nx, ny, nz = (int(value) for value in coeff_shape) degree, _ = _get_basis_and_degree(spline) flat_x = x.reshape(-1) flat_y = y.reshape(-1) flat_z = z.reshape(-1) cell_x, ux = _scaled_coordinate(flat_x, first_knot_xy, knot_spacing_xy) cell_y, uy = _scaled_coordinate(flat_y, first_knot_xy, knot_spacing_xy) cell_z, uz = _scaled_coordinate(flat_z, first_knot_z, knot_spacing_z) flat_mask = ( (cell_x >= degree) & (cell_x < nx) & (cell_y >= degree) & (cell_y < ny) & (cell_z >= degree) & (cell_z < nz) ) supported_x = flat_x[flat_mask] supported_y = flat_y[flat_mask] supported_z = flat_z[flat_mask] stencil = _uniform_stencil_3d_from_scaled( cell_x[flat_mask], ux[flat_mask], cell_y[flat_mask], uy[flat_mask], cell_z[flat_mask], uz[flat_mask], coeff_shape=(nx, ny, nz), knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, check_bounds=False, ) return SupportedStencil3D( mask=flat_mask.reshape(x.shape), x=supported_x, y=supported_y, z=supported_z, stencil=stencil, )
[docs] def uniform_stencil_1d( x: torch.Tensor, *, coeff_size: int, first_knot: float, knot_spacing: float, spline: str, ) -> Stencil1D: """Return coefficient indices, values, and gradients for a 1D uniform stencil.""" degree, basis_and_grad = _get_basis_and_degree(spline) cell, u = _scaled_coordinate(x, first_knot, knot_spacing) basis, grad_u = basis_and_grad(u) offsets = torch.arange(degree + 1, dtype=torch.int64, device=x.device) start = cell - degree indices = _wrap_indices(start[:, None] + offsets[None, :], coeff_size) return Stencil1D( indices=indices, values=basis, grads=grad_u / knot_spacing, )
[docs] def uniform_stencil_6d( coords: tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ], *, coeff_shape: tuple[int, int, int, int, int, int], first_knots: tuple[float, float, float, float, float, float], knot_spacings: tuple[float, float, float, float, float, float], spline: str, ) -> Stencil6D: """Return coefficient indices, values, and gradients for a 6D uniform stencil.""" degree, basis_and_grad = _get_basis_and_degree(spline) if any(coord.shape != coords[0].shape for coord in coords[1:]): raise ValueError("all 6D coordinates must have matching shapes") flat_coords = tuple(coord.reshape(-1) for coord in coords) offsets = torch.arange(degree + 1, dtype=torch.int64, device=coords[0].device) indices_1d = [] basis_values = [] basis_grads = [] for coord, size, first_knot, knot_spacing in zip( flat_coords, coeff_shape, first_knots, knot_spacings, strict=True, ): cell, u = _scaled_coordinate(coord, first_knot, knot_spacing) basis, grad_u = basis_and_grad(u) indices_1d.append( _wrap_indices((cell - degree)[:, None] + offsets[None, :], int(size)) ) basis_values.append(basis) basis_grads.append(grad_u / knot_spacing) strides = ( coeff_shape[1] * coeff_shape[2] * coeff_shape[3] * coeff_shape[4] * coeff_shape[5], coeff_shape[2] * coeff_shape[3] * coeff_shape[4] * coeff_shape[5], coeff_shape[3] * coeff_shape[4] * coeff_shape[5], coeff_shape[4] * coeff_shape[5], coeff_shape[5], 1, ) shaped_indices = [ indices_1d[0][:, :, None, None, None, None, None], indices_1d[1][:, None, :, None, None, None, None], indices_1d[2][:, None, None, :, None, None, None], indices_1d[3][:, None, None, None, :, None, None], indices_1d[4][:, None, None, None, None, :, None], indices_1d[5][:, None, None, None, None, None, :], ] flat_indices = sum( index * int(stride) for index, stride in zip(shaped_indices, strides, strict=True) ).reshape(flat_coords[0].shape[0], -1) # type: ignore[union-attr] shaped_basis = [ basis_values[0][:, :, None, None, None, None, None], basis_values[1][:, None, :, None, None, None, None], basis_values[2][:, None, None, :, None, None, None], basis_values[3][:, None, None, None, :, None, None], basis_values[4][:, None, None, None, None, :, None], basis_values[5][:, None, None, None, None, None, :], ] values_nd = ( shaped_basis[0] * shaped_basis[1] * shaped_basis[2] * shaped_basis[3] * shaped_basis[4] * shaped_basis[5] ) shaped_grads = [ basis_grads[0][:, :, None, None, None, None, None], basis_grads[1][:, None, :, None, None, None, None], basis_grads[2][:, None, None, :, None, None, None], basis_grads[3][:, None, None, None, :, None, None], basis_grads[4][:, None, None, None, None, :, None], basis_grads[5][:, None, None, None, None, None, :], ] grads = [] for axis in range(6): product = shaped_grads[axis] for other_axis, basis in enumerate(shaped_basis): if other_axis != axis: product = product * basis grads.append(product.reshape(flat_coords[0].shape[0], -1)) return Stencil6D( indices=flat_indices, values=values_nd.reshape(flat_coords[0].shape[0], -1), grads=tuple(grads), # type: ignore[arg-type] )
[docs] def uniform_stencil_3d( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, coeff_shape: tuple[int, int, int], first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, spline: str, ) -> Stencil3D: """Return coefficient indices, values, and gradients for a 3D uniform stencil.""" cell_x, ux = _scaled_coordinate(x, first_knot_xy, knot_spacing_xy) cell_y, uy = _scaled_coordinate(y, first_knot_xy, knot_spacing_xy) cell_z, uz = _scaled_coordinate(z, first_knot_z, knot_spacing_z) return _uniform_stencil_3d_from_scaled( cell_x, ux, cell_y, uy, cell_z, uz, coeff_shape=coeff_shape, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, spline=spline, check_bounds=True, )
__all__ = [ "Stencil1D", "Stencil2D", "Stencil3D", "Stencil6D", "SupportedStencil3D", "all_supported_uniform_stencil_3d", "spline_support_mask_1d", "spline_support_mask_2d", "spline_support_mask_3d", "spline_support_mask_6d", "supported_uniform_stencil_3d", "uniform_stencil_1d", "uniform_stencil_2d", "uniform_stencil_3d", "uniform_stencil_6d", "uniform_support_parameters", ]