"""Core matrix-free least-squares problem and iterative solver."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
from ufp.leastsquares._block import (
BlockProblemLayout,
BlockSolveBatch,
_block_matrix_cross,
_block_matrix_diagonal,
_block_matrix_matvec,
_block_matrix_rmatvec,
_materialize_block_matrix,
)
from ufp.leastsquares._layout import ParameterLayout
[docs]
@dataclass
class BlockLinearProblem:
"""Matrix-free linear least-squares problem assembled from block batches."""
layout: BlockProblemLayout
batches: Sequence[BlockSolveBatch]
@property
def n_rows(self) -> int:
"""Return the total number of target rows across all batches."""
return sum(batch.n_rows for batch in self.batches)
@property
def dtype(self) -> torch.dtype:
"""Return the dtype shared by the assembled problem tensors."""
if self.batches:
return self.batches[0].target.dtype
return torch.get_default_dtype()
@property
def device(self) -> torch.device:
"""Return the device shared by the assembled problem tensors."""
if self.batches:
return self.batches[0].target.device
return torch.device("cpu")
[docs]
def target_vector(self) -> torch.Tensor:
"""Concatenate all batch targets into one right-hand-side vector."""
if not self.batches:
return torch.zeros(0, dtype=self.dtype, device=self.device)
return torch.cat([batch.target for batch in self.batches], dim=0)
[docs]
def materialize_design_matrix(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Build the explicit dense design matrix for debugging or tiny problems."""
matrix = torch.zeros(
(self.n_rows, self.layout.size),
dtype=self.dtype,
device=self.device,
)
offset = 0
for batch in self.batches:
for key, block_matrix in batch.matrices.items():
matrix[offset : offset + batch.n_rows, self.layout.theta_slice(key)] = (
_materialize_block_matrix(block_matrix)
)
offset += batch.n_rows
return matrix, self.target_vector()
[docs]
def matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply the design matrix to a flat parameter vector."""
theta = theta.reshape(self.layout.size)
outputs: list[torch.Tensor] = []
for batch in self.batches:
prediction = torch.zeros(
(batch.n_rows,),
dtype=batch.target.dtype,
device=batch.target.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
)
outputs.append(prediction)
if not outputs:
return torch.zeros(0, dtype=self.dtype, device=self.device)
return torch.cat(outputs, dim=0)
[docs]
def rmatvec(self, residual: torch.Tensor) -> torch.Tensor:
"""Apply the transpose design matrix to a residual vector."""
residual = residual.reshape(self.n_rows)
output = torch.zeros(
(self.layout.size,),
dtype=residual.dtype,
device=residual.device,
)
offset = 0
for batch in self.batches:
batch_residual = residual[offset : offset + batch.n_rows]
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
batch_residual,
)
offset += batch.n_rows
return output
[docs]
def regularization_apply(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply all block regularizers to a flat parameter vector."""
theta = theta.reshape(self.layout.size)
output = torch.zeros_like(theta)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
output[theta_slice] += block.regularization.apply(theta[theta_slice])
return output
[docs]
def regularization_diagonal(self) -> torch.Tensor:
"""Return the summed diagonal preconditioner implied by block regularizers."""
diag = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
diag[theta_slice] += block.regularization.diagonal(
dtype=self.dtype,
device=self.device,
)
return diag
[docs]
def regularization_rhs(self) -> torch.Tensor:
"""Return the summed RHS shifts implied by block regularizers."""
rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
rhs[theta_slice] += block.regularization.rhs(
dtype=self.dtype,
device=self.device,
)
return rhs
[docs]
def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply the regularized normal-equation operator to a flat vector."""
return self.rmatvec(self.matvec(theta)) + self.regularization_apply(theta)
[docs]
def normal_equation_diagonal(self) -> torch.Tensor:
"""Return the diagonal of the weighted design ``A.T @ A``."""
return _normal_equation_diagonal(self)
[docs]
def design_trace_by_block(self) -> dict[object, float]:
"""Return weighted design-matrix trace contributions by solve block."""
diagonal = self.normal_equation_diagonal()
return {
block.key: float(diagonal[self.layout.theta_slice(block.key)].sum().item())
for block in self.layout.blocks
}
[docs]
def rhs(self) -> torch.Tensor:
"""Return the right-hand side of the normal equations."""
return self.rmatvec(self.target_vector()) + self.regularization_rhs()
[docs]
def objective(self, theta: torch.Tensor) -> torch.Tensor:
"""Evaluate the regularized least-squares objective at ``theta``."""
residual = self.matvec(theta) - self.target_vector()
value = torch.dot(residual, residual)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
value = value + block.regularization.quadratic(theta[theta_slice])
return value
[docs]
def accumulate_normal_equations(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Materialize the Gram matrix and right-hand side from all batches."""
gram = torch.zeros(
(self.layout.size, self.layout.size),
dtype=self.dtype,
device=self.device,
)
rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for batch in self.batches:
keys = tuple(batch.matrices)
for key in keys:
theta_slice = self.layout.theta_slice(key)
block_matrix = batch.matrices[key]
rhs[theta_slice] += _block_matrix_rmatvec(block_matrix, batch.target)
for index_i, key_i in enumerate(keys):
slice_i = self.layout.theta_slice(key_i)
matrix_i = batch.matrices[key_i]
gram[slice_i, slice_i] += _block_matrix_cross(matrix_i, matrix_i)
for key_j in keys[index_i + 1 :]:
slice_j = self.layout.theta_slice(key_j)
cross = _block_matrix_cross(matrix_i, batch.matrices[key_j])
gram[slice_i, slice_j] += cross
gram[slice_j, slice_i] += cross.T
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
gram[theta_slice, theta_slice] += block.regularization.materialize(
dtype=self.dtype,
device=self.device,
)
rhs[theta_slice] += block.regularization.rhs(
dtype=self.dtype,
device=self.device,
)
return gram, rhs
[docs]
def solve(
self,
*,
solver: str,
cg_tolerance: float,
cg_max_iter: int | None,
progress: bool = False,
progress_frequency: int = 10,
initial_theta: torch.Tensor | None = None,
cg_checkpoint_path: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
cg_checkpoint_metadata: dict[str, object] | None = None,
fallback_theta: torch.Tensor | None = None,
return_info: bool = False,
) -> torch.Tensor | "LinearSolveResult":
"""Solve the assembled problem with the selected dense or iterative backend."""
def fallback_result() -> "LinearSolveResult":
if fallback_theta is not None:
theta = fallback_theta.to(dtype=self.dtype, device=self.device)
elif initial_theta is not None:
theta = initial_theta.to(dtype=self.dtype, device=self.device)
else:
theta = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
return LinearSolveResult(theta=theta, interrupted=True)
def maybe_return(
result: "LinearSolveResult",
) -> torch.Tensor | "LinearSolveResult":
return result if return_info else result.theta
if solver == "dense_lstsq":
try:
matrix, target = self.materialize_design_matrix()
reg_rows = []
reg_targets = []
for block in self.layout.blocks:
if block.regularization is None:
continue
block_rows, block_target = block.regularization.least_squares_rows(
dtype=self.dtype,
device=self.device,
)
if block_rows.shape[0] == 0:
continue
row_block = torch.zeros(
(block_rows.shape[0], self.layout.size),
dtype=self.dtype,
device=self.device,
)
row_block[:, self.layout.theta_slice(block.key)] = block_rows
reg_rows.append(row_block)
reg_targets.append(block_target)
if reg_rows:
matrix = torch.cat([matrix, *reg_rows], dim=0)
target = torch.cat([target, *reg_targets], dim=0)
return maybe_return(
LinearSolveResult(torch.linalg.lstsq(matrix, target).solution)
)
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print(
"Interrupted dense least-squares solve; "
"using fallback coefficients."
)
return fallback_result()
if solver == "normal_equation_direct":
try:
gram, rhs = self.accumulate_normal_equations()
if progress:
print("Solving normal equations directly...")
try:
theta = torch.linalg.solve(gram, rhs)
except RuntimeError:
if progress:
print(
"Direct solve failed; falling back to torch.linalg.lstsq."
)
theta = torch.linalg.lstsq(gram, rhs).solution
return maybe_return(LinearSolveResult(theta))
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print(
"Interrupted normal-equation solve; "
"using fallback coefficients."
)
return fallback_result()
if solver == "cg":
try:
rhs = self.rhs()
if cg_checkpoint_metadata is None:
checkpoint_metadata = _cg_checkpoint_metadata(
n_parameters=self.layout.size,
dtype=self.dtype,
)
else:
checkpoint_metadata = dict(cg_checkpoint_metadata)
checkpoint_state = (
None
if not cg_resume or cg_checkpoint_path is None
else load_cg_checkpoint(
cg_checkpoint_path,
dtype=self.dtype,
device=self.device,
expected_metadata=checkpoint_metadata,
)
)
result = _conjugate_gradient(
self.normal_matvec,
rhs,
diagonal_preconditioner=self.regularization_diagonal()
+ self.normal_equation_diagonal(),
tolerance=cg_tolerance,
max_iter=cg_max_iter,
progress=progress,
progress_frequency=progress_frequency,
initial_guess=initial_theta,
checkpoint_state=checkpoint_state,
checkpoint_path=cg_checkpoint_path,
checkpoint_frequency=cg_checkpoint_frequency,
checkpoint_metadata=checkpoint_metadata,
handle_interrupts=return_info,
)
return maybe_return(result)
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print("Interrupted CG setup; using fallback coefficients.")
return fallback_result()
choices = ", ".join(["dense_lstsq", "normal_equation_direct", "cg"])
raise ValueError(f"Unsupported solver '{solver}'. Expected one of: {choices}.")
def _normal_equation_diagonal(problem: BlockLinearProblem) -> torch.Tensor:
"""Return the diagonal of ``A^T A`` for simple preconditioning."""
diagonal = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
for batch in problem.batches:
for key, block_matrix in batch.matrices.items():
diagonal[problem.layout.theta_slice(key)] += _block_matrix_diagonal(
block_matrix
)
return diagonal
def _conjugate_gradient(
matvec,
rhs: torch.Tensor,
*,
diagonal_preconditioner: torch.Tensor,
tolerance: float,
max_iter: int | None,
progress: bool,
progress_frequency: int,
initial_guess: torch.Tensor | None = None,
checkpoint_state: CGCheckpointState | None = None,
checkpoint_path: Path | str | None = None,
checkpoint_frequency: int = 1,
checkpoint_metadata: dict[str, object] | None = None,
handle_interrupts: bool = False,
) -> "LinearSolveResult":
"""Solve a system with diagonally preconditioned CG."""
if max_iter is None:
max_iter = max(10, 4 * rhs.numel())
progress_frequency = max(int(progress_frequency), 1)
safe_diag = torch.where(
torch.abs(diagonal_preconditioner) > 1.0e-14,
diagonal_preconditioner,
torch.ones_like(diagonal_preconditioner),
)
if checkpoint_state is not None:
x = checkpoint_state.x.to(dtype=rhs.dtype, device=rhs.device).reshape_as(rhs)
residual = checkpoint_state.residual.to(
dtype=rhs.dtype,
device=rhs.device,
).reshape_as(rhs)
direction = checkpoint_state.direction.to(
dtype=rhs.dtype,
device=rhs.device,
).reshape_as(rhs)
rz_old = checkpoint_state.rz_old.to(dtype=rhs.dtype, device=rhs.device)
start_iteration = int(checkpoint_state.iteration)
else:
x = (
torch.zeros_like(rhs)
if initial_guess is None
else initial_guess.to(dtype=rhs.dtype, device=rhs.device).reshape_as(rhs)
)
residual = rhs - matvec(x)
z = residual / safe_diag
direction = z.clone()
rz_old = torch.dot(residual, z)
start_iteration = 0
def write_checkpoint(iteration: int) -> None:
if checkpoint_path is None:
return
save_cg_checkpoint(
checkpoint_path,
CGCheckpointState(
x=x,
residual=residual,
direction=direction,
rz_old=rz_old,
iteration=int(iteration),
metadata={} if checkpoint_metadata is None else checkpoint_metadata,
),
)
residual_norm = torch.linalg.norm(residual).item()
if progress:
resume_suffix = (
"" if start_iteration == 0 else f", resumed_at={start_iteration}"
)
print(
"CG start: "
f"||residual||={residual_norm:.6e}, "
f"tolerance={tolerance:.3e}, max_iter={max_iter}{resume_suffix}"
)
if residual_norm <= tolerance:
if progress:
print("CG converged without iterations.")
write_checkpoint(start_iteration)
return LinearSolveResult(theta=x)
checkpoint_frequency = max(int(checkpoint_frequency), 1)
completed_iteration = start_iteration
interrupted = False
restored_checkpoint_path = None
try:
for iteration in range(start_iteration, max_iter):
mat_direction = matvec(direction)
denom = torch.dot(direction, mat_direction)
if torch.abs(denom) <= 1.0e-30:
if progress:
print(
f"CG stopped at iteration {iteration + 1}: "
"near-zero denominator."
)
break
alpha = rz_old / denom
x = x + alpha * direction
residual = residual - alpha * mat_direction
residual_norm = torch.linalg.norm(residual).item()
completed_iteration = iteration + 1
if progress and (
iteration == 0
or completed_iteration % progress_frequency == 0
or residual_norm <= tolerance
):
print(
f"CG iter {completed_iteration}: "
f"||residual||={residual_norm:.6e}"
)
if residual_norm <= tolerance:
write_checkpoint(completed_iteration)
break
z = residual / safe_diag
rz_new = torch.dot(residual, z)
if torch.abs(rz_old) <= 1.0e-30:
if progress:
print(
f"CG stopped at iteration {iteration + 1}: "
"near-zero rz_old."
)
break
beta = rz_new / rz_old
direction = z + beta * direction
rz_old = rz_new
if completed_iteration % checkpoint_frequency == 0:
write_checkpoint(completed_iteration)
except KeyboardInterrupt:
if not handle_interrupts:
raise
interrupted = True
if progress:
print(
"Interrupted CG solve; saving the latest complete iterate "
f"at iteration {completed_iteration}."
)
write_checkpoint(completed_iteration)
if checkpoint_path is not None:
checkpoint = load_cg_checkpoint(
checkpoint_path,
dtype=rhs.dtype,
device=rhs.device,
expected_metadata=checkpoint_metadata,
)
x = checkpoint.x.reshape_as(rhs)
restored_checkpoint_path = str(Path(checkpoint_path))
if progress:
print(f"CG done: ||residual||={torch.linalg.norm(residual).item():.6e}")
write_checkpoint(completed_iteration)
return LinearSolveResult(
theta=x,
interrupted=interrupted,
restored_checkpoint_path=restored_checkpoint_path,
)
[docs]
@dataclass(frozen=True)
class LinearSolveResult:
"""Parameter vector returned by one linear solve with interrupt metadata."""
theta: torch.Tensor
interrupted: bool = False
restored_checkpoint_path: str | None = None
[docs]
@dataclass(frozen=True)
class LinearFitResult:
"""Summary of one direct linear fit together with its assembled problem."""
theta: torch.Tensor
objective: float
residual_norm: float
solver: str
n_rows: int
n_parameters: int
layout: ParameterLayout
problem: BlockLinearProblem
interrupted: bool = False
restored_checkpoint_path: str | None = None
[docs]
@dataclass(frozen=True)
class CGCheckpointState:
"""Restart state for one conjugate-gradient solve."""
x: torch.Tensor
residual: torch.Tensor
direction: torch.Tensor
rz_old: torch.Tensor
iteration: int
metadata: dict[str, object]
def _cg_checkpoint_metadata(
*,
n_parameters: int,
dtype: torch.dtype,
) -> dict[str, object]:
"""Return validation metadata for a CG checkpoint."""
return {
"schema_version": 1,
"n_parameters": int(n_parameters),
"dtype": str(dtype),
}
def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Convert one tensor to a detached CPU numpy array."""
return tensor.detach().cpu().numpy()
[docs]
def save_cg_checkpoint(
path: Path | str,
state: CGCheckpointState,
) -> None:
"""Persist one conjugate-gradient restart state as an ``.npz`` file."""
checkpoint_path = Path(path)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
temporary_path = checkpoint_path.with_name(f"{checkpoint_path.name}.tmp")
with temporary_path.open("wb") as handle:
np.savez(
handle,
x=_tensor_to_numpy(state.x),
residual=_tensor_to_numpy(state.residual),
direction=_tensor_to_numpy(state.direction),
rz_old=_tensor_to_numpy(state.rz_old.reshape(())),
iteration=np.asarray([int(state.iteration)], dtype=np.int64),
metadata=np.asarray(json.dumps(state.metadata, sort_keys=True)),
)
temporary_path.replace(checkpoint_path)
[docs]
def load_cg_checkpoint(
path: Path | str,
*,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
expected_metadata: dict[str, object] | None = None,
) -> CGCheckpointState:
"""Load and validate one conjugate-gradient restart state."""
checkpoint_path = Path(path)
with np.load(checkpoint_path) as data:
metadata = json.loads(str(data["metadata"].item()))
if expected_metadata is not None and metadata != expected_metadata:
raise ValueError("CG checkpoint metadata does not match this problem")
x = torch.as_tensor(data["x"])
residual = torch.as_tensor(data["residual"])
direction = torch.as_tensor(data["direction"])
rz_old = torch.as_tensor(data["rz_old"]).reshape(())
if dtype is not None or device is not None:
x = x.to(dtype=dtype, device=device)
residual = residual.to(dtype=dtype, device=device)
direction = direction.to(dtype=dtype, device=device)
rz_old = rz_old.to(dtype=dtype, device=device)
return CGCheckpointState(
x=x,
residual=residual,
direction=direction,
rz_old=rz_old,
iteration=int(np.asarray(data["iteration"]).reshape(-1)[0]),
metadata=metadata,
)
__all__ = [
"BlockLinearProblem",
"CGCheckpointState",
"LinearFitResult",
"LinearSolveResult",
"load_cg_checkpoint",
"save_cg_checkpoint",
]