"""Disk cache metadata, memmap IO, and streamed cached problem support."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Callable, Literal, Sequence
import numpy as np
import torch
from ufp.core._disk_cache import (
atomic_write_json,
atomic_write_npy_memmap,
settings_cache_dir,
write_cache_settings_summary,
)
from ufp.leastsquares._block import (
BlockMatrix,
BlockProblemLayout,
BlockSolveBatch,
ColumnRowIndexedBlockMatrix,
ColumnRowIndexedChunk,
RowIndexedBlockMatrix,
_apply_row_weights_to_assembled_batch,
_block_matrix_cross,
_block_matrix_diagonal,
_block_matrix_matvec,
_block_matrix_rmatvec,
_block_solve_batch_from_assembled,
_compact_block_matrix,
_compact_column_chunked_block_matrix,
_materialize_block_matrix,
)
from ufp.leastsquares._layout import ParameterLayout
from ufp.leastsquares._problem import (
LinearSolveResult,
_cg_checkpoint_metadata,
_conjugate_gradient,
load_cg_checkpoint,
)
from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares._utils import _iter_with_progress
from ufp.leastsquares.dataset import FitSample
_ASSEMBLED_CACHE_SCHEMA_VERSION = 7
_NORMAL_EQUATION_CACHE_SCHEMA_VERSION = 3
AssembledBatchCacheMode = Literal["auto", "read", "write", "refresh"]
def _hash_array(hasher, value: object, *, dtype: np.dtype | None = None) -> None:
"""Add one shaped array to a cache metadata hash."""
array = np.ascontiguousarray(np.asarray(value, dtype=dtype))
hasher.update(str(array.dtype).encode("utf8"))
hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
hasher.update(array.tobytes())
def _to_numpy_for_hash(value: object) -> np.ndarray:
"""Convert torch or numpy-like arrays to a CPU numpy array for hashing."""
if isinstance(value, torch.Tensor):
return value.detach().cpu().numpy()
return np.asarray(value)
def _sample_signature(
samples: Sequence[FitSample],
*,
include_weights: bool = True,
) -> str:
"""Return a stable digest for input structures, targets, and weights."""
hasher = hashlib.sha256()
for sample in samples:
atoms = sample.atoms
hasher.update(b"sample")
_hash_array(hasher, atoms.numbers, dtype=np.int64)
_hash_array(hasher, atoms.positions, dtype=np.float64)
_hash_array(hasher, atoms.cell.array, dtype=np.float64)
_hash_array(hasher, atoms.pbc, dtype=np.bool_)
if sample.neighbor_list is None:
hasher.update(b"neighbor_list:none")
else:
neighbor_list = sample.neighbor_list
hasher.update(b"neighbor_list")
_hash_array(hasher, _to_numpy_for_hash(neighbor_list.pairs))
_hash_array(hasher, _to_numpy_for_hash(neighbor_list.shifts))
if neighbor_list.distances is not None:
_hash_array(
hasher,
_to_numpy_for_hash(neighbor_list.distances),
dtype=np.float64,
)
if neighbor_list.vectors is not None:
_hash_array(
hasher,
_to_numpy_for_hash(neighbor_list.vectors),
dtype=np.float64,
)
hasher.update(str(bool(neighbor_list.full_list)).encode("utf8"))
hasher.update(str(sample.energy is not None).encode("utf8"))
if sample.energy is not None:
_hash_array(hasher, [sample.energy], dtype=np.float64)
hasher.update(str(sample.forces is not None).encode("utf8"))
if sample.forces is not None:
_hash_array(hasher, sample.forces, dtype=np.float64)
hasher.update(str(sample.per_atom_energy is not None).encode("utf8"))
if sample.per_atom_energy is not None:
_hash_array(hasher, sample.per_atom_energy, dtype=np.float64)
if include_weights:
_hash_array(
hasher,
[sample.energy_weight, sample.force_weight, sample.per_atom_weight],
dtype=np.float64,
)
return hasher.hexdigest()
def _layout_signature(layout: ParameterLayout) -> list[dict[str, object]]:
"""Return stable metadata for the least-squares parameter layout."""
signature = []
for block in layout.blocks:
provider = block.coefficient_provider
provider_signature = None
if provider is not None:
provider_signature = {
"coefficient_shape": [int(dim) for dim in provider.coefficient_shape],
"n_proxy_terms": int(provider.n_proxy_terms),
"n_true_terms": int(provider.n_true_terms),
"uses_identity_weights": bool(provider.uses_identity_weights),
}
signature.append(
{
"index": int(block.index),
"name": block.name,
"kind": block.kind,
"label": block.label,
"shape": [int(dim) for dim in block.shape],
"start": int(block.start),
"stop": int(block.stop),
"coefficient_index": block.coefficient_index,
"regularization_group": block.regularization_group,
"assembler": (
block.assembler
if isinstance(block.assembler, str)
else type(block.assembler).__name__
),
"provider": provider_signature,
}
)
return signature
def _cache_metadata_for_fit(
*,
layout: ParameterLayout,
samples: Sequence[FitSample],
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
dtype: torch.dtype,
batch_size: int,
include_sample_weights: bool = True,
) -> dict[str, object]:
"""Return cache metadata needed to detect stale assembled batches."""
return {
"sample_count": len(samples),
"sample_signature": _sample_signature(
samples,
include_weights=include_sample_weights,
),
"layout_size": int(layout.size),
"layout": _layout_signature(layout),
"fit_energy": bool(fit_energy),
"fit_forces": bool(fit_forces),
"fit_per_atom_energy": bool(fit_per_atom_energy),
"dtype": str(dtype),
"batch_size": int(batch_size),
}
def _cache_identity_metadata(
metadata: dict[str, object] | None,
) -> dict[str, object]:
"""Return metadata used for cache identity comparisons and directory names."""
if metadata is None:
return {}
return dict(metadata)
def _cache_metadata_matches(
cached: dict[str, object] | None,
expected: dict[str, object],
) -> bool:
"""Return whether cached assembled batches match the requested fit."""
return _cache_identity_metadata(cached) == _cache_identity_metadata(expected)
def _cache_problem_metadata_matches(
cached: dict[str, object] | None,
expected: dict[str, object],
) -> bool:
"""Return whether cache row/problem metadata is cross-layout compatible."""
if cached is None:
return False
ignored = {
"layout",
"layout_size",
"selected_block_indices",
"coefficient_selection",
"fixed_coefficients_signature",
"cache_blocks",
"cache_reusable",
}
cached_comparable = {
key: value for key, value in cached.items() if key not in ignored
}
expected_comparable = {
key: value for key, value in expected.items() if key not in ignored
}
return cached_comparable == expected_comparable
def _cache_metadata_can_project(
cached: dict[str, object] | None,
expected: dict[str, object],
) -> bool:
"""Return whether cached semantic blocks may be projected to this layout."""
if cached is None:
return False
if not bool(cached.get("cache_reusable")) or not bool(
expected.get("cache_reusable")
):
return False
if "cache_blocks" not in cached or "cache_blocks" not in expected:
return False
return _cache_problem_metadata_matches(cached, expected)
def _cache_metadata_mismatch_reasons(
cached: dict[str, object] | None,
expected: dict[str, object],
) -> tuple[str, ...]:
"""Return concise human-readable differences between cache metadata payloads."""
if cached is None:
return ("missing metadata",)
cached_comparable = _cache_identity_metadata(cached)
expected_comparable = _cache_identity_metadata(expected)
reasons = []
labels = {
"sample_count": "sample count",
"sample_signature": "sample geometry or targets",
"layout_size": "parameter count",
"layout": "model parameter layout",
"cache_blocks": "semantic cache block layout",
"fit_energy": "energy target setting",
"fit_forces": "force target setting",
"fit_per_atom_energy": "per-atom energy target setting",
"dtype": "dtype",
"batch_size": "batch size",
}
for key in sorted(set(cached_comparable) | set(expected_comparable)):
if cached_comparable.get(key) != expected_comparable.get(key):
reasons.append(labels.get(key, key))
return tuple(reasons)
def _assembled_cache_metadata_for_fit(
*,
layout: ParameterLayout,
samples: Sequence[FitSample],
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
dtype: torch.dtype,
batch_size: int,
) -> dict[str, object]:
"""Return assembled-cache metadata independent of current target weights."""
metadata = _cache_metadata_for_fit(
layout=layout,
samples=samples,
fit_energy=fit_energy,
fit_forces=fit_forces,
fit_per_atom_energy=fit_per_atom_energy,
dtype=dtype,
batch_size=batch_size,
include_sample_weights=False,
)
metadata["assembled_components"] = "column_chunked_per_atom_energy_rows_v1"
return metadata
def _normal_equation_cache_metadata_for_fit(
*,
layout: ParameterLayout,
samples: Sequence[FitSample],
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
dtype: torch.dtype,
batch_size: int,
) -> dict[str, object]:
"""Return cache metadata for split unweighted normal-equation components."""
metadata = _cache_metadata_for_fit(
layout=layout,
samples=samples,
fit_energy=fit_energy,
fit_forces=fit_forces,
fit_per_atom_energy=fit_per_atom_energy,
dtype=dtype,
batch_size=batch_size,
include_sample_weights=False,
)
metadata["normal_equation_components"] = "per_atom_energy_force_split_v1"
return metadata
def assembled_cache_dir(
parent: Path | str,
metadata: dict[str, object],
) -> Path:
"""Return the settings-named child directory for assembled batches."""
return settings_cache_dir(
parent,
"assembled",
_cache_identity_metadata(metadata),
)
def normal_equation_cache_dir(
parent: Path | str,
metadata: dict[str, object],
) -> Path:
"""Return the settings-named child directory for normal equations."""
return settings_cache_dir(
parent,
"normal_equations",
_cache_identity_metadata(metadata),
)
def _uniform_target_weight(
samples: Sequence[FitSample],
*,
target_name: str,
value_name: str,
weight_name: str,
enabled: bool,
) -> float:
"""Return a uniform target weight, rejecting per-sample reweighting."""
if not enabled:
return 0.0
weights = [
float(getattr(sample, weight_name))
for sample in samples
if getattr(sample, value_name) is not None
]
if not weights:
return 0.0
first = weights[0]
if any(
not np.isclose(weight, first, rtol=1.0e-12, atol=1.0e-12) for weight in weights
):
raise ValueError(
"normal-equation caching requires uniform "
f"{target_name} weights across active samples"
)
return first
def _normal_equation_target_weights(
samples: Sequence[FitSample],
*,
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
) -> tuple[float, float]:
"""Return global energy and force weights for cached normal equations."""
if fit_per_atom_energy and any(
sample.per_atom_energy is not None for sample in samples
):
raise ValueError(
"normal-equation caching currently supports energy and force targets only"
)
return (
_uniform_target_weight(
samples,
target_name="energy",
value_name="energy",
weight_name="energy_weight",
enabled=fit_energy,
),
_uniform_target_weight(
samples,
target_name="force",
value_name="forces",
weight_name="force_weight",
enabled=fit_forces,
),
)
def _samples_with_unit_target_weights(
samples: Sequence[FitSample],
) -> tuple[FitSample, ...]:
"""Return samples with target values unchanged and target weights set to one."""
return tuple(
replace(
sample,
energy_weight=1.0,
force_weight=1.0,
per_atom_weight=1.0,
)
for sample in samples
)
@dataclass
class NormalEquationComponents:
"""Unweighted split normal-equation components for energy and force rows."""
energy_gram: torch.Tensor
energy_rhs: torch.Tensor
force_gram: torch.Tensor
force_rhs: torch.Tensor
energy_target_norm: torch.Tensor
force_target_norm: torch.Tensor
n_energy_rows: int
n_force_rows: int
@property
def n_rows(self) -> int:
"""Return the number of cached target rows."""
return int(self.n_energy_rows + self.n_force_rows)
def _normal_equation_components_to(
components: NormalEquationComponents,
*,
dtype: torch.dtype,
device: torch.device,
) -> NormalEquationComponents:
"""Move split normal-equation components to a solve device."""
return NormalEquationComponents(
energy_gram=components.energy_gram.to(dtype=dtype, device=device),
energy_rhs=components.energy_rhs.to(dtype=dtype, device=device),
force_gram=components.force_gram.to(dtype=dtype, device=device),
force_rhs=components.force_rhs.to(dtype=dtype, device=device),
energy_target_norm=components.energy_target_norm.to(dtype=dtype, device=device),
force_target_norm=components.force_target_norm.to(dtype=dtype, device=device),
n_energy_rows=components.n_energy_rows,
n_force_rows=components.n_force_rows,
)
def _write_npy_memmap(path: Path, tensor: torch.Tensor) -> None:
"""Persist one tensor as a ``.npy`` file suitable for memory mapping."""
atomic_write_npy_memmap(path, tensor.detach().cpu().numpy())
[docs]
def save_assembled_batches_memmap(
directory: Path | str,
assembled_batches: Sequence[AssembledBatch],
*,
manifest_name: str = "assembled_batches.json",
metadata: dict[str, object] | None = None,
column_chunk_sizes: dict[object, int] | None = None,
compact: bool = True,
) -> None:
"""
Persist assembled least-squares batches as block-separated ``.npy`` files.
The cache stores each batch target and each term-block matrix separately. This
keeps the existing block layout visible on disk and avoids materializing one
global dense design matrix.
Args:
directory: Directory where cache files should be written.
assembled_batches: Batches returned by ``assemble_true_blocks``.
manifest_name: JSON manifest filename inside ``directory``.
metadata: Optional metadata copied into the manifest.
column_chunk_sizes: Optional column chunk sizes keyed by cache block key.
compact: Whether dense block matrices should be compacted during writing.
"""
cache_dir = Path(directory)
cache_dir.mkdir(parents=True, exist_ok=True)
manifest_batches = []
for batch_index, batch in enumerate(assembled_batches):
manifest_batches.append(
_write_assembled_batch_memmap(
cache_dir,
batch_index,
batch,
metadata=metadata,
column_chunk_sizes=column_chunk_sizes,
compact=compact,
)
)
_write_assembled_batches_manifest(
cache_dir,
manifest_batches,
manifest_name=manifest_name,
metadata=metadata,
)
def _write_assembled_batch_memmap(
cache_dir: Path,
batch_index: int,
batch: AssembledBatch,
*,
metadata: dict[str, object] | None = None,
column_chunk_sizes: dict[object, int] | None = None,
compact: bool = True,
) -> dict[str, object]:
"""Persist one assembled batch and return its manifest entry."""
target_name = f"batch{batch_index}_target.npy"
_write_npy_memmap(cache_dir / target_name, batch.target)
block_files: dict[str, str | dict[str, object]] = {}
for block_number, (block_key, matrix) in enumerate(
sorted(batch.block_matrices.items(), key=lambda item: str(item[0]))
):
manifest_key = str(block_key)
block_token = f"block{block_number}"
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
block_files[manifest_key] = _write_column_row_indexed_block(
cache_dir,
batch_index,
block_token,
matrix,
)
continue
if isinstance(matrix, RowIndexedBlockMatrix):
block_files[manifest_key] = _write_row_indexed_block(
cache_dir,
batch_index,
block_token,
matrix,
)
continue
if not isinstance(matrix, torch.Tensor):
matrix = _materialize_block_matrix(matrix)
chunk_size = (
None
if column_chunk_sizes is None
else column_chunk_sizes.get(block_key, column_chunk_sizes.get(manifest_key))
)
compact_chunks = None
if compact and chunk_size is not None:
compact_chunks = _compact_column_chunked_block_matrix(
matrix,
chunk_size=chunk_size,
)
if compact_chunks is not None:
block_files[manifest_key] = _write_column_row_indexed_block(
cache_dir,
batch_index,
block_token,
compact_chunks,
)
continue
if compact:
compact_rows = _compact_block_matrix(matrix)
if compact_rows.rows.numel() < matrix.shape[0]:
block_files[manifest_key] = _write_row_indexed_block(
cache_dir,
batch_index,
block_token,
compact_rows,
)
continue
matrix_name = f"batch{batch_index}_{block_token}.npy"
_write_npy_memmap(cache_dir / matrix_name, matrix)
block_files[manifest_key] = matrix_name
entry = {
"target": target_name,
"blocks": block_files,
}
_write_assembled_batch_manifest(
cache_dir,
batch_index,
entry,
metadata=metadata,
)
return entry
def _write_row_indexed_block(
cache_dir: Path,
batch_index: int,
block_token: str,
matrix: RowIndexedBlockMatrix,
) -> dict[str, object]:
"""Persist one row-indexed block matrix and return its manifest entry."""
row_name = f"batch{batch_index}_{block_token}_rows.npy"
value_name = f"batch{batch_index}_{block_token}_values.npy"
_write_npy_memmap(cache_dir / row_name, matrix.rows.to(dtype=torch.int64))
_write_npy_memmap(cache_dir / value_name, matrix.values)
return {
"storage": "row_indexed",
"rows": row_name,
"values": value_name,
"n_rows": int(matrix.n_rows),
}
def _write_column_row_indexed_block(
cache_dir: Path,
batch_index: int,
block_token: str,
matrix: ColumnRowIndexedBlockMatrix,
) -> dict[str, object]:
"""Persist one column-row indexed block matrix and return its manifest entry."""
chunk_entries = []
for chunk_index, chunk in enumerate(matrix.chunks):
row_name = f"batch{batch_index}_{block_token}_chunk{chunk_index}_rows.npy"
value_name = f"batch{batch_index}_{block_token}_chunk{chunk_index}_values.npy"
_write_npy_memmap(cache_dir / row_name, chunk.rows.to(dtype=torch.int64))
_write_npy_memmap(cache_dir / value_name, chunk.values)
chunk_entries.append(
{
"column_start": int(chunk.column_start),
"n_columns": int(chunk.values.shape[1]),
"rows": row_name,
"values": value_name,
}
)
return {
"storage": "column_row_indexed",
"chunks": chunk_entries,
"n_rows": int(matrix.n_rows),
"n_cols": int(matrix.n_cols),
}
def _write_assembled_batch_manifest(
cache_dir: Path,
batch_index: int,
entry: dict[str, object],
*,
metadata: dict[str, object] | None = None,
) -> None:
"""Write one completed-batch manifest for resumable cache assembly."""
manifest = {
"schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION,
"batch_index": int(batch_index),
"entry": dict(entry),
"metadata": {} if metadata is None else metadata,
}
atomic_write_json(cache_dir / f"batch{batch_index}_manifest.json", manifest)
def _cached_block_files_exist(
cache_dir: Path,
block_entry: object,
*,
manifest_path: Path,
) -> bool:
"""Return whether all files referenced by one cached block exist."""
if isinstance(block_entry, str):
return (cache_dir / block_entry).is_file()
if not isinstance(block_entry, dict):
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
storage = str(block_entry.get("storage"))
if storage == "row_indexed":
for field in ("rows", "values"):
if not (cache_dir / str(block_entry.get(field))).is_file():
return False
if int(block_entry.get("n_rows", -1)) < 0:
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
return True
if storage == "column_row_indexed":
if (
int(block_entry.get("n_rows", -1)) < 0
or int(block_entry.get("n_cols", -1)) < 0
):
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
chunks = block_entry.get("chunks")
if not isinstance(chunks, list):
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
for chunk in chunks:
if not isinstance(chunk, dict):
raise ValueError(
f"invalid least-squares batch manifest: {manifest_path}"
)
if int(chunk.get("column_start", -1)) < 0:
raise ValueError(
f"invalid least-squares batch manifest: {manifest_path}"
)
if int(chunk.get("n_columns", -1)) <= 0:
raise ValueError(
f"invalid least-squares batch manifest: {manifest_path}"
)
for field in ("rows", "values"):
if not (cache_dir / str(chunk.get(field))).is_file():
return False
return True
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
def _load_assembled_batch_manifest(
cache_dir: Path,
batch_index: int,
*,
expected_metadata: dict[str, object],
) -> dict[str, object] | None:
"""Load one completed-batch manifest if it matches the expected fit."""
manifest_path = cache_dir / f"batch{batch_index}_manifest.json"
if not manifest_path.is_file():
return None
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION:
raise ValueError(
"unsupported least-squares batch cache schema version: "
f"{manifest.get('schema_version')}"
)
if int(manifest.get("batch_index", -1)) != int(batch_index):
raise ValueError(
f"least-squares batch manifest {manifest_path} has the wrong index"
)
if not _cache_metadata_matches(manifest.get("metadata"), expected_metadata):
raise ValueError(
"least-squares batch cache metadata does not match the requested "
"samples, targets, dtype, or model layout"
)
entry = manifest.get("entry")
if not isinstance(entry, dict):
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
if not (cache_dir / str(entry.get("target"))).is_file():
return None
blocks = entry.get("blocks")
if not isinstance(blocks, dict):
raise ValueError(f"invalid least-squares batch manifest: {manifest_path}")
for block_entry in blocks.values():
if not _cached_block_files_exist(
cache_dir,
block_entry,
manifest_path=manifest_path,
):
return None
return entry
def _matching_assembled_batch_cache_size(
cache_dir: Path,
*,
expected_metadata: dict[str, object],
) -> int | None:
"""Return the batch size used by matching resumable batch manifests."""
for manifest_path in sorted(cache_dir.glob("batch*_manifest.json")):
try:
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
except (OSError, json.JSONDecodeError):
continue
if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION:
continue
metadata = manifest.get("metadata")
if not _cache_metadata_matches(metadata, expected_metadata):
continue
try:
batch_size = int(metadata["batch_size"]) # type: ignore[index]
except (KeyError, TypeError, ValueError):
continue
if batch_size > 0:
return batch_size
return None
def _write_assembled_batches_manifest(
cache_dir: Path,
manifest_batches: Sequence[dict[str, object]],
*,
manifest_name: str = "assembled_batches.json",
metadata: dict[str, object] | None = None,
) -> None:
"""Write the assembled-batch cache manifest."""
manifest = {
"schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION,
"n_batches": len(manifest_batches),
"batches": list(manifest_batches),
"metadata": {} if metadata is None else metadata,
}
atomic_write_json(cache_dir / manifest_name, manifest)
write_cache_settings_summary(
cache_dir,
cache_kind="least_squares_assembled_batches",
prefix="assembled",
settings=manifest["metadata"],
extra={
"cache_schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION,
"manifest": manifest_name,
},
)
def _normal_equation_cache_exists(cache_directory: Path | str) -> bool:
"""Return whether a split normal-equation cache manifest exists."""
return (Path(cache_directory) / "normal_equations.json").is_file()
def _write_normal_equation_cache(
cache_directory: Path | str,
components: NormalEquationComponents,
*,
metadata: dict[str, object],
) -> None:
"""Persist split normal-equation components as four matrix/vector arrays."""
cache_dir = Path(cache_directory)
cache_dir.mkdir(parents=True, exist_ok=True)
files = {
"energy_gram": "energy_ata.npy",
"energy_rhs": "energy_atb.npy",
"force_gram": "force_ata.npy",
"force_rhs": "force_atb.npy",
}
_write_npy_memmap(cache_dir / files["energy_gram"], components.energy_gram)
_write_npy_memmap(cache_dir / files["energy_rhs"], components.energy_rhs)
_write_npy_memmap(cache_dir / files["force_gram"], components.force_gram)
_write_npy_memmap(cache_dir / files["force_rhs"], components.force_rhs)
manifest = {
"schema_version": _NORMAL_EQUATION_CACHE_SCHEMA_VERSION,
"files": files,
"n_energy_rows": int(components.n_energy_rows),
"n_force_rows": int(components.n_force_rows),
"energy_target_norm": float(components.energy_target_norm.item()),
"force_target_norm": float(components.force_target_norm.item()),
"metadata": metadata,
}
atomic_write_json(cache_dir / "normal_equations.json", manifest)
write_cache_settings_summary(
cache_dir,
cache_kind="least_squares_normal_equations",
prefix="normal_equations",
settings=metadata,
extra={
"cache_schema_version": _NORMAL_EQUATION_CACHE_SCHEMA_VERSION,
"manifest": "normal_equations.json",
},
)
def _load_assembled_batches_manifest(
directory: Path | str,
*,
manifest_name: str = "assembled_batches.json",
) -> dict[str, object]:
"""Load and validate an assembled-batch cache manifest."""
cache_dir = Path(directory)
manifest_path = cache_dir / manifest_name
if not manifest_path.is_file():
child_manifests = sorted(cache_dir.glob(f"*/{manifest_name}"))
if len(child_manifests) == 1:
manifest_path = child_manifests[0]
else:
raise FileNotFoundError(
f"least-squares cache manifest not found: {manifest_path}"
)
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION:
raise ValueError(
"unsupported least-squares cache schema version: "
f"{manifest.get('schema_version')}"
)
manifest["_cache_dir"] = str(manifest_path.parent)
return manifest
def _load_normal_equation_cache(
cache_directory: Path | str,
*,
dtype: torch.dtype,
device: torch.device,
expected_metadata: dict[str, object],
) -> NormalEquationComponents:
"""Load split normal-equation components from a validated cache."""
cache_dir = Path(cache_directory)
manifest_path = cache_dir / "normal_equations.json"
if not manifest_path.is_file():
raise FileNotFoundError(
f"normal-equation cache manifest not found: {manifest_path}"
)
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
if int(manifest.get("schema_version", -1)) != _NORMAL_EQUATION_CACHE_SCHEMA_VERSION:
raise ValueError(
"unsupported normal-equation cache schema version: "
f"{manifest.get('schema_version')}"
)
if not _cache_metadata_matches(manifest.get("metadata"), expected_metadata):
raise ValueError(
"normal-equation cache metadata does not match the requested samples, "
"targets, dtype, or model layout"
)
files = manifest.get("files")
if not isinstance(files, dict):
raise ValueError(f"invalid normal-equation cache manifest: {manifest_path}")
for field in ("energy_gram", "energy_rhs", "force_gram", "force_rhs"):
if not (cache_dir / str(files.get(field))).is_file():
missing_path = cache_dir / str(files.get(field))
raise FileNotFoundError(
f"normal-equation cache file not found: {missing_path}"
)
return NormalEquationComponents(
energy_gram=_load_cached_tensor(
cache_dir / str(files["energy_gram"]),
mmap_mode="r+",
dtype=dtype,
device=device,
),
energy_rhs=_load_cached_tensor(
cache_dir / str(files["energy_rhs"]),
mmap_mode="r+",
dtype=dtype,
device=device,
),
force_gram=_load_cached_tensor(
cache_dir / str(files["force_gram"]),
mmap_mode="r+",
dtype=dtype,
device=device,
),
force_rhs=_load_cached_tensor(
cache_dir / str(files["force_rhs"]),
mmap_mode="r+",
dtype=dtype,
device=device,
),
energy_target_norm=torch.tensor(
float(manifest.get("energy_target_norm", 0.0)),
dtype=dtype,
device=device,
),
force_target_norm=torch.tensor(
float(manifest.get("force_target_norm", 0.0)),
dtype=dtype,
device=device,
),
n_energy_rows=int(manifest.get("n_energy_rows", 0)),
n_force_rows=int(manifest.get("n_force_rows", 0)),
)
def _load_cached_tensor(
path: Path,
*,
mmap_mode: Literal["r", "r+", "c"] | None,
dtype: torch.dtype | None,
device: torch.device | None,
) -> torch.Tensor:
"""Load one cached ``.npy`` array as a torch tensor."""
array = np.load(path, mmap_mode=mmap_mode)
tensor = torch.as_tensor(array)
if dtype is not None or device is not None:
tensor = tensor.to(dtype=dtype, device=device)
return tensor
def _load_cached_block_matrix(
cache_dir: Path,
entry: object,
*,
mmap_mode: Literal["r", "r+", "c"] | None,
dtype: torch.dtype | None,
device: torch.device | None,
row_indexed_blocks: bool,
) -> BlockMatrix:
"""Load one dense or row-indexed cached block matrix."""
if isinstance(entry, str):
return _load_cached_tensor(
cache_dir / entry,
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
)
if not isinstance(entry, dict):
raise ValueError("invalid least-squares cache block entry")
storage = str(entry.get("storage"))
if storage == "row_indexed":
rows = _load_cached_tensor(
cache_dir / str(entry["rows"]),
mmap_mode=mmap_mode,
dtype=torch.int64,
device=device,
)
values = _load_cached_tensor(
cache_dir / str(entry["values"]),
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
)
block = RowIndexedBlockMatrix(
rows=rows,
values=values,
n_rows=int(entry["n_rows"]),
)
if row_indexed_blocks:
return block
return block.materialize()
if storage == "column_row_indexed":
chunks = []
for chunk in entry["chunks"]:
rows = _load_cached_tensor(
cache_dir / str(chunk["rows"]),
mmap_mode=mmap_mode,
dtype=torch.int64,
device=device,
)
values = _load_cached_tensor(
cache_dir / str(chunk["values"]),
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
)
chunks.append(
ColumnRowIndexedChunk(
column_start=int(chunk["column_start"]),
rows=rows,
values=values,
)
)
block = ColumnRowIndexedBlockMatrix(
chunks=tuple(chunks),
n_rows=int(entry["n_rows"]),
n_cols=int(entry["n_cols"]),
)
if row_indexed_blocks:
return block
return block.materialize()
raise ValueError("invalid least-squares cache block entry")
def _load_assembled_batch_entry_memmap(
cache_dir: Path,
entry: object,
*,
mmap_mode: Literal["r", "r+", "c"] | None,
dtype: torch.dtype | None,
device: torch.device | None,
row_indexed_blocks: bool,
row_weights: torch.Tensor | None = None,
) -> AssembledBatch:
"""Load one assembled-batch manifest entry from a disk cache."""
if not isinstance(entry, dict):
raise ValueError("invalid least-squares cache batch entry")
target = _load_cached_tensor(
cache_dir / str(entry["target"]),
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
)
block_matrices = {
str(block_key): _load_cached_block_matrix(
cache_dir,
block_entry,
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
row_indexed_blocks=row_indexed_blocks,
)
for block_key, block_entry in sorted(
entry["blocks"].items(), key=lambda item: str(item[0])
)
}
batch = AssembledBatch(
target=target,
block_matrices=block_matrices,
)
if row_weights is None:
return batch
return _apply_row_weights_to_assembled_batch(batch, row_weights)
[docs]
def load_assembled_batches_memmap(
directory: Path | str,
*,
manifest_name: str = "assembled_batches.json",
mmap_mode: Literal["r", "r+", "c"] | None = "r+",
dtype: torch.dtype | None = None,
device: torch.device | None = None,
expected_metadata: dict[str, object] | None = None,
row_indexed_blocks: bool = False,
row_weights: Sequence[torch.Tensor] | None = None,
) -> tuple[AssembledBatch, ...]:
"""
Load block-separated assembled batches from a disk cache.
Args:
directory: Directory containing a cache manifest and ``.npy`` files.
manifest_name: JSON manifest filename inside ``directory``.
mmap_mode: Mode passed to ``numpy.load``. Use ``None`` to load ordinary
arrays.
dtype: Optional dtype conversion for loaded tensors.
device: Optional device conversion for loaded tensors.
expected_metadata: Optional metadata payload that must match the cache.
row_indexed_blocks: Whether to preserve sparse row-indexed block storage.
row_weights: Optional row weights applied to loaded batches.
Returns:
Assembled least-squares batches.
Raises:
ValueError: If cache metadata or row-weight counts do not match.
"""
cache_dir = Path(directory)
manifest = _load_assembled_batches_manifest(
cache_dir,
manifest_name=manifest_name,
)
if expected_metadata is not None and not _cache_metadata_matches(
manifest.get("metadata"),
expected_metadata,
):
raise ValueError(
"least-squares cache metadata does not match the requested samples, "
"targets, dtype, or model layout"
)
cache_dir = Path(str(manifest.get("_cache_dir", cache_dir)))
manifest_batches = tuple(manifest.get("batches", ()))
if row_weights is not None and len(row_weights) != len(manifest_batches):
raise ValueError(
"number of least-squares cache weight batches does not match manifest"
)
assembled_batches = []
for batch_index, batch in enumerate(manifest_batches):
weight = None if row_weights is None else row_weights[batch_index]
assembled_batches.append(
_load_assembled_batch_entry_memmap(
cache_dir,
batch,
mmap_mode=mmap_mode,
dtype=dtype,
device=device,
row_indexed_blocks=row_indexed_blocks,
row_weights=weight,
)
)
return tuple(assembled_batches)
def _assembled_cache_manifest_exists(
cache_directory: Path | str,
*,
manifest_name: str = "assembled_batches.json",
) -> bool:
"""Return whether a least-squares assembled-batch cache manifest exists."""
return (Path(cache_directory) / manifest_name).is_file()
class _CachedBlockBatchSequence:
"""Lazy sequence of cached solve batches backed by ``.npy`` files."""
def __init__(
self,
*,
cache_dir: Path,
entries: Sequence[object],
dtype: torch.dtype,
device: torch.device,
row_weights: Sequence[torch.Tensor] | None,
repulsion_batch: BlockSolveBatch | None,
load_batch_entry: Callable[..., AssembledBatch] = (
_load_assembled_batch_entry_memmap
),
mmap_mode: Literal["r", "r+", "c"] | None = "r+",
row_indexed_blocks: bool = True,
) -> None:
"""Store cache metadata needed to load one batch at a time."""
if row_weights is not None and len(row_weights) != len(entries):
raise ValueError(
"number of least-squares cache weight batches does not match manifest"
)
self._cache_dir = cache_dir
self._entries = tuple(entries)
self._dtype = dtype
self._device = device
self._row_weights = None if row_weights is None else tuple(row_weights)
self._repulsion_batch = repulsion_batch
self._load_batch_entry = load_batch_entry
self._mmap_mode = mmap_mode
self._row_indexed_blocks = row_indexed_blocks
def __len__(self) -> int:
"""Return the number of cached data batches plus optional pseudo-batch."""
return len(self._entries) + int(self._repulsion_batch is not None)
def __iter__(self):
"""Yield batches by loading each cached batch only for this iteration."""
for index in range(len(self._entries)):
yield self._load_data_batch(index)
if self._repulsion_batch is not None:
yield self._repulsion_batch
def __getitem__(self, index):
"""Load one batch by index, or a tuple of batches for slices."""
if isinstance(index, slice):
return tuple(self[item] for item in range(*index.indices(len(self))))
resolved = int(index)
if resolved < 0:
resolved += len(self)
if resolved < 0 or resolved >= len(self):
raise IndexError("cached least-squares batch index out of range")
if resolved == len(self._entries):
assert self._repulsion_batch is not None
return self._repulsion_batch
return self._load_data_batch(resolved)
def _load_data_batch(self, index: int) -> BlockSolveBatch:
"""Load one cached data batch and apply current row weights."""
weight = None if self._row_weights is None else self._row_weights[index]
batch = self._load_batch_entry(
self._cache_dir,
self._entries[index],
mmap_mode=self._mmap_mode,
dtype=self._dtype,
device=self._device,
row_indexed_blocks=self._row_indexed_blocks,
row_weights=weight,
)
return _block_solve_batch_from_assembled(batch)
[docs]
class CachedBlockLinearProblem:
"""Matrix-free linear problem that streams cached batches from disk."""
def __init__(
self,
*,
layout: BlockProblemLayout,
cache_dir: Path,
entries: Sequence[object],
dtype: torch.dtype,
device: torch.device,
row_weights: Sequence[torch.Tensor],
repulsion_batch: BlockSolveBatch | None = None,
load_batch_entry: Callable[..., AssembledBatch] = (
_load_assembled_batch_entry_memmap
),
row_indexed_blocks: bool = True,
) -> None:
"""Create a cached problem without materializing all batches."""
self.layout = layout
self._dtype = dtype
self._device = device
self._batch_row_counts = tuple(int(weight.numel()) for weight in row_weights)
if len(self._batch_row_counts) != len(entries):
raise ValueError(
"number of least-squares cache weight batches does not match manifest"
)
self._repulsion_batch = repulsion_batch
self._stream_progress = False
self._stream_pass = 0
self._batches = _CachedBlockBatchSequence(
cache_dir=cache_dir,
entries=entries,
dtype=dtype,
device=device,
row_weights=row_weights,
repulsion_batch=repulsion_batch,
load_batch_entry=load_batch_entry,
row_indexed_blocks=bool(row_indexed_blocks),
)
@property
def batches(self) -> _CachedBlockBatchSequence:
"""Return a lazy sequence that loads cached batches on demand."""
return self._batches
@property
def n_rows(self) -> int:
"""Return the total number of target rows across all batches."""
total = sum(self._batch_row_counts)
if self._repulsion_batch is not None:
total += self._repulsion_batch.n_rows
return int(total)
@property
def dtype(self) -> torch.dtype:
"""Return the dtype used when cached batches are streamed."""
return self._dtype
@property
def device(self) -> torch.device:
"""Return the device used when cached batches are streamed."""
return self._device
def _iter_stream_batches(self, description: str):
"""Yield cached batches with an optional progress bar."""
return _iter_with_progress(
self.batches,
enabled=self._stream_progress,
description=description,
total=len(self.batches),
)
[docs]
def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply the regularized normal operator without concatenating rows."""
self._stream_pass += 1
theta = theta.reshape(self.layout.size)
output = torch.zeros(
(self.layout.size,),
dtype=theta.dtype,
device=theta.device,
)
for batch in self._iter_stream_batches(
f"Streaming CG matvec {self._stream_pass}"
):
prediction = torch.zeros(
(batch.n_rows,),
dtype=theta.dtype,
device=theta.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
).to(device=theta.device, dtype=theta.dtype)
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
prediction,
).to(device=theta.device, dtype=theta.dtype)
return output + self.regularization_apply(theta)
[docs]
def rhs(self) -> torch.Tensor:
"""Return the right-hand side by streaming cached batches."""
output = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
for batch in self._iter_stream_batches("Streaming cached RHS"):
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
batch.target,
).to(device=self.device, dtype=self.dtype)
return output + self.regularization_rhs()
[docs]
def normal_equation_diagonal(self) -> torch.Tensor:
"""Return the diagonal of ``A.T @ A`` by streaming cached batches."""
diagonal = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
for batch in self._iter_stream_batches("Streaming normal diagonal"):
for key, block_matrix in batch.matrices.items():
diagonal[self.layout.theta_slice(key)] += _block_matrix_diagonal(
block_matrix
)
return diagonal
[docs]
def design_trace_by_block(self) -> dict[object, float]:
"""Return weighted design-matrix trace contributions by solve block."""
diagonal = self.normal_equation_diagonal()
return {
block.key: float(diagonal[self.layout.theta_slice(block.key)].sum().item())
for block in self.layout.blocks
}
[docs]
def target_vector(self) -> torch.Tensor:
"""Concatenate targets, loading cached batches on demand."""
if not self.batches:
return torch.zeros(0, dtype=self.dtype, device=self.device)
return torch.cat([batch.target for batch in self.batches], dim=0)
[docs]
def materialize_design_matrix(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Build the explicit dense design matrix for debugging or tiny problems."""
matrix = torch.zeros(
(self.n_rows, self.layout.size),
dtype=self.dtype,
device=self.device,
)
targets: list[torch.Tensor] = []
offset = 0
for batch in self.batches:
for key, block_matrix in batch.matrices.items():
matrix[offset : offset + batch.n_rows, self.layout.theta_slice(key)] = (
_materialize_block_matrix(block_matrix)
)
targets.append(batch.target)
offset += batch.n_rows
if not targets:
return matrix, torch.zeros(0, dtype=self.dtype, device=self.device)
return matrix, torch.cat(targets, dim=0)
[docs]
def matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply the design matrix, loading cached batches on demand."""
theta = theta.reshape(self.layout.size)
outputs: list[torch.Tensor] = []
for batch in self.batches:
prediction = torch.zeros(
(batch.n_rows,),
dtype=batch.target.dtype,
device=batch.target.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
)
outputs.append(prediction)
if not outputs:
return torch.zeros(0, dtype=self.dtype, device=self.device)
return torch.cat(outputs, dim=0)
[docs]
def rmatvec(self, residual: torch.Tensor) -> torch.Tensor:
"""Apply the transpose design matrix to a residual vector."""
residual = residual.reshape(self.n_rows)
output = torch.zeros(
(self.layout.size,),
dtype=residual.dtype,
device=residual.device,
)
offset = 0
for batch in self.batches:
batch_residual = residual[offset : offset + batch.n_rows]
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
batch_residual,
)
offset += batch.n_rows
return output
[docs]
def regularization_apply(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply all block regularizers to a flat parameter vector."""
theta = theta.reshape(self.layout.size)
output = torch.zeros_like(theta)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
output[theta_slice] += block.regularization.apply(theta[theta_slice])
return output
[docs]
def regularization_diagonal(self) -> torch.Tensor:
"""Return the summed diagonal preconditioner implied by regularizers."""
diag = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
diag[theta_slice] += block.regularization.diagonal(
dtype=self.dtype,
device=self.device,
)
return diag
[docs]
def regularization_rhs(self) -> torch.Tensor:
"""Return the summed RHS shifts implied by block regularizers."""
rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
rhs[theta_slice] += block.regularization.rhs(
dtype=self.dtype,
device=self.device,
)
return rhs
[docs]
def accumulate_normal_equations(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Materialize the Gram matrix and right-hand side from streamed batches."""
gram = torch.zeros(
(self.layout.size, self.layout.size),
dtype=self.dtype,
device=self.device,
)
rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device)
for batch in self.batches:
keys = tuple(batch.matrices)
for key in keys:
theta_slice = self.layout.theta_slice(key)
block_matrix = batch.matrices[key]
rhs[theta_slice] += _block_matrix_rmatvec(block_matrix, batch.target)
for index_i, key_i in enumerate(keys):
slice_i = self.layout.theta_slice(key_i)
matrix_i = batch.matrices[key_i]
gram[slice_i, slice_i] += _block_matrix_cross(matrix_i, matrix_i)
for key_j in keys[index_i + 1 :]:
slice_j = self.layout.theta_slice(key_j)
cross = _block_matrix_cross(matrix_i, batch.matrices[key_j])
gram[slice_i, slice_j] += cross
gram[slice_j, slice_i] += cross.T
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
gram[theta_slice, theta_slice] += block.regularization.materialize(
dtype=self.dtype,
device=self.device,
)
rhs[theta_slice] += block.regularization.rhs(
dtype=self.dtype,
device=self.device,
)
return gram, rhs
[docs]
def objective(self, theta: torch.Tensor) -> torch.Tensor:
"""Evaluate the regularized objective without materializing residuals."""
theta = theta.reshape(self.layout.size)
value = torch.zeros((), dtype=theta.dtype, device=theta.device)
for batch in self.batches:
prediction = torch.zeros(
(batch.n_rows,),
dtype=theta.dtype,
device=theta.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
).to(device=theta.device, dtype=theta.dtype)
target = batch.target.to(device=theta.device, dtype=theta.dtype)
residual = prediction - target
value = value + torch.dot(residual, residual)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
value = value + block.regularization.quadratic(theta[theta_slice])
return value
[docs]
def residual_norm(self, theta: torch.Tensor) -> torch.Tensor:
"""Return ``||A theta - b||`` without materializing all residual rows."""
theta = theta.reshape(self.layout.size)
squared = torch.zeros((), dtype=theta.dtype, device=theta.device)
for batch in self.batches:
prediction = torch.zeros(
(batch.n_rows,),
dtype=theta.dtype,
device=theta.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
).to(device=theta.device, dtype=theta.dtype)
target = batch.target.to(device=theta.device, dtype=theta.dtype)
residual = prediction - target
squared = squared + torch.dot(residual, residual)
return torch.sqrt(torch.clamp(squared, min=0.0))
[docs]
def solve(
self,
*,
solver: str,
cg_tolerance: float,
cg_max_iter: int | None,
progress: bool = False,
progress_frequency: int = 10,
initial_theta: torch.Tensor | None = None,
cg_checkpoint_path: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
cg_checkpoint_metadata: dict[str, object] | None = None,
fallback_theta: torch.Tensor | None = None,
return_info: bool = False,
) -> torch.Tensor | LinearSolveResult:
"""Solve the streamed problem with the selected backend."""
def fallback_result() -> LinearSolveResult:
if fallback_theta is not None:
theta = fallback_theta.to(dtype=self.dtype, device=self.device)
elif initial_theta is not None:
theta = initial_theta.to(dtype=self.dtype, device=self.device)
else:
theta = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
return LinearSolveResult(theta=theta, interrupted=True)
def maybe_return(result: LinearSolveResult) -> torch.Tensor | LinearSolveResult:
return result if return_info else result.theta
if solver == "dense_lstsq":
try:
matrix, target = self.materialize_design_matrix()
reg_rows = []
reg_targets = []
for block in self.layout.blocks:
if block.regularization is None:
continue
block_rows, block_target = block.regularization.least_squares_rows(
dtype=self.dtype,
device=self.device,
)
if block_rows.shape[0] == 0:
continue
row_block = torch.zeros(
(block_rows.shape[0], self.layout.size),
dtype=self.dtype,
device=self.device,
)
row_block[:, self.layout.theta_slice(block.key)] = block_rows
reg_rows.append(row_block)
reg_targets.append(block_target)
if reg_rows:
matrix = torch.cat([matrix, *reg_rows], dim=0)
target = torch.cat([target, *reg_targets], dim=0)
return maybe_return(
LinearSolveResult(torch.linalg.lstsq(matrix, target).solution)
)
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print(
"Interrupted dense least-squares solve; "
"using fallback coefficients."
)
return fallback_result()
if solver == "normal_equation_direct":
try:
gram, rhs = self.accumulate_normal_equations()
if progress:
print("Solving normal equations directly...")
try:
theta = torch.linalg.solve(gram, rhs)
except RuntimeError:
if progress:
print(
"Direct solve failed; falling back to torch.linalg.lstsq."
)
theta = torch.linalg.lstsq(gram, rhs).solution
return maybe_return(LinearSolveResult(theta))
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print(
"Interrupted normal-equation solve; "
"using fallback coefficients."
)
return fallback_result()
if solver == "cg":
self._stream_progress = progress
self._stream_pass = 0
try:
rhs = self.rhs()
if cg_checkpoint_metadata is None:
checkpoint_metadata = _cg_checkpoint_metadata(
n_parameters=self.layout.size,
dtype=self.dtype,
)
else:
checkpoint_metadata = dict(cg_checkpoint_metadata)
checkpoint_state = (
None
if not cg_resume or cg_checkpoint_path is None
else load_cg_checkpoint(
cg_checkpoint_path,
dtype=self.dtype,
device=self.device,
expected_metadata=checkpoint_metadata,
)
)
result = _conjugate_gradient(
self.normal_matvec,
rhs,
diagonal_preconditioner=self.regularization_diagonal()
+ self.normal_equation_diagonal(),
tolerance=cg_tolerance,
max_iter=cg_max_iter,
progress=progress,
progress_frequency=progress_frequency,
initial_guess=initial_theta,
checkpoint_state=checkpoint_state,
checkpoint_path=cg_checkpoint_path,
checkpoint_frequency=cg_checkpoint_frequency,
checkpoint_metadata=checkpoint_metadata,
handle_interrupts=return_info,
)
return maybe_return(result)
except KeyboardInterrupt:
if not return_info:
raise
if progress:
print("Interrupted CG setup; using fallback coefficients.")
return fallback_result()
finally:
self._stream_progress = False
choices = ", ".join(["dense_lstsq", "normal_equation_direct", "cg"])
raise ValueError(f"Unsupported solver '{solver}'. Expected one of: {choices}.")
__all__ = [
"AssembledBatchCacheMode",
"CachedBlockLinearProblem",
"NormalEquationComponents",
"assembled_cache_dir",
"load_assembled_batches_memmap",
"normal_equation_cache_dir",
"save_assembled_batches_memmap",
]