Source code for ufp.leastsquares.regularization

"""Blockwise regularization operators for linear least-squares fits."""

from __future__ import annotations

from dataclasses import dataclass

import torch


[docs] @dataclass(frozen=True) class RegularizationStencil: """One least-squares regularization row over a compact block vector.""" columns: tuple[int, ...] weights: tuple[float, ...] target: float = 0.0 def __post_init__(self) -> None: """Validate compact row metadata.""" if len(self.columns) != len(self.weights): raise ValueError("regularization stencil columns and weights must match") if any(int(column) < 0 for column in self.columns): raise ValueError("regularization stencil columns must be non-negative")
[docs] @dataclass(frozen=True) class BlockRegularization: """Regularization recipe applied blockwise inside the linear problem.""" shape: tuple[int, ...] ridge: float = 0.0 third_difference_penalty: float = 0.0 active_rows: tuple[int, ...] | None = None third_difference_stencils: tuple[RegularizationStencil, ...] | None = None @property def size(self) -> int: """Return the flattened size of the regularized block.""" size = 1 for dim in self.shape: size *= int(dim) return size def _active_1d_rows(self) -> tuple[int, ...]: """Return row indices regularized along the final coefficient axis.""" if len(self.shape) == 1: return (0,) if len(self.shape) == 2: if self.active_rows is None: return tuple(range(int(self.shape[0]))) return self.active_rows return ()
[docs] def apply(self, values: torch.Tensor) -> torch.Tensor: """Apply ridge and optional third-difference penalties to one block.""" flat_values = values.reshape(-1) output = self.ridge * flat_values.clone() self._apply_third_difference(output, flat_values) return output
def _apply_third_difference( self, output: torch.Tensor, values: torch.Tensor, ) -> None: """Accumulate the third-difference normal operator into ``output``.""" penalty = self.third_difference_penalty if penalty <= 0.0: return flat_values = values.reshape(-1) flat_output = output.reshape(-1) for stencil in self._iter_third_difference_stencils(): if not stencil.columns: continue current = torch.zeros( (), dtype=flat_values.dtype, device=flat_values.device ) for column, weight in zip(stencil.columns, stencil.weights, strict=True): current = current + float(weight) * flat_values[int(column)] for column, weight in zip(stencil.columns, stencil.weights, strict=True): flat_output[int(column)] += penalty * float(weight) * current def _iter_third_difference_stencils( self, ) -> tuple[RegularizationStencil, ...]: """Return third-difference rows in compact block coordinates.""" if self.third_difference_penalty <= 0.0: return () if self.third_difference_stencils is not None: return self.third_difference_stencils if len(self.shape) not in {1, 2} or self.shape[-1] < 4: return () stencil = (-1.0, 3.0, -3.0, 1.0) rows: list[RegularizationStencil] = [] if len(self.shape) == 1: for coeff in range(int(self.shape[0]) - 3): rows.append( RegularizationStencil( columns=(coeff, coeff + 1, coeff + 2, coeff + 3), weights=stencil, ) ) return tuple(rows) n_coeffs = int(self.shape[1]) for row in self._active_1d_rows(): offset = int(row) * n_coeffs for coeff in range(n_coeffs - 3): rows.append( RegularizationStencil( columns=( offset + coeff, offset + coeff + 1, offset + coeff + 2, offset + coeff + 3, ), weights=stencil, ) ) return tuple(rows)
[docs] def rhs(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: """Return the regularization contribution to the normal-equation RHS.""" output = torch.zeros((self.size,), dtype=dtype, device=device) penalty = self.third_difference_penalty if penalty <= 0.0: return output for stencil in self._iter_third_difference_stencils(): if stencil.target == 0.0: continue for column, weight in zip(stencil.columns, stencil.weights, strict=True): output[int(column)] += penalty * float(stencil.target) * float(weight) return output
[docs] def constant(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: """Return the regularization constant from fixed-coefficient targets.""" value = torch.zeros((), dtype=dtype, device=device) penalty = self.third_difference_penalty if penalty <= 0.0: return value for stencil in self._iter_third_difference_stencils(): if stencil.target == 0.0: continue target = torch.as_tensor(float(stencil.target), dtype=dtype, device=device) value = value + penalty * target.square() return value
[docs] def least_squares_rows( self, *, dtype: torch.dtype, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """Materialize regularization as least-squares rows and targets.""" rows: list[torch.Tensor] = [] targets: list[torch.Tensor] = [] if self.ridge > 0.0: scale = self.ridge**0.5 rows.append(scale * torch.eye(self.size, dtype=dtype, device=device)) targets.append(torch.zeros((self.size,), dtype=dtype, device=device)) stencils = self._iter_third_difference_stencils() if self.third_difference_penalty > 0.0 and stencils: scale = self.third_difference_penalty**0.5 matrix = torch.zeros((len(stencils), self.size), dtype=dtype, device=device) target = torch.zeros((len(stencils),), dtype=dtype, device=device) for row, stencil in enumerate(stencils): for column, weight in zip( stencil.columns, stencil.weights, strict=True, ): matrix[row, int(column)] += scale * float(weight) target[row] = scale * float(stencil.target) rows.append(matrix) targets.append(target) if not rows: return ( torch.zeros((0, self.size), dtype=dtype, device=device), torch.zeros((0,), dtype=dtype, device=device), ) return torch.cat(rows, dim=0), torch.cat(targets, dim=0)
[docs] def diagonal(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: """Return the diagonal of the block regularizer in vector form.""" diag = torch.full( (self.size,), fill_value=self.ridge, dtype=dtype, device=device, ) return self._add_third_difference_diagonal(diag)
def _add_third_difference_diagonal(self, diag: torch.Tensor) -> torch.Tensor: """Add the third-difference diagonal contribution in place.""" penalty = self.third_difference_penalty if penalty <= 0.0: return diag flat_diag = diag.reshape(-1) for stencil in self._iter_third_difference_stencils(): for column, weight in zip(stencil.columns, stencil.weights, strict=True): flat_diag[int(column)] += penalty * float(weight) * float(weight) return diag
[docs] def materialize(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: """Materialize the block regularizer as an explicit square matrix.""" matrix = torch.zeros((self.size, self.size), dtype=dtype, device=device) if self.ridge != 0.0: matrix.diagonal().add_(self.ridge) self._materialize_third_difference(matrix) return matrix
def _materialize_third_difference(self, matrix: torch.Tensor) -> None: """Accumulate the explicit third-difference normal matrix.""" penalty = self.third_difference_penalty if penalty <= 0.0: return for stencil in self._iter_third_difference_stencils(): for col_i, weight_i in zip( stencil.columns, stencil.weights, strict=True, ): for col_j, weight_j in zip( stencil.columns, stencil.weights, strict=True, ): matrix[int(col_i), int(col_j)] += ( penalty * float(weight_i) * float(weight_j) ) def _third_difference_quadratic( self, values: torch.Tensor, ) -> torch.Tensor: """Evaluate explicit third-difference least-squares rows.""" penalty = self.third_difference_penalty total = torch.zeros((), dtype=values.dtype, device=values.device) if penalty <= 0.0: return total flat_values = values.reshape(-1) for stencil in self._iter_third_difference_stencils(): residual = -torch.as_tensor( float(stencil.target), dtype=values.dtype, device=values.device, ) for column, weight in zip(stencil.columns, stencil.weights, strict=True): residual = residual + float(weight) * flat_values[int(column)] total = total + penalty * residual.square() return total
[docs] def semantics_metadata(self) -> dict[str, object]: """Return metadata for cache/checkpoint invalidation.""" return { "third_difference": ( "boundary_aware_v1" if self.third_difference_stencils is not None else "full_block_v1" ) }
[docs] def stencils_metadata(self) -> dict[str, object]: """Return compact metadata for selected regularization rows.""" stencils = self._iter_third_difference_stencils() return { "count": len(stencils), "has_targets": any(stencil.target != 0.0 for stencil in stencils), }
[docs] def quadratic(self, values: torch.Tensor) -> torch.Tensor: """Evaluate the quadratic penalty contributed by one block vector.""" flat_values = values.reshape(-1) value = torch.zeros((), dtype=flat_values.dtype, device=flat_values.device) if self.ridge != 0.0: value = value + self.ridge * torch.dot(flat_values, flat_values) return value + self._third_difference_quadratic(flat_values)
def _make_block_regularization( shape: tuple[int, ...], *, ridge: float, third_difference_penalty: float = 0.0, active_rows: tuple[int, ...] | None = None, third_difference_stencils: tuple[RegularizationStencil, ...] | None = None, ) -> BlockRegularization | None: """Build a block regularizer only when one of the penalties is active.""" if ridge == 0.0 and third_difference_penalty == 0.0: return None if ( ridge == 0.0 and third_difference_penalty > 0.0 and third_difference_stencils == () ): return None return BlockRegularization( shape=shape, ridge=float(ridge), third_difference_penalty=float(third_difference_penalty), active_rows=active_rows, third_difference_stencils=third_difference_stencils, ) __all__ = ["BlockRegularization", "RegularizationStencil"]