Source code for ufp.leastsquares._block

"""Block matrix containers and kernels for linear least-squares problems."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal, Sequence

import torch

from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares.regularization import BlockRegularization


[docs] @dataclass(frozen=True) class SolveBlock: """Solve-time metadata for one block in the assembled linear system.""" key: Any size: int label: str regularization: BlockRegularization | None = None
[docs] @dataclass class BlockProblemLayout: """Ordered block layout used by an assembled linear problem.""" blocks: tuple[SolveBlock, ...] slices: dict[Any, slice] size: int
[docs] @classmethod def from_blocks(cls, blocks: Sequence[SolveBlock]) -> "BlockProblemLayout": """Assign slices to an ordered sequence of solve blocks.""" slices: dict[Any, slice] = {} start = 0 ordered = tuple(blocks) for block in ordered: slices[block.key] = slice(start, start + block.size) start += block.size return cls( blocks=ordered, slices=slices, size=start, )
[docs] def theta_slice(self, key: Any) -> slice: """Return the solve-vector slice occupied by the requested block key.""" return self.slices[key]
[docs] @dataclass(frozen=True) class RowIndexedBlockMatrix: """Dense block values stored only for rows that can be nonzero.""" rows: torch.Tensor values: torch.Tensor n_rows: int def __post_init__(self) -> None: """Validate row-indexed matrix metadata.""" if self.rows.ndim != 1: raise ValueError("`rows` must be a one-dimensional tensor") if self.values.ndim != 2: raise ValueError("`values` must be a two-dimensional tensor") if self.rows.shape[0] != self.values.shape[0]: raise ValueError( "`rows` and `values` must contain the same number of active rows" ) if self.n_rows < 0: raise ValueError("`n_rows` must be non-negative") if self.rows.numel() and ( bool(torch.any(self.rows < 0)) or bool(torch.any(self.rows >= self.n_rows)) ): raise ValueError("`rows` contains an out-of-range row index") if self.rows.numel() > 1 and bool(torch.any(self.rows[1:] <= self.rows[:-1])): raise ValueError("`rows` must be sorted and unique") @property def shape(self) -> tuple[int, int]: """Return the logical dense matrix shape.""" return (int(self.n_rows), int(self.values.shape[1])) @property def dtype(self) -> torch.dtype: """Return the value dtype.""" return self.values.dtype @property def device(self) -> torch.device: """Return the value device.""" return self.values.device
[docs] def materialize(self) -> torch.Tensor: """Return a dense tensor with zero rows restored.""" matrix = torch.zeros( self.shape, dtype=self.values.dtype, device=self.values.device, ) if self.rows.numel(): matrix.index_copy_(0, self.rows.to(device=self.values.device), self.values) return matrix
[docs] def matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply this matrix to one block parameter vector.""" output = torch.zeros( (int(self.n_rows),), dtype=self.values.dtype, device=self.values.device, ) if self.rows.numel(): output.index_add_( 0, self.rows.to(device=self.values.device), self.values @ theta, ) return output
[docs] def rmatvec(self, residual: torch.Tensor) -> torch.Tensor: """Apply this matrix transpose to a residual vector.""" if not self.rows.numel(): return torch.zeros( (int(self.values.shape[1]),), dtype=residual.dtype, device=residual.device, ) rows = self.rows.to(device=residual.device) values = self.values.to(device=residual.device, dtype=residual.dtype) return values.T @ residual.index_select(0, rows)
@dataclass(frozen=True) class ColumnRowIndexedChunk: """One contiguous column chunk stored only for active rows.""" column_start: int rows: torch.Tensor values: torch.Tensor def __post_init__(self) -> None: """Validate one column-row indexed chunk.""" if self.column_start < 0: raise ValueError("`column_start` must be non-negative") if self.rows.ndim != 1: raise ValueError("`rows` must be a one-dimensional tensor") if self.values.ndim != 2: raise ValueError("`values` must be a two-dimensional tensor") if self.rows.shape[0] != self.values.shape[0]: raise ValueError( "`rows` and `values` must contain the same number of active rows" ) if self.rows.numel() > 1 and bool(torch.any(self.rows[1:] <= self.rows[:-1])): raise ValueError("`rows` must be sorted and unique") @property def column_stop(self) -> int: """Return the exclusive column bound for this chunk.""" return int(self.column_start + self.values.shape[1])
[docs] @dataclass(frozen=True) class ColumnRowIndexedBlockMatrix: """Dense block values stored by contiguous column chunks and active rows.""" chunks: tuple[ColumnRowIndexedChunk, ...] n_rows: int n_cols: int def __post_init__(self) -> None: """Validate column-chunked matrix metadata.""" if self.n_rows < 0 or self.n_cols < 0: raise ValueError("`n_rows` and `n_cols` must be non-negative") last_stop = 0 for chunk in self.chunks: if chunk.column_start < last_stop: raise ValueError("column chunks must be sorted and non-overlapping") if chunk.column_stop > self.n_cols: raise ValueError("column chunk exceeds `n_cols`") if chunk.rows.numel() and ( bool(torch.any(chunk.rows < 0)) or bool(torch.any(chunk.rows >= self.n_rows)) ): raise ValueError("chunk rows contain an out-of-range row index") last_stop = chunk.column_stop @property def shape(self) -> tuple[int, int]: """Return the logical dense matrix shape.""" return (int(self.n_rows), int(self.n_cols)) @property def dtype(self) -> torch.dtype: """Return the value dtype.""" for chunk in self.chunks: return chunk.values.dtype return torch.get_default_dtype() @property def device(self) -> torch.device: """Return the value device.""" for chunk in self.chunks: return chunk.values.device return torch.device("cpu")
[docs] def materialize(self) -> torch.Tensor: """Return a dense tensor with zero chunks restored.""" matrix = torch.zeros( self.shape, dtype=self.dtype, device=self.device, ) for chunk in self.chunks: if not chunk.rows.numel(): continue column_slice = slice(chunk.column_start, chunk.column_stop) matrix[:, column_slice].index_copy_( 0, chunk.rows.to(device=self.device), chunk.values, ) return matrix
[docs] def matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply this matrix to one block parameter vector.""" output = torch.zeros( (int(self.n_rows),), dtype=self.dtype, device=self.device, ) for chunk in self.chunks: if not chunk.rows.numel(): continue contribution = chunk.values @ theta[chunk.column_start : chunk.column_stop] output.index_add_(0, chunk.rows.to(device=self.device), contribution) return output
[docs] def rmatvec(self, residual: torch.Tensor) -> torch.Tensor: """Apply this matrix transpose to a residual vector.""" output = torch.zeros( (int(self.n_cols),), dtype=residual.dtype, device=residual.device, ) for chunk in self.chunks: if not chunk.rows.numel(): continue rows = chunk.rows.to(device=residual.device) values = chunk.values.to(device=residual.device, dtype=residual.dtype) output[chunk.column_start : chunk.column_stop] += ( values.T @ residual.index_select(0, rows) ) return output
BlockMatrix = torch.Tensor | RowIndexedBlockMatrix | ColumnRowIndexedBlockMatrix MatrixStorageMode = Literal["dense", "row_indexed", "column_chunked", "auto"] _AUTO_COMPACT_MIN_DENSE_ELEMENTS = 1024
[docs] @dataclass(frozen=True) class BlockSolveBatch: """One assembled batch contributing rows to the block linear problem.""" target: torch.Tensor matrices: dict[Any, BlockMatrix] @property def n_rows(self) -> int: """Return the number of target rows stored in this batch.""" return int(self.target.shape[0])
def _block_solve_batch_from_assembled(batch: AssembledBatch) -> BlockSolveBatch: """Convert an assembled cache batch into the solve-batch container.""" return BlockSolveBatch( target=batch.target, matrices=batch.block_matrices, ) def _compact_block_matrix(matrix: torch.Tensor) -> RowIndexedBlockMatrix: """Return a row-indexed view of the nonzero rows in a dense block matrix.""" active_rows = torch.nonzero( torch.any(matrix != 0, dim=1), as_tuple=False, ).reshape(-1) return RowIndexedBlockMatrix( rows=active_rows, values=matrix.index_select(0, active_rows), n_rows=int(matrix.shape[0]), ) def _row_indexed_stored_elements(matrix: RowIndexedBlockMatrix) -> int: """Return stored value plus row-index elements for a row-indexed matrix.""" return int(matrix.values.numel() + matrix.rows.numel()) def _column_chunked_stored_elements(matrix: ColumnRowIndexedBlockMatrix) -> int: """Return stored value plus row-index elements for a column-chunked matrix.""" return int( sum(chunk.values.numel() + chunk.rows.numel() for chunk in matrix.chunks) ) def _block_matrix_storage_elements(matrix: BlockMatrix) -> int: """Return the number of value/index elements held by one block matrix.""" if isinstance(matrix, RowIndexedBlockMatrix): return _row_indexed_stored_elements(matrix) if isinstance(matrix, ColumnRowIndexedBlockMatrix): return _column_chunked_stored_elements(matrix) return int(matrix.numel()) def _block_matrix_storage_nbytes(matrix: BlockMatrix) -> int: """Return the approximate tensor storage footprint of one block matrix.""" if isinstance(matrix, RowIndexedBlockMatrix): return int( matrix.values.numel() * matrix.values.element_size() + matrix.rows.numel() * matrix.rows.element_size() ) if isinstance(matrix, ColumnRowIndexedBlockMatrix): return int( sum( chunk.values.numel() * chunk.values.element_size() + chunk.rows.numel() * chunk.rows.element_size() for chunk in matrix.chunks ) ) return int(matrix.numel() * matrix.element_size()) def _block_solve_batch_storage_nbytes(batch: "BlockSolveBatch") -> int: """Return approximate tensor storage bytes for all matrices in one batch.""" return int( sum(_block_matrix_storage_nbytes(matrix) for matrix in batch.matrices.values()) ) def _block_solve_batch_storage_elements(batch: "BlockSolveBatch") -> int: """Return value/index element count for all matrices in one batch.""" return int( sum( _block_matrix_storage_elements(matrix) for matrix in batch.matrices.values() ) ) def _compact_column_chunked_block_matrix( matrix: torch.Tensor, *, chunk_size: int, ) -> ColumnRowIndexedBlockMatrix | None: """Return column chunks stored only for rows active in that chunk.""" if chunk_size <= 0 or int(matrix.shape[1]) % int(chunk_size) != 0: return None chunks: list[ColumnRowIndexedChunk] = [] stored_elements = 0 n_rows = int(matrix.shape[0]) n_cols = int(matrix.shape[1]) for column_start in range(0, n_cols, int(chunk_size)): column_stop = column_start + int(chunk_size) values = matrix[:, column_start:column_stop] active_rows = torch.nonzero( torch.any(values != 0, dim=1), as_tuple=False, ).reshape(-1) if not active_rows.numel(): continue active_values = values.index_select(0, active_rows) stored_elements += int(active_values.numel()) + int(active_rows.numel()) chunks.append( ColumnRowIndexedChunk( column_start=column_start, rows=active_rows, values=active_values, ) ) if not chunks or stored_elements >= int(matrix.numel()): return None return ColumnRowIndexedBlockMatrix( chunks=tuple(chunks), n_rows=n_rows, n_cols=n_cols, ) def _compact_block_matrix_for_storage( matrix: BlockMatrix, *, mode: MatrixStorageMode, column_chunk_size: int | None = None, ) -> BlockMatrix: """Return one block matrix in the requested solve/storage representation.""" if mode not in {"dense", "row_indexed", "column_chunked", "auto"}: raise ValueError( "`matrix_storage` must be one of: dense, row_indexed, column_chunked, auto" ) if mode == "dense": return _materialize_block_matrix(matrix) if not isinstance(matrix, torch.Tensor): return matrix dense_elements = int(matrix.numel()) if dense_elements == 0: return matrix if mode in {"column_chunked", "auto"} and column_chunk_size is not None: if ( mode == "column_chunked" or dense_elements >= _AUTO_COMPACT_MIN_DENSE_ELEMENTS ): chunked = _compact_column_chunked_block_matrix( matrix, chunk_size=int(column_chunk_size), ) if chunked is not None: return chunked if mode == "row_indexed" or ( mode == "auto" and dense_elements >= _AUTO_COMPACT_MIN_DENSE_ELEMENTS ): compact = _compact_block_matrix(matrix) if _row_indexed_stored_elements(compact) < dense_elements: return compact return matrix def _materialize_block_matrix(matrix: BlockMatrix) -> torch.Tensor: """Return a dense tensor for either supported block matrix representation.""" if isinstance(matrix, RowIndexedBlockMatrix): return matrix.materialize() if isinstance(matrix, ColumnRowIndexedBlockMatrix): return matrix.materialize() return matrix def _block_matrix_matvec(matrix: BlockMatrix, theta: torch.Tensor) -> torch.Tensor: """Apply one block matrix to a parameter vector.""" if isinstance(matrix, RowIndexedBlockMatrix): return matrix.matvec(theta) if isinstance(matrix, ColumnRowIndexedBlockMatrix): return matrix.matvec(theta) return matrix @ theta def _block_matrix_rmatvec(matrix: BlockMatrix, residual: torch.Tensor) -> torch.Tensor: """Apply one block matrix transpose to a residual vector.""" if isinstance(matrix, RowIndexedBlockMatrix): return matrix.rmatvec(residual) if isinstance(matrix, ColumnRowIndexedBlockMatrix): return matrix.rmatvec(residual) return matrix.T @ residual def _block_matrix_diagonal(matrix: BlockMatrix) -> torch.Tensor: """Return the diagonal contribution of one block to ``A.T @ A``.""" if isinstance(matrix, ColumnRowIndexedBlockMatrix): output = torch.zeros( (int(matrix.n_cols),), dtype=matrix.dtype, device=matrix.device, ) for chunk in matrix.chunks: output[chunk.column_start : chunk.column_stop] = torch.sum( chunk.values * chunk.values, dim=0, ) return output values = matrix.values if isinstance(matrix, RowIndexedBlockMatrix) else matrix return torch.sum(values * values, dim=0) def _row_indexed_values_on_rows( matrix: RowIndexedBlockMatrix, rows: torch.Tensor, ) -> torch.Tensor: """Return row-indexed matrix values aligned to requested dense row ids.""" rows = rows.to(device=matrix.values.device, dtype=torch.int64) output = torch.zeros( (int(rows.numel()), int(matrix.values.shape[1])), dtype=matrix.values.dtype, device=matrix.values.device, ) if not rows.numel() or not matrix.rows.numel(): return output stored_rows = matrix.rows.to(device=matrix.values.device, dtype=torch.int64) positions = torch.searchsorted(stored_rows, rows) in_bounds = positions < int(stored_rows.numel()) if not torch.any(in_bounds): return output candidate_rows = stored_rows.index_select(0, positions[in_bounds]) matching = candidate_rows == rows[in_bounds] if not torch.any(matching): return output output_rows = torch.nonzero(in_bounds, as_tuple=False).reshape(-1)[matching] output.index_copy_( 0, output_rows, matrix.values.index_select(0, positions[in_bounds][matching]), ) return output def _block_matrix_values_on_rows( matrix: BlockMatrix, rows: torch.Tensor, ) -> torch.Tensor: """Return matrix values for selected dense rows.""" if isinstance(matrix, RowIndexedBlockMatrix): return _row_indexed_values_on_rows(matrix, rows) if isinstance(matrix, ColumnRowIndexedBlockMatrix): rows = rows.to(device=matrix.device, dtype=torch.int64) output = torch.zeros( (int(rows.numel()), int(matrix.n_cols)), dtype=matrix.dtype, device=matrix.device, ) for chunk in matrix.chunks: chunk_values = _row_indexed_values_on_rows( RowIndexedBlockMatrix( rows=chunk.rows, values=chunk.values, n_rows=matrix.n_rows, ), rows, ) output[:, chunk.column_start : chunk.column_stop] = chunk_values return output return matrix.index_select(0, rows.to(device=matrix.device, dtype=torch.int64)) def _block_matrix_cross(lhs: BlockMatrix, rhs: BlockMatrix) -> torch.Tensor: """Return ``lhs.T @ rhs`` without materializing zero rows.""" if isinstance(lhs, ColumnRowIndexedBlockMatrix): output = torch.zeros( (int(lhs.n_cols), int(rhs.shape[1])), dtype=lhs.dtype, device=lhs.device, ) for chunk in lhs.chunks: rhs_values = _block_matrix_values_on_rows(rhs, chunk.rows) output[chunk.column_start : chunk.column_stop] += ( chunk.values.T @ rhs_values.to(device=lhs.device, dtype=lhs.dtype) ) return output if isinstance(rhs, ColumnRowIndexedBlockMatrix): output = torch.zeros( (int(lhs.shape[1]), int(rhs.n_cols)), dtype=rhs.dtype, device=rhs.device, ) for chunk in rhs.chunks: lhs_values = _block_matrix_values_on_rows(lhs, chunk.rows) output[:, chunk.column_start : chunk.column_stop] += ( lhs_values.to(device=rhs.device, dtype=rhs.dtype).T @ chunk.values ) return output if isinstance(lhs, RowIndexedBlockMatrix): rhs_values = _block_matrix_values_on_rows(rhs, lhs.rows) return lhs.values.T @ rhs_values.to(device=lhs.device, dtype=lhs.dtype) if isinstance(rhs, RowIndexedBlockMatrix): lhs_values = lhs.index_select(0, rhs.rows.to(device=lhs.device)) return lhs_values.T @ rhs.values.to(device=lhs.device, dtype=lhs.dtype) return lhs.T @ rhs def _weight_block_matrix( matrix: BlockMatrix, sqrt_weights: torch.Tensor ) -> BlockMatrix: """Apply row square-root weights to one block matrix.""" if isinstance(matrix, ColumnRowIndexedBlockMatrix): chunks = [] for chunk in matrix.chunks: rows = chunk.rows.to(device=sqrt_weights.device, dtype=torch.int64) row_weights = sqrt_weights.index_select(0, rows).to( device=chunk.values.device, dtype=chunk.values.dtype, ) chunks.append( ColumnRowIndexedChunk( column_start=chunk.column_start, rows=chunk.rows, values=chunk.values * row_weights[:, None], ) ) return ColumnRowIndexedBlockMatrix( chunks=tuple(chunks), n_rows=matrix.n_rows, n_cols=matrix.n_cols, ) if isinstance(matrix, RowIndexedBlockMatrix): rows = matrix.rows.to(device=sqrt_weights.device, dtype=torch.int64) row_weights = sqrt_weights.index_select(0, rows).to( device=matrix.values.device, dtype=matrix.values.dtype, ) return RowIndexedBlockMatrix( rows=matrix.rows, values=matrix.values * row_weights[:, None], n_rows=matrix.n_rows, ) return matrix * sqrt_weights.to(device=matrix.device, dtype=matrix.dtype)[:, None] def _apply_row_weights_to_assembled_batch( batch: AssembledBatch, sqrt_weights: torch.Tensor, ) -> AssembledBatch: """Apply current target weights to one unweighted assembled cache batch.""" if int(sqrt_weights.numel()) != batch.n_rows: raise ValueError( "cached least-squares batch row count does not match current targets" ) weights = sqrt_weights.to(device=batch.target.device, dtype=batch.target.dtype) return AssembledBatch( target=weights * batch.target, block_matrices={ key: _weight_block_matrix(matrix, weights) for key, matrix in batch.block_matrices.items() }, ) def _accumulate_block_rows_normal_equations( *, layout: BlockProblemLayout, block_matrices: dict[int, torch.Tensor], target: torch.Tensor, rows: torch.Tensor, gram: torch.Tensor, rhs: torch.Tensor, ) -> torch.Tensor: """Accumulate selected target rows into normal-equation tensors.""" if rows.numel() == 0: return torch.zeros((), dtype=target.dtype, device=target.device) row_target = target.index_select(0, rows) keys = tuple(block_matrices) for key in keys: theta_slice = layout.theta_slice(key) block_matrix = block_matrices[key].index_select(0, rows) rhs[theta_slice] += block_matrix.T @ row_target for index_i, key_i in enumerate(keys): slice_i = layout.theta_slice(key_i) matrix_i = block_matrices[key_i].index_select(0, rows) gram[slice_i, slice_i] += matrix_i.T @ matrix_i for key_j in keys[index_i + 1 :]: slice_j = layout.theta_slice(key_j) matrix_j = block_matrices[key_j].index_select(0, rows) cross = matrix_i.T @ matrix_j gram[slice_i, slice_j] += cross gram[slice_j, slice_i] += cross.T return torch.dot(row_target, row_target) __all__ = [ "BlockMatrix", "BlockProblemLayout", "BlockSolveBatch", "ColumnRowIndexedBlockMatrix", "ColumnRowIndexedChunk", "MatrixStorageMode", "RowIndexedBlockMatrix", "SolveBlock", ]