"""Blockwise regularization operators for linear least-squares fits."""
from __future__ import annotations
from dataclasses import dataclass
import torch
[docs]
@dataclass(frozen=True)
class RegularizationStencil:
"""One least-squares regularization row over a compact block vector."""
columns: tuple[int, ...]
weights: tuple[float, ...]
target: float = 0.0
def __post_init__(self) -> None:
"""Validate compact row metadata."""
if len(self.columns) != len(self.weights):
raise ValueError("regularization stencil columns and weights must match")
if any(int(column) < 0 for column in self.columns):
raise ValueError("regularization stencil columns must be non-negative")
[docs]
@dataclass(frozen=True)
class BlockRegularization:
"""Regularization recipe applied blockwise inside the linear problem."""
shape: tuple[int, ...]
ridge: float = 0.0
third_difference_penalty: float = 0.0
active_rows: tuple[int, ...] | None = None
third_difference_stencils: tuple[RegularizationStencil, ...] | None = None
@property
def size(self) -> int:
"""Return the flattened size of the regularized block."""
size = 1
for dim in self.shape:
size *= int(dim)
return size
def _active_1d_rows(self) -> tuple[int, ...]:
"""Return row indices regularized along the final coefficient axis."""
if len(self.shape) == 1:
return (0,)
if len(self.shape) == 2:
if self.active_rows is None:
return tuple(range(int(self.shape[0])))
return self.active_rows
return ()
[docs]
def apply(self, values: torch.Tensor) -> torch.Tensor:
"""Apply ridge and optional third-difference penalties to one block."""
flat_values = values.reshape(-1)
output = self.ridge * flat_values.clone()
self._apply_third_difference(output, flat_values)
return output
def _apply_third_difference(
self,
output: torch.Tensor,
values: torch.Tensor,
) -> None:
"""Accumulate the third-difference normal operator into ``output``."""
penalty = self.third_difference_penalty
if penalty <= 0.0:
return
flat_values = values.reshape(-1)
flat_output = output.reshape(-1)
for stencil in self._iter_third_difference_stencils():
if not stencil.columns:
continue
current = torch.zeros(
(), dtype=flat_values.dtype, device=flat_values.device
)
for column, weight in zip(stencil.columns, stencil.weights, strict=True):
current = current + float(weight) * flat_values[int(column)]
for column, weight in zip(stencil.columns, stencil.weights, strict=True):
flat_output[int(column)] += penalty * float(weight) * current
def _iter_third_difference_stencils(
self,
) -> tuple[RegularizationStencil, ...]:
"""Return third-difference rows in compact block coordinates."""
if self.third_difference_penalty <= 0.0:
return ()
if self.third_difference_stencils is not None:
return self.third_difference_stencils
if len(self.shape) not in {1, 2} or self.shape[-1] < 4:
return ()
stencil = (-1.0, 3.0, -3.0, 1.0)
rows: list[RegularizationStencil] = []
if len(self.shape) == 1:
for coeff in range(int(self.shape[0]) - 3):
rows.append(
RegularizationStencil(
columns=(coeff, coeff + 1, coeff + 2, coeff + 3),
weights=stencil,
)
)
return tuple(rows)
n_coeffs = int(self.shape[1])
for row in self._active_1d_rows():
offset = int(row) * n_coeffs
for coeff in range(n_coeffs - 3):
rows.append(
RegularizationStencil(
columns=(
offset + coeff,
offset + coeff + 1,
offset + coeff + 2,
offset + coeff + 3,
),
weights=stencil,
)
)
return tuple(rows)
[docs]
def rhs(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Return the regularization contribution to the normal-equation RHS."""
output = torch.zeros((self.size,), dtype=dtype, device=device)
penalty = self.third_difference_penalty
if penalty <= 0.0:
return output
for stencil in self._iter_third_difference_stencils():
if stencil.target == 0.0:
continue
for column, weight in zip(stencil.columns, stencil.weights, strict=True):
output[int(column)] += penalty * float(stencil.target) * float(weight)
return output
[docs]
def constant(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Return the regularization constant from fixed-coefficient targets."""
value = torch.zeros((), dtype=dtype, device=device)
penalty = self.third_difference_penalty
if penalty <= 0.0:
return value
for stencil in self._iter_third_difference_stencils():
if stencil.target == 0.0:
continue
target = torch.as_tensor(float(stencil.target), dtype=dtype, device=device)
value = value + penalty * target.square()
return value
[docs]
def least_squares_rows(
self,
*,
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Materialize regularization as least-squares rows and targets."""
rows: list[torch.Tensor] = []
targets: list[torch.Tensor] = []
if self.ridge > 0.0:
scale = self.ridge**0.5
rows.append(scale * torch.eye(self.size, dtype=dtype, device=device))
targets.append(torch.zeros((self.size,), dtype=dtype, device=device))
stencils = self._iter_third_difference_stencils()
if self.third_difference_penalty > 0.0 and stencils:
scale = self.third_difference_penalty**0.5
matrix = torch.zeros((len(stencils), self.size), dtype=dtype, device=device)
target = torch.zeros((len(stencils),), dtype=dtype, device=device)
for row, stencil in enumerate(stencils):
for column, weight in zip(
stencil.columns,
stencil.weights,
strict=True,
):
matrix[row, int(column)] += scale * float(weight)
target[row] = scale * float(stencil.target)
rows.append(matrix)
targets.append(target)
if not rows:
return (
torch.zeros((0, self.size), dtype=dtype, device=device),
torch.zeros((0,), dtype=dtype, device=device),
)
return torch.cat(rows, dim=0), torch.cat(targets, dim=0)
[docs]
def diagonal(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Return the diagonal of the block regularizer in vector form."""
diag = torch.full(
(self.size,),
fill_value=self.ridge,
dtype=dtype,
device=device,
)
return self._add_third_difference_diagonal(diag)
def _add_third_difference_diagonal(self, diag: torch.Tensor) -> torch.Tensor:
"""Add the third-difference diagonal contribution in place."""
penalty = self.third_difference_penalty
if penalty <= 0.0:
return diag
flat_diag = diag.reshape(-1)
for stencil in self._iter_third_difference_stencils():
for column, weight in zip(stencil.columns, stencil.weights, strict=True):
flat_diag[int(column)] += penalty * float(weight) * float(weight)
return diag
[docs]
def materialize(self, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Materialize the block regularizer as an explicit square matrix."""
matrix = torch.zeros((self.size, self.size), dtype=dtype, device=device)
if self.ridge != 0.0:
matrix.diagonal().add_(self.ridge)
self._materialize_third_difference(matrix)
return matrix
def _materialize_third_difference(self, matrix: torch.Tensor) -> None:
"""Accumulate the explicit third-difference normal matrix."""
penalty = self.third_difference_penalty
if penalty <= 0.0:
return
for stencil in self._iter_third_difference_stencils():
for col_i, weight_i in zip(
stencil.columns,
stencil.weights,
strict=True,
):
for col_j, weight_j in zip(
stencil.columns,
stencil.weights,
strict=True,
):
matrix[int(col_i), int(col_j)] += (
penalty * float(weight_i) * float(weight_j)
)
def _third_difference_quadratic(
self,
values: torch.Tensor,
) -> torch.Tensor:
"""Evaluate explicit third-difference least-squares rows."""
penalty = self.third_difference_penalty
total = torch.zeros((), dtype=values.dtype, device=values.device)
if penalty <= 0.0:
return total
flat_values = values.reshape(-1)
for stencil in self._iter_third_difference_stencils():
residual = -torch.as_tensor(
float(stencil.target),
dtype=values.dtype,
device=values.device,
)
for column, weight in zip(stencil.columns, stencil.weights, strict=True):
residual = residual + float(weight) * flat_values[int(column)]
total = total + penalty * residual.square()
return total
[docs]
def quadratic(self, values: torch.Tensor) -> torch.Tensor:
"""Evaluate the quadratic penalty contributed by one block vector."""
flat_values = values.reshape(-1)
value = torch.zeros((), dtype=flat_values.dtype, device=flat_values.device)
if self.ridge != 0.0:
value = value + self.ridge * torch.dot(flat_values, flat_values)
return value + self._third_difference_quadratic(flat_values)
def _make_block_regularization(
shape: tuple[int, ...],
*,
ridge: float,
third_difference_penalty: float = 0.0,
active_rows: tuple[int, ...] | None = None,
third_difference_stencils: tuple[RegularizationStencil, ...] | None = None,
) -> BlockRegularization | None:
"""Build a block regularizer only when one of the penalties is active."""
if ridge == 0.0 and third_difference_penalty == 0.0:
return None
if (
ridge == 0.0
and third_difference_penalty > 0.0
and third_difference_stencils == ()
):
return None
return BlockRegularization(
shape=shape,
ridge=float(ridge),
third_difference_penalty=float(third_difference_penalty),
active_rows=active_rows,
third_difference_stencils=third_difference_stencils,
)
__all__ = ["BlockRegularization", "RegularizationStencil"]