Source code for ufp.projection.spline

"""Offline projection helpers for uniform spline coefficient blocks."""

from __future__ import annotations

import math
from collections.abc import Sequence
from dataclasses import dataclass, replace
from typing import Any, cast

import torch

from ufp.projection.diagnostics import (
    ProjectionChannelDiagnostic,
    ProjectionDiagnostics,
    ProjectionErrorSummary,
    ProjectionSupportCoverage,
)
from ufp.splines.representation import (
    Stencil1D,
    Stencil2D,
    Stencil3D,
    uniform_stencil_1d,
    uniform_stencil_2d,
    uniform_stencil_3d,
    uniform_support_parameters,
)
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.triplet2d import SplineTriplet2DTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm


[docs] class SplineProjectionError(ValueError): """Raised when spline coefficients can not be projected safely."""
[docs] @dataclass(frozen=True) class SplineProjectionResult: """Projected coefficients and sample-space diagnostics.""" coeffs: torch.Tensor sample_coordinates: tuple[torch.Tensor, ...] source_values: torch.Tensor projected_values: torch.Tensor rmse: float max_abs_error: float diagnostics: ProjectionDiagnostics source_gradients: tuple[torch.Tensor, ...] | None = None projected_gradients: tuple[torch.Tensor, ...] | None = None gradient_rmse: float | None = None gradient_max_abs_error: float | None = None source_shape: tuple[int, ...] = () target_shape: tuple[int, ...] = () channel: tuple[int, ...] | None = None
def _as_float(value: torch.Tensor) -> float: """Return one scalar tensor as a Python float.""" return float(value.detach().cpu().item()) def _error_summary(residuals: torch.Tensor) -> ProjectionErrorSummary: """Build a scalar projection-error summary from tensor residuals.""" return ProjectionErrorSummary.from_residuals( residuals.detach().cpu().reshape(-1).tolist() ) def _support_coverage( sample_coordinates: tuple[torch.Tensor, ...], *, bounds: tuple[tuple[float, float], ...], ) -> ProjectionSupportCoverage: """Build support coverage metadata for tensor-product sample points.""" sample_count = 0 if not sample_coordinates else int(sample_coordinates[0].numel()) if sample_count == 0: minimum_sample = None maximum_sample = None else: coordinates = torch.cat( [coordinate.detach().cpu().reshape(-1) for coordinate in sample_coordinates] ) minimum_sample = _as_float(torch.min(coordinates)) maximum_sample = _as_float(torch.max(coordinates)) lower_bound = min(float(lower) for lower, _ in bounds) upper_bound = max(float(upper) for _, upper in bounds) return ProjectionSupportCoverage( sample_count=sample_count, covered_count=sample_count, lower_bound=lower_bound, upper_bound=upper_bound, minimum_sample=minimum_sample, maximum_sample=maximum_sample, ) def _gradient_error_summary( source_gradients: tuple[torch.Tensor, ...] | None, projected_gradients: tuple[torch.Tensor, ...] | None, ) -> ProjectionErrorSummary | None: """Build a derivative-error summary when gradients were projected.""" if source_gradients is None or projected_gradients is None: return None residuals = torch.cat( [ (projected - source).detach().reshape(-1) for source, projected in zip( source_gradients, projected_gradients, strict=True, ) ] ) return _error_summary(residuals) def _channel_label(channel: Sequence[int] | None) -> str: """Return a stable diagnostics label for a projected channel.""" if channel is None: return "spline" values = tuple(int(value) for value in channel) if len(values) == 2: return f"pair[{values!r}]" if len(values) == 3: return f"triplet[{values!r}]" return f"channel[{values!r}]" def _projection_diagnostics( *, channel: Sequence[int] | None, sample_coordinates: tuple[torch.Tensor, ...], bounds: tuple[tuple[float, float], ...], source_values: torch.Tensor, projected_values: torch.Tensor, source_gradients: tuple[torch.Tensor, ...] | None, projected_gradients: tuple[torch.Tensor, ...] | None, ) -> ProjectionDiagnostics: """Return shared diagnostics for one spline projection result.""" value_residuals = projected_values - source_values return ProjectionDiagnostics( ( ProjectionChannelDiagnostic( channel_label=_channel_label(channel), sample_count=int(source_values.numel()), support_coverage=_support_coverage( sample_coordinates, bounds=bounds, ), value_error=_error_summary(value_residuals), derivative_error=_gradient_error_summary( source_gradients, projected_gradients, ), ), ) ) def _diagnostics_with_channel( diagnostics: ProjectionDiagnostics, channel: Sequence[int], ) -> ProjectionDiagnostics: """Relabel the primary channel in shared projection diagnostics.""" if not diagnostics.channels: return diagnostics channels = ( replace(diagnostics.channels[0], channel_label=_channel_label(channel)), *diagnostics.channels[1:], ) return ProjectionDiagnostics(channels) def _result_with_channel( result: SplineProjectionResult, channel: Sequence[int], ) -> SplineProjectionResult: """Return a projection result with its physical channel metadata attached.""" resolved = tuple(int(value) for value in channel) return replace( result, channel=resolved, diagnostics=_diagnostics_with_channel(result.diagnostics, resolved), ) def _as_floating_tensor(value: torch.Tensor | Any, *, name: str) -> torch.Tensor: """Normalize a coefficient tensor and reject non-floating storage.""" tensor = torch.as_tensor(value) if not tensor.is_floating_point(): raise SplineProjectionError(f"`{name}` must be a floating-point tensor") return tensor def _row_matrix_from_stencil( indices: torch.Tensor, values: torch.Tensor, *, coeff_size: int, ) -> torch.Tensor: """Materialize one dense design matrix from sparse local stencil rows.""" matrix = values.new_zeros((values.shape[0], int(coeff_size))) matrix.scatter_add_(1, indices, values) return matrix def _midpoint_samples( *, lower: float, upper: float, n_samples: int, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: """Return midpoint samples over a half-open interval.""" n_samples = int(n_samples) if n_samples <= 0: raise SplineProjectionError("sample counts must be positive") step = (float(upper) - float(lower)) / float(n_samples) return float(lower) + step * ( torch.arange(n_samples, dtype=dtype, device=device) + 0.5 ) def _normalize_1d_samples( sample_distances: torch.Tensor | None, *, lower: float, upper: float, n_samples: int | None, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: """Return one-dimensional sample locations inside a support interval.""" if sample_distances is None: resolved_n_samples = n_samples if resolved_n_samples is None: raise SplineProjectionError("internal error: missing default sample count") return _midpoint_samples( lower=lower, upper=upper, n_samples=resolved_n_samples, dtype=dtype, device=device, ) distances = torch.as_tensor(sample_distances, dtype=dtype, device=device).reshape( -1 ) if distances.numel() == 0: raise SplineProjectionError("`sample_distances` must not be empty") if torch.any(distances < float(lower)) or torch.any(distances >= float(upper)): raise SplineProjectionError( "`sample_distances` must lie inside the target full-support interval" ) return distances def _normalize_axis_samples( n_samples_per_axis: int | Sequence[int] | None, *, n_axes: int, target_shape: tuple[int, ...], ) -> tuple[int, ...]: """Resolve per-axis sample counts.""" if n_samples_per_axis is None: return tuple(max(2 * int(dim), int(dim) + 4) for dim in target_shape) if isinstance(n_samples_per_axis, int): count = int(n_samples_per_axis) if count <= 0: raise SplineProjectionError("`n_samples_per_axis` must be positive") return (count,) * n_axes counts = tuple(int(count) for count in n_samples_per_axis) if len(counts) != n_axes: raise SplineProjectionError( f"`n_samples_per_axis` must contain {n_axes} entries" ) if any(count <= 0 for count in counts): raise SplineProjectionError("`n_samples_per_axis` entries must be positive") return counts def _normalize_sample_coordinates( sample_coordinates: Sequence[torch.Tensor] | None, *, bounds: tuple[tuple[float, float], ...], n_samples_per_axis: int | Sequence[int] | None, target_shape: tuple[int, ...], dtype: torch.dtype, device: torch.device, ) -> tuple[torch.Tensor, ...]: """Return flattened sample coordinates for a tensor-product spline grid.""" if sample_coordinates is not None: coordinates = tuple( torch.as_tensor(coordinate, dtype=dtype, device=device).reshape(-1) for coordinate in sample_coordinates ) if len(coordinates) != len(bounds): raise SplineProjectionError( "`sample_coordinates` must match the spline dimensionality" ) if not coordinates: raise SplineProjectionError("`sample_coordinates` must not be empty") shape = coordinates[0].shape if any(coordinate.shape != shape for coordinate in coordinates): raise SplineProjectionError( "`sample_coordinates` entries must have matching shapes" ) if coordinates[0].numel() == 0: raise SplineProjectionError("`sample_coordinates` must not be empty") for axis, (coordinate, (lower, upper)) in enumerate( zip(coordinates, bounds, strict=True) ): if torch.any(coordinate < float(lower)) or torch.any( coordinate >= float(upper) ): raise SplineProjectionError( f"sample coordinates for axis {axis} must lie inside the " "target full-support interval" ) return coordinates counts = _normalize_axis_samples( n_samples_per_axis, n_axes=len(bounds), target_shape=target_shape, ) axes = tuple( _midpoint_samples( lower=lower, upper=upper, n_samples=count, dtype=dtype, device=device, ) for count, (lower, upper) in zip(counts, bounds, strict=True) ) meshes = torch.meshgrid(*axes, indexing="ij") return tuple(mesh.reshape(-1) for mesh in meshes) def _validate_domain( *, lower: float, upper: float, source_lower: float, source_upper: float, ) -> tuple[float, float]: """Validate a target interval and reject source extrapolation.""" lower = float(lower) upper = float(upper) source_lower = float(source_lower) source_upper = float(source_upper) if upper <= lower: raise SplineProjectionError( "target upper full-support boundary must exceed the lower boundary" ) if lower < source_lower or upper > source_upper: raise SplineProjectionError( "target full-support interval must be contained in the source interval" ) return lower, upper def _require_close(name: str, source: float, target: float) -> None: """Validate one scalar metadata field.""" if not math.isclose(float(source), float(target), rel_tol=0.0, abs_tol=1.0e-12): raise SplineProjectionError( f"{name} differs: {float(source)!r} != {float(target)!r}" ) def _require_same_dtype_device(source: torch.Tensor, target: torch.Tensor) -> None: """Validate coefficient storage dtype and device.""" if source.dtype != target.dtype: raise SplineProjectionError(f"dtype differs: {source.dtype} != {target.dtype}") if source.device != target.device: raise SplineProjectionError( f"device differs: {source.device} != {target.device}" ) def _required_cutoff(term: Any) -> float: """Return a term cutoff, rejecting terms without one.""" cutoff = term.cutoff if cutoff is None: raise SplineProjectionError(f"{type(term).__name__} does not define a cutoff") return float(cutoff) def _check_derivative_weight(derivative_weight: float) -> float: """Normalize a derivative-row weight.""" derivative_weight = float(derivative_weight) if derivative_weight < 0.0: raise SplineProjectionError("`derivative_weight` must be non-negative") return derivative_weight def _solve_projection( *, coeff_shape: tuple[int, ...], value_matrix: torch.Tensor, gradient_matrices: tuple[torch.Tensor, ...], source_values: torch.Tensor, source_gradients: tuple[torch.Tensor, ...], derivative_weight: float, rcond: float | None, ) -> tuple[ torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...] | None, float, float, float | None, float | None, ]: """Solve one dense least-squares projection problem.""" rows = [value_matrix] rhs = [source_values] if derivative_weight > 0.0: rows.extend(derivative_weight * matrix for matrix in gradient_matrices) rhs.extend(derivative_weight * gradient for gradient in source_gradients) design = torch.cat(rows, dim=0) target = torch.cat(rhs, dim=0) coeffs = torch.linalg.lstsq(design, target, rcond=rcond).solution.reshape( coeff_shape ) projected_values = value_matrix @ coeffs.reshape(-1) value_error = projected_values - source_values rmse = _as_float(torch.sqrt(torch.mean(value_error.square()))) max_abs_error = _as_float(torch.max(torch.abs(value_error))) if derivative_weight <= 0.0: return coeffs, projected_values, None, rmse, max_abs_error, None, None projected_gradients = tuple( matrix @ coeffs.reshape(-1) for matrix in gradient_matrices ) gradient_error = torch.cat( [ (projected - source).reshape(-1) for projected, source in zip( projected_gradients, source_gradients, strict=True, ) ] ) gradient_rmse = _as_float(torch.sqrt(torch.mean(gradient_error.square()))) gradient_max_abs_error = _as_float(torch.max(torch.abs(gradient_error))) return ( coeffs, projected_values, projected_gradients, rmse, max_abs_error, gradient_rmse, gradient_max_abs_error, ) def _evaluate_with_1d_stencil( coeffs: torch.Tensor, stencil: Stencil1D, ) -> tuple[torch.Tensor, torch.Tensor]: """Evaluate a one-dimensional coefficient vector from a stencil.""" coeff_window = coeffs.reshape(-1)[stencil.indices] values = (stencil.values * coeff_window).sum(dim=1) grads = (stencil.grads * coeff_window).sum(dim=1) return values, grads def _evaluate_with_2d_stencil( coeffs: torch.Tensor, stencil: Stencil2D, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Evaluate a two-dimensional coefficient grid from a stencil.""" coeff_window = coeffs.reshape(-1)[stencil.indices] values = (stencil.values * coeff_window).sum(dim=1) grad_x = (stencil.grad_x * coeff_window).sum(dim=1) grad_y = (stencil.grad_y * coeff_window).sum(dim=1) return values, grad_x, grad_y def _evaluate_with_3d_stencil( coeffs: torch.Tensor, stencil: Stencil3D, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Evaluate a three-dimensional coefficient grid from a stencil.""" coeff_window = coeffs.reshape(-1)[stencil.indices] values = (stencil.values * coeff_window).sum(dim=1) grad_x = (stencil.grad_x * coeff_window).sum(dim=1) grad_y = (stencil.grad_y * coeff_window).sum(dim=1) grad_z = (stencil.grad_z * coeff_window).sum(dim=1) return values, grad_x, grad_y, grad_z
[docs] def evaluate_uniform_spline_1d( coeffs: torch.Tensor, distances: torch.Tensor, *, first_knot: float, knot_spacing: float, spline: str, ) -> tuple[torch.Tensor, torch.Tensor]: """Evaluate 1D uniform spline values and distance derivatives.""" coeffs = _as_floating_tensor(coeffs, name="coeffs") if coeffs.ndim != 1: raise SplineProjectionError("`coeffs` must be one-dimensional") distances = torch.as_tensor( distances, dtype=coeffs.dtype, device=coeffs.device, ).reshape(-1) stencil = uniform_stencil_1d( distances, coeff_size=int(coeffs.shape[0]), first_knot=float(first_knot), knot_spacing=float(knot_spacing), spline=spline, ) return _evaluate_with_1d_stencil(coeffs, stencil)
[docs] def evaluate_uniform_spline_2d( coeffs: torch.Tensor, x: torch.Tensor, y: torch.Tensor, *, first_knot_x: float, first_knot_y: float, knot_spacing_x: float, knot_spacing_y: float, spline: str, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Evaluate 2D uniform spline values and coordinate derivatives.""" coeffs = _as_floating_tensor(coeffs, name="coeffs") if coeffs.ndim != 2: raise SplineProjectionError("`coeffs` must be two-dimensional") x = torch.as_tensor(x, dtype=coeffs.dtype, device=coeffs.device).reshape(-1) y = torch.as_tensor(y, dtype=coeffs.dtype, device=coeffs.device).reshape(-1) if x.shape != y.shape: raise SplineProjectionError("`x` and `y` must have matching shapes") stencil = uniform_stencil_2d( x, y, coeff_shape=(int(coeffs.shape[0]), int(coeffs.shape[1])), first_knot_x=float(first_knot_x), first_knot_y=float(first_knot_y), knot_spacing_x=float(knot_spacing_x), knot_spacing_y=float(knot_spacing_y), spline=spline, ) return _evaluate_with_2d_stencil(coeffs, stencil)
[docs] def evaluate_uniform_spline_3d( coeffs: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, *, first_knot_xy: float, first_knot_z: float, knot_spacing_xy: float, knot_spacing_z: float, spline: str, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Evaluate 3D uniform spline values and coordinate derivatives.""" coeffs = _as_floating_tensor(coeffs, name="coeffs") if coeffs.ndim != 3: raise SplineProjectionError("`coeffs` must be three-dimensional") x = torch.as_tensor(x, dtype=coeffs.dtype, device=coeffs.device).reshape(-1) y = torch.as_tensor(y, dtype=coeffs.dtype, device=coeffs.device).reshape(-1) z = torch.as_tensor(z, dtype=coeffs.dtype, device=coeffs.device).reshape(-1) if x.shape != y.shape or x.shape != z.shape: raise SplineProjectionError("`x`, `y`, and `z` must have matching shapes") stencil = uniform_stencil_3d( x, y, z, coeff_shape=( int(coeffs.shape[0]), int(coeffs.shape[1]), int(coeffs.shape[2]), ), first_knot_xy=float(first_knot_xy), first_knot_z=float(first_knot_z), knot_spacing_xy=float(knot_spacing_xy), knot_spacing_z=float(knot_spacing_z), spline=spline, ) return _evaluate_with_3d_stencil(coeffs, stencil)
[docs] def project_uniform_spline_1d( source_coeffs: torch.Tensor, *, source_lower_full_support: float, source_upper_full_support: float, target_coeff_size: int, target_lower_full_support: float | None = None, target_upper_full_support: float | None = None, source_spline: str = "cubic", target_spline: str | None = None, n_samples: int | None = None, sample_distances: torch.Tensor | None = None, derivative_weight: float = 1.0, rcond: float | None = None, ) -> SplineProjectionResult: """ Project one 1D uniform spline onto another 1D uniform spline basis. The target interval must be contained in the source interval. The result uses the source coefficient dtype and device. """ source_coeffs = _as_floating_tensor(source_coeffs, name="source_coeffs") if source_coeffs.ndim != 1: raise SplineProjectionError("`source_coeffs` must be one-dimensional") target_coeff_size = int(target_coeff_size) target_spline = source_spline if target_spline is None else target_spline derivative_weight = _check_derivative_weight(derivative_weight) source_lower_full_support = float(source_lower_full_support) source_upper_full_support = float(source_upper_full_support) if source_upper_full_support <= source_lower_full_support: raise SplineProjectionError( "source upper full-support boundary must exceed the lower boundary" ) target_lower, target_upper = _validate_domain( lower=( source_lower_full_support if target_lower_full_support is None else target_lower_full_support ), upper=( source_upper_full_support if target_upper_full_support is None else target_upper_full_support ), source_lower=source_lower_full_support, source_upper=source_upper_full_support, ) source_first_knot, source_knot_spacing = uniform_support_parameters( coeff_size=int(source_coeffs.shape[0]), lower_full_support=source_lower_full_support, upper_full_support=source_upper_full_support, spline=source_spline, ) target_first_knot, target_knot_spacing = uniform_support_parameters( coeff_size=target_coeff_size, lower_full_support=target_lower, upper_full_support=target_upper, spline=target_spline, ) resolved_n_samples = ( max(4 * target_coeff_size, target_coeff_size + 8) if n_samples is None else int(n_samples) ) distances = _normalize_1d_samples( sample_distances, lower=target_lower, upper=target_upper, n_samples=resolved_n_samples, dtype=source_coeffs.dtype, device=source_coeffs.device, ) if torch.any(distances < source_lower_full_support) or torch.any( distances >= source_upper_full_support ): raise SplineProjectionError( "sample distances must lie inside the source full-support interval" ) source_stencil = uniform_stencil_1d( distances, coeff_size=int(source_coeffs.shape[0]), first_knot=source_first_knot, knot_spacing=source_knot_spacing, spline=source_spline, ) source_values, source_gradients = _evaluate_with_1d_stencil( source_coeffs, source_stencil, ) target_stencil = uniform_stencil_1d( distances, coeff_size=target_coeff_size, first_knot=target_first_knot, knot_spacing=target_knot_spacing, spline=target_spline, ) value_matrix = _row_matrix_from_stencil( target_stencil.indices, target_stencil.values, coeff_size=target_coeff_size, ) gradient_matrix = _row_matrix_from_stencil( target_stencil.indices, target_stencil.grads, coeff_size=target_coeff_size, ) ( coeffs, projected_values, projected_gradients, rmse, max_abs_error, gradient_rmse, gradient_max_abs_error, ) = _solve_projection( coeff_shape=(target_coeff_size,), value_matrix=value_matrix, gradient_matrices=(gradient_matrix,), source_values=source_values, source_gradients=(source_gradients,), derivative_weight=derivative_weight, rcond=rcond, ) result_source_gradients = (source_gradients,) if derivative_weight > 0.0 else None return SplineProjectionResult( coeffs=coeffs, sample_coordinates=(distances,), source_values=source_values, projected_values=projected_values, rmse=rmse, max_abs_error=max_abs_error, diagnostics=_projection_diagnostics( channel=None, sample_coordinates=(distances,), bounds=((target_lower, target_upper),), source_values=source_values, projected_values=projected_values, source_gradients=result_source_gradients, projected_gradients=projected_gradients, ), source_gradients=result_source_gradients, projected_gradients=projected_gradients, gradient_rmse=gradient_rmse, gradient_max_abs_error=gradient_max_abs_error, source_shape=tuple(int(dim) for dim in source_coeffs.shape), target_shape=(target_coeff_size,), )
[docs] def project_uniform_spline_2d( source_coeffs: torch.Tensor, *, source_lower_full_support: float, source_upper_full_support: float, target_coeff_shape: Sequence[int], target_lower_full_support: float | None = None, target_upper_full_support: float | None = None, source_spline: str = "cubic", target_spline: str | None = None, n_samples_per_axis: int | Sequence[int] | None = None, sample_coordinates: Sequence[torch.Tensor] | None = None, derivative_weight: float = 1.0, rcond: float | None = None, ) -> SplineProjectionResult: """Project one square 2D uniform spline grid onto another grid.""" source_coeffs = _as_floating_tensor(source_coeffs, name="source_coeffs") if source_coeffs.ndim != 2: raise SplineProjectionError("`source_coeffs` must be two-dimensional") if int(source_coeffs.shape[0]) != int(source_coeffs.shape[1]): raise SplineProjectionError("2D source coefficients must have matching axes") target_shape = tuple(int(dim) for dim in target_coeff_shape) if len(target_shape) != 2: raise SplineProjectionError("`target_coeff_shape` must contain two entries") if target_shape[0] != target_shape[1]: raise SplineProjectionError("2D target coefficients must have matching axes") target_spline = source_spline if target_spline is None else target_spline derivative_weight = _check_derivative_weight(derivative_weight) source_lower_full_support = float(source_lower_full_support) source_upper_full_support = float(source_upper_full_support) if source_upper_full_support <= source_lower_full_support: raise SplineProjectionError( "source upper full-support boundary must exceed the lower boundary" ) target_lower, target_upper = _validate_domain( lower=( source_lower_full_support if target_lower_full_support is None else target_lower_full_support ), upper=( source_upper_full_support if target_upper_full_support is None else target_upper_full_support ), source_lower=source_lower_full_support, source_upper=source_upper_full_support, ) source_first_knot, source_knot_spacing = uniform_support_parameters( coeff_size=int(source_coeffs.shape[0]), lower_full_support=source_lower_full_support, upper_full_support=source_upper_full_support, spline=source_spline, ) target_first_knot, target_knot_spacing = uniform_support_parameters( coeff_size=target_shape[0], lower_full_support=target_lower, upper_full_support=target_upper, spline=target_spline, ) x, y = _normalize_sample_coordinates( sample_coordinates, bounds=((target_lower, target_upper), (target_lower, target_upper)), n_samples_per_axis=n_samples_per_axis, target_shape=target_shape, dtype=source_coeffs.dtype, device=source_coeffs.device, ) source_stencil = uniform_stencil_2d( x, y, coeff_shape=(int(source_coeffs.shape[0]), int(source_coeffs.shape[1])), first_knot_x=source_first_knot, first_knot_y=source_first_knot, knot_spacing_x=source_knot_spacing, knot_spacing_y=source_knot_spacing, spline=source_spline, ) source_values, source_grad_x, source_grad_y = _evaluate_with_2d_stencil( source_coeffs, source_stencil, ) target_stencil = uniform_stencil_2d( x, y, coeff_shape=target_shape, first_knot_x=target_first_knot, first_knot_y=target_first_knot, knot_spacing_x=target_knot_spacing, knot_spacing_y=target_knot_spacing, spline=target_spline, ) value_matrix = _row_matrix_from_stencil( target_stencil.indices, target_stencil.values, coeff_size=math.prod(target_shape), ) gradient_matrices = ( _row_matrix_from_stencil( target_stencil.indices, target_stencil.grad_x, coeff_size=math.prod(target_shape), ), _row_matrix_from_stencil( target_stencil.indices, target_stencil.grad_y, coeff_size=math.prod(target_shape), ), ) ( coeffs, projected_values, projected_gradients, rmse, max_abs_error, gradient_rmse, gradient_max_abs_error, ) = _solve_projection( coeff_shape=target_shape, value_matrix=value_matrix, gradient_matrices=gradient_matrices, source_values=source_values, source_gradients=(source_grad_x, source_grad_y), derivative_weight=derivative_weight, rcond=rcond, ) result_source_gradients = ( (source_grad_x, source_grad_y) if derivative_weight > 0.0 else None ) return SplineProjectionResult( coeffs=coeffs, sample_coordinates=(x, y), source_values=source_values, projected_values=projected_values, rmse=rmse, max_abs_error=max_abs_error, diagnostics=_projection_diagnostics( channel=None, sample_coordinates=(x, y), bounds=((target_lower, target_upper), (target_lower, target_upper)), source_values=source_values, projected_values=projected_values, source_gradients=result_source_gradients, projected_gradients=projected_gradients, ), source_gradients=result_source_gradients, projected_gradients=projected_gradients, gradient_rmse=gradient_rmse, gradient_max_abs_error=gradient_max_abs_error, source_shape=tuple(int(dim) for dim in source_coeffs.shape), target_shape=target_shape, )
[docs] def project_uniform_spline_3d( source_coeffs: torch.Tensor, *, source_lower_full_support_xy: float, source_upper_full_support_xy: float, source_lower_full_support_z: float, source_upper_full_support_z: float, target_coeff_shape: Sequence[int], target_lower_full_support_xy: float | None = None, target_upper_full_support_xy: float | None = None, target_lower_full_support_z: float | None = None, target_upper_full_support_z: float | None = None, source_spline: str = "cubic", target_spline: str | None = None, n_samples_per_axis: int | Sequence[int] | None = None, sample_coordinates: Sequence[torch.Tensor] | None = None, derivative_weight: float = 1.0, rcond: float | None = None, ) -> SplineProjectionResult: """Project one 3D uniform spline grid onto another grid.""" source_coeffs = _as_floating_tensor(source_coeffs, name="source_coeffs") if source_coeffs.ndim != 3: raise SplineProjectionError("`source_coeffs` must be three-dimensional") if int(source_coeffs.shape[0]) != int(source_coeffs.shape[1]): raise SplineProjectionError("3D source x/y axes must match") target_shape = tuple(int(dim) for dim in target_coeff_shape) if len(target_shape) != 3: raise SplineProjectionError("`target_coeff_shape` must contain three entries") if target_shape[0] != target_shape[1]: raise SplineProjectionError("3D target x/y axes must match") target_spline = source_spline if target_spline is None else target_spline derivative_weight = _check_derivative_weight(derivative_weight) source_lower_xy = float(source_lower_full_support_xy) source_upper_xy = float(source_upper_full_support_xy) source_lower_z = float(source_lower_full_support_z) source_upper_z = float(source_upper_full_support_z) if source_upper_xy <= source_lower_xy or source_upper_z <= source_lower_z: raise SplineProjectionError( "source upper full-support boundaries must exceed lower boundaries" ) target_lower_xy, target_upper_xy = _validate_domain( lower=( source_lower_xy if target_lower_full_support_xy is None else target_lower_full_support_xy ), upper=( source_upper_xy if target_upper_full_support_xy is None else target_upper_full_support_xy ), source_lower=source_lower_xy, source_upper=source_upper_xy, ) target_lower_z, target_upper_z = _validate_domain( lower=( source_lower_z if target_lower_full_support_z is None else target_lower_full_support_z ), upper=( source_upper_z if target_upper_full_support_z is None else target_upper_full_support_z ), source_lower=source_lower_z, source_upper=source_upper_z, ) source_first_knot_xy, source_knot_spacing_xy = uniform_support_parameters( coeff_size=int(source_coeffs.shape[0]), lower_full_support=source_lower_xy, upper_full_support=source_upper_xy, spline=source_spline, ) source_first_knot_z, source_knot_spacing_z = uniform_support_parameters( coeff_size=int(source_coeffs.shape[2]), lower_full_support=source_lower_z, upper_full_support=source_upper_z, spline=source_spline, ) target_first_knot_xy, target_knot_spacing_xy = uniform_support_parameters( coeff_size=target_shape[0], lower_full_support=target_lower_xy, upper_full_support=target_upper_xy, spline=target_spline, ) target_first_knot_z, target_knot_spacing_z = uniform_support_parameters( coeff_size=target_shape[2], lower_full_support=target_lower_z, upper_full_support=target_upper_z, spline=target_spline, ) x, y, z = _normalize_sample_coordinates( sample_coordinates, bounds=( (target_lower_xy, target_upper_xy), (target_lower_xy, target_upper_xy), (target_lower_z, target_upper_z), ), n_samples_per_axis=n_samples_per_axis, target_shape=target_shape, dtype=source_coeffs.dtype, device=source_coeffs.device, ) source_stencil = uniform_stencil_3d( x, y, z, coeff_shape=( int(source_coeffs.shape[0]), int(source_coeffs.shape[1]), int(source_coeffs.shape[2]), ), first_knot_xy=source_first_knot_xy, first_knot_z=source_first_knot_z, knot_spacing_xy=source_knot_spacing_xy, knot_spacing_z=source_knot_spacing_z, spline=source_spline, ) ( source_values, source_grad_x, source_grad_y, source_grad_z, ) = _evaluate_with_3d_stencil(source_coeffs, source_stencil) target_stencil = uniform_stencil_3d( x, y, z, coeff_shape=target_shape, first_knot_xy=target_first_knot_xy, first_knot_z=target_first_knot_z, knot_spacing_xy=target_knot_spacing_xy, knot_spacing_z=target_knot_spacing_z, spline=target_spline, ) target_coeff_size = math.prod(target_shape) value_matrix = _row_matrix_from_stencil( target_stencil.indices, target_stencil.values, coeff_size=target_coeff_size, ) gradient_matrices = ( _row_matrix_from_stencil( target_stencil.indices, target_stencil.grad_x, coeff_size=target_coeff_size, ), _row_matrix_from_stencil( target_stencil.indices, target_stencil.grad_y, coeff_size=target_coeff_size, ), _row_matrix_from_stencil( target_stencil.indices, target_stencil.grad_z, coeff_size=target_coeff_size, ), ) ( coeffs, projected_values, projected_gradients, rmse, max_abs_error, gradient_rmse, gradient_max_abs_error, ) = _solve_projection( coeff_shape=target_shape, value_matrix=value_matrix, gradient_matrices=gradient_matrices, source_values=source_values, source_gradients=(source_grad_x, source_grad_y, source_grad_z), derivative_weight=derivative_weight, rcond=rcond, ) result_source_gradients = ( (source_grad_x, source_grad_y, source_grad_z) if derivative_weight > 0.0 else None ) return SplineProjectionResult( coeffs=coeffs, sample_coordinates=(x, y, z), source_values=source_values, projected_values=projected_values, rmse=rmse, max_abs_error=max_abs_error, diagnostics=_projection_diagnostics( channel=None, sample_coordinates=(x, y, z), bounds=( (target_lower_xy, target_upper_xy), (target_lower_xy, target_upper_xy), (target_lower_z, target_upper_z), ), source_values=source_values, projected_values=projected_values, source_gradients=result_source_gradients, projected_gradients=projected_gradients, ), source_gradients=result_source_gradients, projected_gradients=projected_gradients, gradient_rmse=gradient_rmse, gradient_max_abs_error=gradient_max_abs_error, source_shape=tuple(int(dim) for dim in source_coeffs.shape), target_shape=target_shape, )
def _canonical_triplet(channel: Sequence[int]) -> tuple[int, int, int]: """Normalize one source-distinguished triplet channel.""" if len(channel) != 3: raise SplineProjectionError("triplet channels must contain three entries") source = int(channel[0]) first = int(channel[1]) second = int(channel[2]) if first > second: first, second = second, first return (source, first, second) def _active_triplet_categories( term: SplineThreeBodyTerm | SplineTriplet2DTerm, ) -> tuple[tuple[int, int, int], ...]: """Return active triplet categories for a three-body-like term.""" active_indices = cast(tuple[int, ...], term._active_triplet_indices) return tuple(term.triplet_categories[index] for index in active_indices) def _write_indexed_channel( term: SplineTwoBodyTerm | SplineThreeBodyTerm | SplineTriplet2DTerm, index: int, values: torch.Tensor, ) -> None: """Write one indexed channel through a term parameter block.""" block = term.parameter_blocks()[0] current = block.read().detach().clone() if current.ndim == values.ndim: current.copy_(values.reshape(current.shape).to(current)) elif current.ndim == values.ndim + 1: if tuple(int(dim) for dim in current[index].shape) != tuple( int(dim) for dim in values.shape ): raise SplineProjectionError( "projected coefficient output shape does not match target channel" ) current[index] = values.to(current) else: raise SplineProjectionError( "target coefficient block dimensionality is incompatible" ) block.write(current.reshape(block.shape))
[docs] def project_pair_to_twobody( source: SplinePairTerm, target: SplineTwoBodyTerm, *, pair: Sequence[int] | None = None, n_samples: int | None = None, sample_distances: torch.Tensor | None = None, derivative_weight: float = 1.0, rcond: float | None = None, write: bool = True, ) -> SplineProjectionResult: """Project a pair-specific spline into one categorized two-body row.""" if not isinstance(source, SplinePairTerm): raise TypeError("`source` must be a SplinePairTerm") if not isinstance(target, SplineTwoBodyTerm): raise TypeError("`target` must be a SplineTwoBodyTerm") resolved_pair = source.pair if pair is None else tuple(int(value) for value in pair) if len(resolved_pair) != 2: raise SplineProjectionError("pair channels must contain two entries") if not source.covers_pair(resolved_pair[0], resolved_pair[1]): raise SplineProjectionError( f"source term does not cover pair channel {resolved_pair!r}" ) target_pair = target.canonical_pair(resolved_pair[0], resolved_pair[1]) if target_pair != source.pair: raise SplineProjectionError( f"channel identity differs: {source.pair!r} != {target_pair!r}" ) if not target.is_pair_active(target_pair[0], target_pair[1]): raise SplineProjectionError(f"target pair channel {target_pair!r} is inactive") if bool(source.symmetric) != bool(target.symmetric): raise SplineProjectionError( f"pair symmetry differs: {source.symmetric} != {target.symmetric}" ) if str(source.spline) != str(target.spline): raise SplineProjectionError( f"spline family differs: {source.spline!r} != {target.spline!r}" ) source_cutoff = _required_cutoff(source) target_cutoff = _required_cutoff(target) _require_close("cutoff", source_cutoff, target_cutoff) _require_close( "full_support_start", float(source.full_support_start), float(target.full_support_start), ) source_coeffs = source.true_coeffs.detach() target_coeffs = target.true_coeffs_by_pair.detach() target_index = target.pair_category_index(target_pair[0], target_pair[1]) target_row = target_coeffs[target_index] _require_same_dtype_device(source_coeffs, target_row) result = project_uniform_spline_1d( source_coeffs, source_lower_full_support=float(source.full_support_start), source_upper_full_support=source_cutoff, target_coeff_size=int(target_row.shape[0]), target_lower_full_support=float(target.full_support_start), target_upper_full_support=target_cutoff, source_spline=str(source.spline), target_spline=str(target.spline), n_samples=n_samples, sample_distances=sample_distances, derivative_weight=derivative_weight, rcond=rcond, ) if tuple(int(dim) for dim in result.coeffs.shape) != tuple( int(dim) for dim in target_row.shape ): raise SplineProjectionError( "projected coefficient output shape does not match target channel" ) if write: _write_indexed_channel(target, target_index, result.coeffs) return _result_with_channel(result, target_pair)
[docs] def project_triplet2d_channel( source: SplineTriplet2DTerm, target: SplineTriplet2DTerm, channel: Sequence[int], *, n_samples_per_axis: int | Sequence[int] | None = None, sample_coordinates: Sequence[torch.Tensor] | None = None, derivative_weight: float = 1.0, rcond: float | None = None, write: bool = True, ) -> SplineProjectionResult: """Project one same-family 2D triplet channel between compatible terms.""" if not isinstance(source, SplineTriplet2DTerm): raise TypeError("`source` must be a SplineTriplet2DTerm") if not isinstance(target, SplineTriplet2DTerm): raise TypeError("`target` must be a SplineTriplet2DTerm") triplet = _canonical_triplet(channel) try: source_index = source.triplet_categories.index(triplet) target_index = target.triplet_categories.index(triplet) except ValueError as exc: raise SplineProjectionError( f"triplet channel {triplet!r} is not present in source and target" ) from exc if triplet not in _active_triplet_categories(source): raise SplineProjectionError(f"source triplet channel {triplet!r} is inactive") if triplet not in _active_triplet_categories(target): raise SplineProjectionError(f"target triplet channel {triplet!r} is inactive") if str(source.spline) != str(target.spline): raise SplineProjectionError( f"spline family differs: {source.spline!r} != {target.spline!r}" ) source_cutoff = _required_cutoff(source) target_cutoff = _required_cutoff(target) _require_close("cutoff", source_cutoff, target_cutoff) _require_close( "full_support_start", float(source.full_support_start), float(target.full_support_start), ) source_coeffs = source.coeffs_by_triplet.detach()[source_index] target_coeffs = target.coeffs_by_triplet.detach()[target_index] _require_same_dtype_device(source_coeffs, target_coeffs) result = project_uniform_spline_2d( source_coeffs, source_lower_full_support=float(source.full_support_start), source_upper_full_support=source_cutoff, target_coeff_shape=tuple(int(dim) for dim in target_coeffs.shape), target_lower_full_support=float(target.full_support_start), target_upper_full_support=target_cutoff, source_spline=str(source.spline), target_spline=str(target.spline), n_samples_per_axis=n_samples_per_axis, sample_coordinates=sample_coordinates, derivative_weight=derivative_weight, rcond=rcond, ) if write: _write_indexed_channel(target, target_index, result.coeffs) return _result_with_channel(result, triplet)
[docs] def project_threebody_channel( source: SplineThreeBodyTerm, target: SplineThreeBodyTerm, channel: Sequence[int], *, n_samples_per_axis: int | Sequence[int] | None = None, sample_coordinates: Sequence[torch.Tensor] | None = None, derivative_weight: float = 1.0, rcond: float | None = None, write: bool = True, ) -> SplineProjectionResult: """Project one same-family 3D triplet channel between compatible terms.""" if not isinstance(source, SplineThreeBodyTerm): raise TypeError("`source` must be a SplineThreeBodyTerm") if not isinstance(target, SplineThreeBodyTerm): raise TypeError("`target` must be a SplineThreeBodyTerm") triplet = _canonical_triplet(channel) source_index_map = cast(dict[tuple[int, int, int], int], source._triplet_index) target_index_map = cast(dict[tuple[int, int, int], int], target._triplet_index) source_index = source_index_map.get(triplet) target_index = target_index_map.get(triplet) if source_index is None or target_index is None: raise SplineProjectionError( f"triplet channel {triplet!r} is not present in source and target" ) if triplet not in _active_triplet_categories(source): raise SplineProjectionError(f"source triplet channel {triplet!r} is inactive") if triplet not in _active_triplet_categories(target): raise SplineProjectionError(f"target triplet channel {triplet!r} is inactive") if str(source.spline) != str(target.spline): raise SplineProjectionError( f"spline family differs: {source.spline!r} != {target.spline!r}" ) source_cutoff = _required_cutoff(source) target_cutoff = _required_cutoff(target) _require_close("cutoff", source_cutoff, target_cutoff) _require_close( "neighbor_neighbor_cutoff", float(source.neighbor_neighbor_cutoff), float(target.neighbor_neighbor_cutoff), ) _require_close( "full_support_start_xy", float(source.full_support_start_xy), float(target.full_support_start_xy), ) _require_close( "full_support_start_z", float(source.full_support_start_z), float(target.full_support_start_z), ) source_coeffs = source.true_coeffs_by_triplet.detach()[source_index] target_coeffs = target.true_coeffs_by_triplet.detach()[target_index] _require_same_dtype_device(source_coeffs, target_coeffs) result = project_uniform_spline_3d( source_coeffs, source_lower_full_support_xy=float(source.full_support_start_xy), source_upper_full_support_xy=source_cutoff, source_lower_full_support_z=float(source.full_support_start_z), source_upper_full_support_z=float(source.neighbor_neighbor_cutoff), target_coeff_shape=tuple(int(dim) for dim in target_coeffs.shape), target_lower_full_support_xy=float(target.full_support_start_xy), target_upper_full_support_xy=target_cutoff, target_lower_full_support_z=float(target.full_support_start_z), target_upper_full_support_z=float(target.neighbor_neighbor_cutoff), source_spline=str(source.spline), target_spline=str(target.spline), n_samples_per_axis=n_samples_per_axis, sample_coordinates=sample_coordinates, derivative_weight=derivative_weight, rcond=rcond, ) if write: _write_indexed_channel(target, int(target_index), result.coeffs) return _result_with_channel(result, triplet)
__all__ = [ "SplineProjectionError", "SplineProjectionResult", "evaluate_uniform_spline_1d", "evaluate_uniform_spline_2d", "evaluate_uniform_spline_3d", "project_pair_to_twobody", "project_threebody_channel", "project_triplet2d_channel", "project_uniform_spline_1d", "project_uniform_spline_2d", "project_uniform_spline_3d", ]