Source code for ufp.leastsquares._problem

"""Core matrix-free least-squares problem and iterative solver."""

from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

import numpy as np
import torch

from ufp.leastsquares._block import (
    BlockProblemLayout,
    BlockSolveBatch,
    _block_matrix_cross,
    _block_matrix_diagonal,
    _block_matrix_matvec,
    _block_matrix_rmatvec,
    _materialize_block_matrix,
)
from ufp.leastsquares._layout import ParameterLayout


[docs] @dataclass class BlockLinearProblem: """Matrix-free linear least-squares problem assembled from block batches.""" layout: BlockProblemLayout batches: Sequence[BlockSolveBatch] @property def n_rows(self) -> int: """Return the total number of target rows across all batches.""" return sum(batch.n_rows for batch in self.batches) @property def dtype(self) -> torch.dtype: """Return the dtype shared by the assembled problem tensors.""" if self.batches: return self.batches[0].target.dtype return torch.get_default_dtype() @property def device(self) -> torch.device: """Return the device shared by the assembled problem tensors.""" if self.batches: return self.batches[0].target.device return torch.device("cpu")
[docs] def target_vector(self) -> torch.Tensor: """Concatenate all batch targets into one right-hand-side vector.""" if not self.batches: return torch.zeros(0, dtype=self.dtype, device=self.device) return torch.cat([batch.target for batch in self.batches], dim=0)
[docs] def materialize_design_matrix(self) -> tuple[torch.Tensor, torch.Tensor]: """Build the explicit dense design matrix for debugging or tiny problems.""" matrix = torch.zeros( (self.n_rows, self.layout.size), dtype=self.dtype, device=self.device, ) offset = 0 for batch in self.batches: for key, block_matrix in batch.matrices.items(): matrix[offset : offset + batch.n_rows, self.layout.theta_slice(key)] = ( _materialize_block_matrix(block_matrix) ) offset += batch.n_rows return matrix, self.target_vector()
[docs] def matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply the design matrix to a flat parameter vector.""" theta = theta.reshape(self.layout.size) outputs: list[torch.Tensor] = [] for batch in self.batches: prediction = torch.zeros( (batch.n_rows,), dtype=batch.target.dtype, device=batch.target.device, ) for key, block_matrix in batch.matrices.items(): prediction = prediction + _block_matrix_matvec( block_matrix, theta[self.layout.theta_slice(key)], ) outputs.append(prediction) if not outputs: return torch.zeros(0, dtype=self.dtype, device=self.device) return torch.cat(outputs, dim=0)
[docs] def rmatvec(self, residual: torch.Tensor) -> torch.Tensor: """Apply the transpose design matrix to a residual vector.""" residual = residual.reshape(self.n_rows) output = torch.zeros( (self.layout.size,), dtype=residual.dtype, device=residual.device, ) offset = 0 for batch in self.batches: batch_residual = residual[offset : offset + batch.n_rows] for key, block_matrix in batch.matrices.items(): output[self.layout.theta_slice(key)] += _block_matrix_rmatvec( block_matrix, batch_residual, ) offset += batch.n_rows return output
[docs] def regularization_apply(self, theta: torch.Tensor) -> torch.Tensor: """Apply all block regularizers to a flat parameter vector.""" theta = theta.reshape(self.layout.size) output = torch.zeros_like(theta) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) output[theta_slice] += block.regularization.apply(theta[theta_slice]) return output
[docs] def regularization_diagonal(self) -> torch.Tensor: """Return the summed diagonal preconditioner implied by block regularizers.""" diag = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) diag[theta_slice] += block.regularization.diagonal( dtype=self.dtype, device=self.device, ) return diag
[docs] def regularization_rhs(self) -> torch.Tensor: """Return the summed RHS shifts implied by block regularizers.""" rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) rhs[theta_slice] += block.regularization.rhs( dtype=self.dtype, device=self.device, ) return rhs
[docs] def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply the regularized normal-equation operator to a flat vector.""" return self.rmatvec(self.matvec(theta)) + self.regularization_apply(theta)
[docs] def normal_equation_diagonal(self) -> torch.Tensor: """Return the diagonal of the weighted design ``A.T @ A``.""" return _normal_equation_diagonal(self)
[docs] def design_trace_by_block(self) -> dict[object, float]: """Return weighted design-matrix trace contributions by solve block.""" diagonal = self.normal_equation_diagonal() return { block.key: float(diagonal[self.layout.theta_slice(block.key)].sum().item()) for block in self.layout.blocks }
[docs] def rhs(self) -> torch.Tensor: """Return the right-hand side of the normal equations.""" return self.rmatvec(self.target_vector()) + self.regularization_rhs()
[docs] def objective(self, theta: torch.Tensor) -> torch.Tensor: """Evaluate the regularized least-squares objective at ``theta``.""" residual = self.matvec(theta) - self.target_vector() value = torch.dot(residual, residual) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) value = value + block.regularization.quadratic(theta[theta_slice]) return value
[docs] def accumulate_normal_equations(self) -> tuple[torch.Tensor, torch.Tensor]: """Materialize the Gram matrix and right-hand side from all batches.""" gram = torch.zeros( (self.layout.size, self.layout.size), dtype=self.dtype, device=self.device, ) rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for batch in self.batches: keys = tuple(batch.matrices) for key in keys: theta_slice = self.layout.theta_slice(key) block_matrix = batch.matrices[key] rhs[theta_slice] += _block_matrix_rmatvec(block_matrix, batch.target) for index_i, key_i in enumerate(keys): slice_i = self.layout.theta_slice(key_i) matrix_i = batch.matrices[key_i] gram[slice_i, slice_i] += _block_matrix_cross(matrix_i, matrix_i) for key_j in keys[index_i + 1 :]: slice_j = self.layout.theta_slice(key_j) cross = _block_matrix_cross(matrix_i, batch.matrices[key_j]) gram[slice_i, slice_j] += cross gram[slice_j, slice_i] += cross.T for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) gram[theta_slice, theta_slice] += block.regularization.materialize( dtype=self.dtype, device=self.device, ) rhs[theta_slice] += block.regularization.rhs( dtype=self.dtype, device=self.device, ) return gram, rhs
[docs] def solve( self, *, solver: str, cg_tolerance: float, cg_max_iter: int | None, progress: bool = False, progress_frequency: int = 10, initial_theta: torch.Tensor | None = None, cg_checkpoint_path: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, cg_checkpoint_metadata: dict[str, object] | None = None, fallback_theta: torch.Tensor | None = None, return_info: bool = False, ) -> torch.Tensor | "LinearSolveResult": """Solve the assembled problem with the selected dense or iterative backend.""" def fallback_result() -> "LinearSolveResult": if fallback_theta is not None: theta = fallback_theta.to(dtype=self.dtype, device=self.device) elif initial_theta is not None: theta = initial_theta.to(dtype=self.dtype, device=self.device) else: theta = torch.zeros( (self.layout.size,), dtype=self.dtype, device=self.device, ) return LinearSolveResult(theta=theta, interrupted=True) def maybe_return( result: "LinearSolveResult", ) -> torch.Tensor | "LinearSolveResult": return result if return_info else result.theta if solver == "dense_lstsq": try: matrix, target = self.materialize_design_matrix() reg_rows = [] reg_targets = [] for block in self.layout.blocks: if block.regularization is None: continue block_rows, block_target = block.regularization.least_squares_rows( dtype=self.dtype, device=self.device, ) if block_rows.shape[0] == 0: continue row_block = torch.zeros( (block_rows.shape[0], self.layout.size), dtype=self.dtype, device=self.device, ) row_block[:, self.layout.theta_slice(block.key)] = block_rows reg_rows.append(row_block) reg_targets.append(block_target) if reg_rows: matrix = torch.cat([matrix, *reg_rows], dim=0) target = torch.cat([target, *reg_targets], dim=0) return maybe_return( LinearSolveResult(torch.linalg.lstsq(matrix, target).solution) ) except KeyboardInterrupt: if not return_info: raise if progress: print( "Interrupted dense least-squares solve; " "using fallback coefficients." ) return fallback_result() if solver == "normal_equation_direct": try: gram, rhs = self.accumulate_normal_equations() if progress: print("Solving normal equations directly...") try: theta = torch.linalg.solve(gram, rhs) except RuntimeError: if progress: print( "Direct solve failed; falling back to torch.linalg.lstsq." ) theta = torch.linalg.lstsq(gram, rhs).solution return maybe_return(LinearSolveResult(theta)) except KeyboardInterrupt: if not return_info: raise if progress: print( "Interrupted normal-equation solve; " "using fallback coefficients." ) return fallback_result() if solver == "cg": try: rhs = self.rhs() if cg_checkpoint_metadata is None: checkpoint_metadata = _cg_checkpoint_metadata( n_parameters=self.layout.size, dtype=self.dtype, ) else: checkpoint_metadata = dict(cg_checkpoint_metadata) checkpoint_state = ( None if not cg_resume or cg_checkpoint_path is None else load_cg_checkpoint( cg_checkpoint_path, dtype=self.dtype, device=self.device, expected_metadata=checkpoint_metadata, ) ) result = _conjugate_gradient( self.normal_matvec, rhs, diagonal_preconditioner=self.regularization_diagonal() + self.normal_equation_diagonal(), tolerance=cg_tolerance, max_iter=cg_max_iter, progress=progress, progress_frequency=progress_frequency, initial_guess=initial_theta, checkpoint_state=checkpoint_state, checkpoint_path=cg_checkpoint_path, checkpoint_frequency=cg_checkpoint_frequency, checkpoint_metadata=checkpoint_metadata, handle_interrupts=return_info, ) return maybe_return(result) except KeyboardInterrupt: if not return_info: raise if progress: print("Interrupted CG setup; using fallback coefficients.") return fallback_result() choices = ", ".join(["dense_lstsq", "normal_equation_direct", "cg"]) raise ValueError(f"Unsupported solver '{solver}'. Expected one of: {choices}.")
def _normal_equation_diagonal(problem: BlockLinearProblem) -> torch.Tensor: """Return the diagonal of ``A^T A`` for simple preconditioning.""" diagonal = torch.zeros( (problem.layout.size,), dtype=problem.dtype, device=problem.device, ) for batch in problem.batches: for key, block_matrix in batch.matrices.items(): diagonal[problem.layout.theta_slice(key)] += _block_matrix_diagonal( block_matrix ) return diagonal def _conjugate_gradient( matvec, rhs: torch.Tensor, *, diagonal_preconditioner: torch.Tensor, tolerance: float, max_iter: int | None, progress: bool, progress_frequency: int, initial_guess: torch.Tensor | None = None, checkpoint_state: CGCheckpointState | None = None, checkpoint_path: Path | str | None = None, checkpoint_frequency: int = 1, checkpoint_metadata: dict[str, object] | None = None, handle_interrupts: bool = False, ) -> "LinearSolveResult": """Solve a system with diagonally preconditioned CG.""" if max_iter is None: max_iter = max(10, 4 * rhs.numel()) progress_frequency = max(int(progress_frequency), 1) safe_diag = torch.where( torch.abs(diagonal_preconditioner) > 1.0e-14, diagonal_preconditioner, torch.ones_like(diagonal_preconditioner), ) if checkpoint_state is not None: x = checkpoint_state.x.to(dtype=rhs.dtype, device=rhs.device).reshape_as(rhs) residual = checkpoint_state.residual.to( dtype=rhs.dtype, device=rhs.device, ).reshape_as(rhs) direction = checkpoint_state.direction.to( dtype=rhs.dtype, device=rhs.device, ).reshape_as(rhs) rz_old = checkpoint_state.rz_old.to(dtype=rhs.dtype, device=rhs.device) start_iteration = int(checkpoint_state.iteration) else: x = ( torch.zeros_like(rhs) if initial_guess is None else initial_guess.to(dtype=rhs.dtype, device=rhs.device).reshape_as(rhs) ) residual = rhs - matvec(x) z = residual / safe_diag direction = z.clone() rz_old = torch.dot(residual, z) start_iteration = 0 def write_checkpoint(iteration: int) -> None: if checkpoint_path is None: return save_cg_checkpoint( checkpoint_path, CGCheckpointState( x=x, residual=residual, direction=direction, rz_old=rz_old, iteration=int(iteration), metadata={} if checkpoint_metadata is None else checkpoint_metadata, ), ) residual_norm = torch.linalg.norm(residual).item() if progress: resume_suffix = ( "" if start_iteration == 0 else f", resumed_at={start_iteration}" ) print( "CG start: " f"||residual||={residual_norm:.6e}, " f"tolerance={tolerance:.3e}, max_iter={max_iter}{resume_suffix}" ) if residual_norm <= tolerance: if progress: print("CG converged without iterations.") write_checkpoint(start_iteration) return LinearSolveResult(theta=x) checkpoint_frequency = max(int(checkpoint_frequency), 1) completed_iteration = start_iteration interrupted = False restored_checkpoint_path = None try: for iteration in range(start_iteration, max_iter): mat_direction = matvec(direction) denom = torch.dot(direction, mat_direction) if torch.abs(denom) <= 1.0e-30: if progress: print( f"CG stopped at iteration {iteration + 1}: " "near-zero denominator." ) break alpha = rz_old / denom x = x + alpha * direction residual = residual - alpha * mat_direction residual_norm = torch.linalg.norm(residual).item() completed_iteration = iteration + 1 if progress and ( iteration == 0 or completed_iteration % progress_frequency == 0 or residual_norm <= tolerance ): print( f"CG iter {completed_iteration}: " f"||residual||={residual_norm:.6e}" ) if residual_norm <= tolerance: write_checkpoint(completed_iteration) break z = residual / safe_diag rz_new = torch.dot(residual, z) if torch.abs(rz_old) <= 1.0e-30: if progress: print( f"CG stopped at iteration {iteration + 1}: " "near-zero rz_old." ) break beta = rz_new / rz_old direction = z + beta * direction rz_old = rz_new if completed_iteration % checkpoint_frequency == 0: write_checkpoint(completed_iteration) except KeyboardInterrupt: if not handle_interrupts: raise interrupted = True if progress: print( "Interrupted CG solve; saving the latest complete iterate " f"at iteration {completed_iteration}." ) write_checkpoint(completed_iteration) if checkpoint_path is not None: checkpoint = load_cg_checkpoint( checkpoint_path, dtype=rhs.dtype, device=rhs.device, expected_metadata=checkpoint_metadata, ) x = checkpoint.x.reshape_as(rhs) restored_checkpoint_path = str(Path(checkpoint_path)) if progress: print(f"CG done: ||residual||={torch.linalg.norm(residual).item():.6e}") write_checkpoint(completed_iteration) return LinearSolveResult( theta=x, interrupted=interrupted, restored_checkpoint_path=restored_checkpoint_path, )
[docs] @dataclass(frozen=True) class LinearSolveResult: """Parameter vector returned by one linear solve with interrupt metadata.""" theta: torch.Tensor interrupted: bool = False restored_checkpoint_path: str | None = None
[docs] @dataclass(frozen=True) class LinearFitResult: """Summary of one direct linear fit together with its assembled problem.""" theta: torch.Tensor objective: float residual_norm: float solver: str n_rows: int n_parameters: int layout: ParameterLayout problem: BlockLinearProblem interrupted: bool = False restored_checkpoint_path: str | None = None
[docs] @dataclass(frozen=True) class CGCheckpointState: """Restart state for one conjugate-gradient solve.""" x: torch.Tensor residual: torch.Tensor direction: torch.Tensor rz_old: torch.Tensor iteration: int metadata: dict[str, object]
def _cg_checkpoint_metadata( *, n_parameters: int, dtype: torch.dtype, ) -> dict[str, object]: """Return validation metadata for a CG checkpoint.""" return { "schema_version": 1, "n_parameters": int(n_parameters), "dtype": str(dtype), } def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: """Convert one tensor to a detached CPU numpy array.""" return tensor.detach().cpu().numpy()
[docs] def save_cg_checkpoint( path: Path | str, state: CGCheckpointState, ) -> None: """Persist one conjugate-gradient restart state as an ``.npz`` file.""" checkpoint_path = Path(path) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) temporary_path = checkpoint_path.with_name(f"{checkpoint_path.name}.tmp") with temporary_path.open("wb") as handle: np.savez( handle, x=_tensor_to_numpy(state.x), residual=_tensor_to_numpy(state.residual), direction=_tensor_to_numpy(state.direction), rz_old=_tensor_to_numpy(state.rz_old.reshape(())), iteration=np.asarray([int(state.iteration)], dtype=np.int64), metadata=np.asarray(json.dumps(state.metadata, sort_keys=True)), ) temporary_path.replace(checkpoint_path)
[docs] def load_cg_checkpoint( path: Path | str, *, dtype: torch.dtype | None = None, device: torch.device | None = None, expected_metadata: dict[str, object] | None = None, ) -> CGCheckpointState: """Load and validate one conjugate-gradient restart state.""" checkpoint_path = Path(path) with np.load(checkpoint_path) as data: metadata = json.loads(str(data["metadata"].item())) if expected_metadata is not None and metadata != expected_metadata: raise ValueError("CG checkpoint metadata does not match this problem") x = torch.as_tensor(data["x"]) residual = torch.as_tensor(data["residual"]) direction = torch.as_tensor(data["direction"]) rz_old = torch.as_tensor(data["rz_old"]).reshape(()) if dtype is not None or device is not None: x = x.to(dtype=dtype, device=device) residual = residual.to(dtype=dtype, device=device) direction = direction.to(dtype=dtype, device=device) rz_old = rz_old.to(dtype=dtype, device=device) return CGCheckpointState( x=x, residual=residual, direction=direction, rz_old=rz_old, iteration=int(np.asarray(data["iteration"]).reshape(-1)[0]), metadata=metadata, )
__all__ = [ "BlockLinearProblem", "CGCheckpointState", "LinearFitResult", "LinearSolveResult", "load_cg_checkpoint", "save_cg_checkpoint", ]