"""Block matrix containers and kernels for linear least-squares problems."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Literal, Sequence
import torch
from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares.regularization import BlockRegularization
[docs]
@dataclass(frozen=True)
class SolveBlock:
"""Solve-time metadata for one block in the assembled linear system."""
key: Any
size: int
label: str
regularization: BlockRegularization | None = None
[docs]
@dataclass
class BlockProblemLayout:
"""Ordered block layout used by an assembled linear problem."""
blocks: tuple[SolveBlock, ...]
slices: dict[Any, slice]
size: int
[docs]
@classmethod
def from_blocks(cls, blocks: Sequence[SolveBlock]) -> "BlockProblemLayout":
"""Assign slices to an ordered sequence of solve blocks."""
slices: dict[Any, slice] = {}
start = 0
ordered = tuple(blocks)
for block in ordered:
slices[block.key] = slice(start, start + block.size)
start += block.size
return cls(
blocks=ordered,
slices=slices,
size=start,
)
[docs]
def theta_slice(self, key: Any) -> slice:
"""Return the solve-vector slice occupied by the requested block key."""
return self.slices[key]
[docs]
@dataclass(frozen=True)
class RowIndexedBlockMatrix:
"""Dense block values stored only for rows that can be nonzero."""
rows: torch.Tensor
values: torch.Tensor
n_rows: int
def __post_init__(self) -> None:
"""Validate row-indexed matrix metadata."""
if self.rows.ndim != 1:
raise ValueError("`rows` must be a one-dimensional tensor")
if self.values.ndim != 2:
raise ValueError("`values` must be a two-dimensional tensor")
if self.rows.shape[0] != self.values.shape[0]:
raise ValueError(
"`rows` and `values` must contain the same number of active rows"
)
if self.n_rows < 0:
raise ValueError("`n_rows` must be non-negative")
if self.rows.numel() and (
bool(torch.any(self.rows < 0)) or bool(torch.any(self.rows >= self.n_rows))
):
raise ValueError("`rows` contains an out-of-range row index")
if self.rows.numel() > 1 and bool(torch.any(self.rows[1:] <= self.rows[:-1])):
raise ValueError("`rows` must be sorted and unique")
@property
def shape(self) -> tuple[int, int]:
"""Return the logical dense matrix shape."""
return (int(self.n_rows), int(self.values.shape[1]))
@property
def dtype(self) -> torch.dtype:
"""Return the value dtype."""
return self.values.dtype
@property
def device(self) -> torch.device:
"""Return the value device."""
return self.values.device
[docs]
def materialize(self) -> torch.Tensor:
"""Return a dense tensor with zero rows restored."""
matrix = torch.zeros(
self.shape,
dtype=self.values.dtype,
device=self.values.device,
)
if self.rows.numel():
matrix.index_copy_(0, self.rows.to(device=self.values.device), self.values)
return matrix
[docs]
def matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply this matrix to one block parameter vector."""
output = torch.zeros(
(int(self.n_rows),),
dtype=self.values.dtype,
device=self.values.device,
)
if self.rows.numel():
output.index_add_(
0,
self.rows.to(device=self.values.device),
self.values @ theta,
)
return output
[docs]
def rmatvec(self, residual: torch.Tensor) -> torch.Tensor:
"""Apply this matrix transpose to a residual vector."""
if not self.rows.numel():
return torch.zeros(
(int(self.values.shape[1]),),
dtype=residual.dtype,
device=residual.device,
)
rows = self.rows.to(device=residual.device)
values = self.values.to(device=residual.device, dtype=residual.dtype)
return values.T @ residual.index_select(0, rows)
@dataclass(frozen=True)
class ColumnRowIndexedChunk:
"""One contiguous column chunk stored only for active rows."""
column_start: int
rows: torch.Tensor
values: torch.Tensor
def __post_init__(self) -> None:
"""Validate one column-row indexed chunk."""
if self.column_start < 0:
raise ValueError("`column_start` must be non-negative")
if self.rows.ndim != 1:
raise ValueError("`rows` must be a one-dimensional tensor")
if self.values.ndim != 2:
raise ValueError("`values` must be a two-dimensional tensor")
if self.rows.shape[0] != self.values.shape[0]:
raise ValueError(
"`rows` and `values` must contain the same number of active rows"
)
if self.rows.numel() > 1 and bool(torch.any(self.rows[1:] <= self.rows[:-1])):
raise ValueError("`rows` must be sorted and unique")
@property
def column_stop(self) -> int:
"""Return the exclusive column bound for this chunk."""
return int(self.column_start + self.values.shape[1])
[docs]
@dataclass(frozen=True)
class ColumnRowIndexedBlockMatrix:
"""Dense block values stored by contiguous column chunks and active rows."""
chunks: tuple[ColumnRowIndexedChunk, ...]
n_rows: int
n_cols: int
def __post_init__(self) -> None:
"""Validate column-chunked matrix metadata."""
if self.n_rows < 0 or self.n_cols < 0:
raise ValueError("`n_rows` and `n_cols` must be non-negative")
last_stop = 0
for chunk in self.chunks:
if chunk.column_start < last_stop:
raise ValueError("column chunks must be sorted and non-overlapping")
if chunk.column_stop > self.n_cols:
raise ValueError("column chunk exceeds `n_cols`")
if chunk.rows.numel() and (
bool(torch.any(chunk.rows < 0))
or bool(torch.any(chunk.rows >= self.n_rows))
):
raise ValueError("chunk rows contain an out-of-range row index")
last_stop = chunk.column_stop
@property
def shape(self) -> tuple[int, int]:
"""Return the logical dense matrix shape."""
return (int(self.n_rows), int(self.n_cols))
@property
def dtype(self) -> torch.dtype:
"""Return the value dtype."""
for chunk in self.chunks:
return chunk.values.dtype
return torch.get_default_dtype()
@property
def device(self) -> torch.device:
"""Return the value device."""
for chunk in self.chunks:
return chunk.values.device
return torch.device("cpu")
[docs]
def materialize(self) -> torch.Tensor:
"""Return a dense tensor with zero chunks restored."""
matrix = torch.zeros(
self.shape,
dtype=self.dtype,
device=self.device,
)
for chunk in self.chunks:
if not chunk.rows.numel():
continue
column_slice = slice(chunk.column_start, chunk.column_stop)
matrix[:, column_slice].index_copy_(
0,
chunk.rows.to(device=self.device),
chunk.values,
)
return matrix
[docs]
def matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply this matrix to one block parameter vector."""
output = torch.zeros(
(int(self.n_rows),),
dtype=self.dtype,
device=self.device,
)
for chunk in self.chunks:
if not chunk.rows.numel():
continue
contribution = chunk.values @ theta[chunk.column_start : chunk.column_stop]
output.index_add_(0, chunk.rows.to(device=self.device), contribution)
return output
[docs]
def rmatvec(self, residual: torch.Tensor) -> torch.Tensor:
"""Apply this matrix transpose to a residual vector."""
output = torch.zeros(
(int(self.n_cols),),
dtype=residual.dtype,
device=residual.device,
)
for chunk in self.chunks:
if not chunk.rows.numel():
continue
rows = chunk.rows.to(device=residual.device)
values = chunk.values.to(device=residual.device, dtype=residual.dtype)
output[chunk.column_start : chunk.column_stop] += (
values.T @ residual.index_select(0, rows)
)
return output
BlockMatrix = torch.Tensor | RowIndexedBlockMatrix | ColumnRowIndexedBlockMatrix
MatrixStorageMode = Literal["dense", "row_indexed", "column_chunked", "auto"]
_AUTO_COMPACT_MIN_DENSE_ELEMENTS = 1024
[docs]
@dataclass(frozen=True)
class BlockSolveBatch:
"""One assembled batch contributing rows to the block linear problem."""
target: torch.Tensor
matrices: dict[Any, BlockMatrix]
@property
def n_rows(self) -> int:
"""Return the number of target rows stored in this batch."""
return int(self.target.shape[0])
def _block_solve_batch_from_assembled(batch: AssembledBatch) -> BlockSolveBatch:
"""Convert an assembled cache batch into the solve-batch container."""
return BlockSolveBatch(
target=batch.target,
matrices=batch.block_matrices,
)
def _compact_block_matrix(matrix: torch.Tensor) -> RowIndexedBlockMatrix:
"""Return a row-indexed view of the nonzero rows in a dense block matrix."""
active_rows = torch.nonzero(
torch.any(matrix != 0, dim=1),
as_tuple=False,
).reshape(-1)
return RowIndexedBlockMatrix(
rows=active_rows,
values=matrix.index_select(0, active_rows),
n_rows=int(matrix.shape[0]),
)
def _row_indexed_stored_elements(matrix: RowIndexedBlockMatrix) -> int:
"""Return stored value plus row-index elements for a row-indexed matrix."""
return int(matrix.values.numel() + matrix.rows.numel())
def _column_chunked_stored_elements(matrix: ColumnRowIndexedBlockMatrix) -> int:
"""Return stored value plus row-index elements for a column-chunked matrix."""
return int(
sum(chunk.values.numel() + chunk.rows.numel() for chunk in matrix.chunks)
)
def _block_matrix_storage_elements(matrix: BlockMatrix) -> int:
"""Return the number of value/index elements held by one block matrix."""
if isinstance(matrix, RowIndexedBlockMatrix):
return _row_indexed_stored_elements(matrix)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
return _column_chunked_stored_elements(matrix)
return int(matrix.numel())
def _block_matrix_storage_nbytes(matrix: BlockMatrix) -> int:
"""Return the approximate tensor storage footprint of one block matrix."""
if isinstance(matrix, RowIndexedBlockMatrix):
return int(
matrix.values.numel() * matrix.values.element_size()
+ matrix.rows.numel() * matrix.rows.element_size()
)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
return int(
sum(
chunk.values.numel() * chunk.values.element_size()
+ chunk.rows.numel() * chunk.rows.element_size()
for chunk in matrix.chunks
)
)
return int(matrix.numel() * matrix.element_size())
def _block_solve_batch_storage_nbytes(batch: "BlockSolveBatch") -> int:
"""Return approximate tensor storage bytes for all matrices in one batch."""
return int(
sum(_block_matrix_storage_nbytes(matrix) for matrix in batch.matrices.values())
)
def _block_solve_batch_storage_elements(batch: "BlockSolveBatch") -> int:
"""Return value/index element count for all matrices in one batch."""
return int(
sum(
_block_matrix_storage_elements(matrix) for matrix in batch.matrices.values()
)
)
def _compact_column_chunked_block_matrix(
matrix: torch.Tensor,
*,
chunk_size: int,
) -> ColumnRowIndexedBlockMatrix | None:
"""Return column chunks stored only for rows active in that chunk."""
if chunk_size <= 0 or int(matrix.shape[1]) % int(chunk_size) != 0:
return None
chunks: list[ColumnRowIndexedChunk] = []
stored_elements = 0
n_rows = int(matrix.shape[0])
n_cols = int(matrix.shape[1])
for column_start in range(0, n_cols, int(chunk_size)):
column_stop = column_start + int(chunk_size)
values = matrix[:, column_start:column_stop]
active_rows = torch.nonzero(
torch.any(values != 0, dim=1),
as_tuple=False,
).reshape(-1)
if not active_rows.numel():
continue
active_values = values.index_select(0, active_rows)
stored_elements += int(active_values.numel()) + int(active_rows.numel())
chunks.append(
ColumnRowIndexedChunk(
column_start=column_start,
rows=active_rows,
values=active_values,
)
)
if not chunks or stored_elements >= int(matrix.numel()):
return None
return ColumnRowIndexedBlockMatrix(
chunks=tuple(chunks),
n_rows=n_rows,
n_cols=n_cols,
)
def _compact_block_matrix_for_storage(
matrix: BlockMatrix,
*,
mode: MatrixStorageMode,
column_chunk_size: int | None = None,
) -> BlockMatrix:
"""Return one block matrix in the requested solve/storage representation."""
if mode not in {"dense", "row_indexed", "column_chunked", "auto"}:
raise ValueError(
"`matrix_storage` must be one of: dense, row_indexed, column_chunked, auto"
)
if mode == "dense":
return _materialize_block_matrix(matrix)
if not isinstance(matrix, torch.Tensor):
return matrix
dense_elements = int(matrix.numel())
if dense_elements == 0:
return matrix
if mode in {"column_chunked", "auto"} and column_chunk_size is not None:
if (
mode == "column_chunked"
or dense_elements >= _AUTO_COMPACT_MIN_DENSE_ELEMENTS
):
chunked = _compact_column_chunked_block_matrix(
matrix,
chunk_size=int(column_chunk_size),
)
if chunked is not None:
return chunked
if mode == "row_indexed" or (
mode == "auto" and dense_elements >= _AUTO_COMPACT_MIN_DENSE_ELEMENTS
):
compact = _compact_block_matrix(matrix)
if _row_indexed_stored_elements(compact) < dense_elements:
return compact
return matrix
def _materialize_block_matrix(matrix: BlockMatrix) -> torch.Tensor:
"""Return a dense tensor for either supported block matrix representation."""
if isinstance(matrix, RowIndexedBlockMatrix):
return matrix.materialize()
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
return matrix.materialize()
return matrix
def _block_matrix_matvec(matrix: BlockMatrix, theta: torch.Tensor) -> torch.Tensor:
"""Apply one block matrix to a parameter vector."""
if isinstance(matrix, RowIndexedBlockMatrix):
return matrix.matvec(theta)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
return matrix.matvec(theta)
return matrix @ theta
def _block_matrix_rmatvec(matrix: BlockMatrix, residual: torch.Tensor) -> torch.Tensor:
"""Apply one block matrix transpose to a residual vector."""
if isinstance(matrix, RowIndexedBlockMatrix):
return matrix.rmatvec(residual)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
return matrix.rmatvec(residual)
return matrix.T @ residual
def _block_matrix_diagonal(matrix: BlockMatrix) -> torch.Tensor:
"""Return the diagonal contribution of one block to ``A.T @ A``."""
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
output = torch.zeros(
(int(matrix.n_cols),),
dtype=matrix.dtype,
device=matrix.device,
)
for chunk in matrix.chunks:
output[chunk.column_start : chunk.column_stop] = torch.sum(
chunk.values * chunk.values,
dim=0,
)
return output
values = matrix.values if isinstance(matrix, RowIndexedBlockMatrix) else matrix
return torch.sum(values * values, dim=0)
def _row_indexed_values_on_rows(
matrix: RowIndexedBlockMatrix,
rows: torch.Tensor,
) -> torch.Tensor:
"""Return row-indexed matrix values aligned to requested dense row ids."""
rows = rows.to(device=matrix.values.device, dtype=torch.int64)
output = torch.zeros(
(int(rows.numel()), int(matrix.values.shape[1])),
dtype=matrix.values.dtype,
device=matrix.values.device,
)
if not rows.numel() or not matrix.rows.numel():
return output
stored_rows = matrix.rows.to(device=matrix.values.device, dtype=torch.int64)
positions = torch.searchsorted(stored_rows, rows)
in_bounds = positions < int(stored_rows.numel())
if not torch.any(in_bounds):
return output
candidate_rows = stored_rows.index_select(0, positions[in_bounds])
matching = candidate_rows == rows[in_bounds]
if not torch.any(matching):
return output
output_rows = torch.nonzero(in_bounds, as_tuple=False).reshape(-1)[matching]
output.index_copy_(
0,
output_rows,
matrix.values.index_select(0, positions[in_bounds][matching]),
)
return output
def _block_matrix_values_on_rows(
matrix: BlockMatrix,
rows: torch.Tensor,
) -> torch.Tensor:
"""Return matrix values for selected dense rows."""
if isinstance(matrix, RowIndexedBlockMatrix):
return _row_indexed_values_on_rows(matrix, rows)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
rows = rows.to(device=matrix.device, dtype=torch.int64)
output = torch.zeros(
(int(rows.numel()), int(matrix.n_cols)),
dtype=matrix.dtype,
device=matrix.device,
)
for chunk in matrix.chunks:
chunk_values = _row_indexed_values_on_rows(
RowIndexedBlockMatrix(
rows=chunk.rows,
values=chunk.values,
n_rows=matrix.n_rows,
),
rows,
)
output[:, chunk.column_start : chunk.column_stop] = chunk_values
return output
return matrix.index_select(0, rows.to(device=matrix.device, dtype=torch.int64))
def _block_matrix_cross(lhs: BlockMatrix, rhs: BlockMatrix) -> torch.Tensor:
"""Return ``lhs.T @ rhs`` without materializing zero rows."""
if isinstance(lhs, ColumnRowIndexedBlockMatrix):
output = torch.zeros(
(int(lhs.n_cols), int(rhs.shape[1])),
dtype=lhs.dtype,
device=lhs.device,
)
for chunk in lhs.chunks:
rhs_values = _block_matrix_values_on_rows(rhs, chunk.rows)
output[chunk.column_start : chunk.column_stop] += (
chunk.values.T @ rhs_values.to(device=lhs.device, dtype=lhs.dtype)
)
return output
if isinstance(rhs, ColumnRowIndexedBlockMatrix):
output = torch.zeros(
(int(lhs.shape[1]), int(rhs.n_cols)),
dtype=rhs.dtype,
device=rhs.device,
)
for chunk in rhs.chunks:
lhs_values = _block_matrix_values_on_rows(lhs, chunk.rows)
output[:, chunk.column_start : chunk.column_stop] += (
lhs_values.to(device=rhs.device, dtype=rhs.dtype).T @ chunk.values
)
return output
if isinstance(lhs, RowIndexedBlockMatrix):
rhs_values = _block_matrix_values_on_rows(rhs, lhs.rows)
return lhs.values.T @ rhs_values.to(device=lhs.device, dtype=lhs.dtype)
if isinstance(rhs, RowIndexedBlockMatrix):
lhs_values = lhs.index_select(0, rhs.rows.to(device=lhs.device))
return lhs_values.T @ rhs.values.to(device=lhs.device, dtype=lhs.dtype)
return lhs.T @ rhs
def _weight_block_matrix(
matrix: BlockMatrix, sqrt_weights: torch.Tensor
) -> BlockMatrix:
"""Apply row square-root weights to one block matrix."""
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
chunks = []
for chunk in matrix.chunks:
rows = chunk.rows.to(device=sqrt_weights.device, dtype=torch.int64)
row_weights = sqrt_weights.index_select(0, rows).to(
device=chunk.values.device,
dtype=chunk.values.dtype,
)
chunks.append(
ColumnRowIndexedChunk(
column_start=chunk.column_start,
rows=chunk.rows,
values=chunk.values * row_weights[:, None],
)
)
return ColumnRowIndexedBlockMatrix(
chunks=tuple(chunks),
n_rows=matrix.n_rows,
n_cols=matrix.n_cols,
)
if isinstance(matrix, RowIndexedBlockMatrix):
rows = matrix.rows.to(device=sqrt_weights.device, dtype=torch.int64)
row_weights = sqrt_weights.index_select(0, rows).to(
device=matrix.values.device,
dtype=matrix.values.dtype,
)
return RowIndexedBlockMatrix(
rows=matrix.rows,
values=matrix.values * row_weights[:, None],
n_rows=matrix.n_rows,
)
return matrix * sqrt_weights.to(device=matrix.device, dtype=matrix.dtype)[:, None]
def _apply_row_weights_to_assembled_batch(
batch: AssembledBatch,
sqrt_weights: torch.Tensor,
) -> AssembledBatch:
"""Apply current target weights to one unweighted assembled cache batch."""
if int(sqrt_weights.numel()) != batch.n_rows:
raise ValueError(
"cached least-squares batch row count does not match current targets"
)
weights = sqrt_weights.to(device=batch.target.device, dtype=batch.target.dtype)
return AssembledBatch(
target=weights * batch.target,
block_matrices={
key: _weight_block_matrix(matrix, weights)
for key, matrix in batch.block_matrices.items()
},
)
def _accumulate_block_rows_normal_equations(
*,
layout: BlockProblemLayout,
block_matrices: dict[int, torch.Tensor],
target: torch.Tensor,
rows: torch.Tensor,
gram: torch.Tensor,
rhs: torch.Tensor,
) -> torch.Tensor:
"""Accumulate selected target rows into normal-equation tensors."""
if rows.numel() == 0:
return torch.zeros((), dtype=target.dtype, device=target.device)
row_target = target.index_select(0, rows)
keys = tuple(block_matrices)
for key in keys:
theta_slice = layout.theta_slice(key)
block_matrix = block_matrices[key].index_select(0, rows)
rhs[theta_slice] += block_matrix.T @ row_target
for index_i, key_i in enumerate(keys):
slice_i = layout.theta_slice(key_i)
matrix_i = block_matrices[key_i].index_select(0, rows)
gram[slice_i, slice_i] += matrix_i.T @ matrix_i
for key_j in keys[index_i + 1 :]:
slice_j = layout.theta_slice(key_j)
matrix_j = block_matrices[key_j].index_select(0, rows)
cross = matrix_i.T @ matrix_j
gram[slice_i, slice_j] += cross
gram[slice_j, slice_i] += cross.T
return torch.dot(row_target, row_target)
__all__ = [
"BlockMatrix",
"BlockProblemLayout",
"BlockSolveBatch",
"ColumnRowIndexedBlockMatrix",
"ColumnRowIndexedChunk",
"MatrixStorageMode",
"RowIndexedBlockMatrix",
"SolveBlock",
]