Source code for ufp.leastsquares._cache

"""Disk cache metadata, memmap IO, and streamed cached problem support."""

from __future__ import annotations

import hashlib
import json
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Callable, Literal, Sequence

import numpy as np
import torch

from ufp.core._disk_cache import (
    atomic_write_json,
    atomic_write_npy_memmap,
    settings_cache_dir,
    write_cache_settings_summary,
)
from ufp.leastsquares._block import (
    BlockMatrix,
    BlockProblemLayout,
    BlockSolveBatch,
    ColumnRowIndexedBlockMatrix,
    ColumnRowIndexedChunk,
    RowIndexedBlockMatrix,
    _apply_row_weights_to_assembled_batch,
    _block_matrix_cross,
    _block_matrix_diagonal,
    _block_matrix_matvec,
    _block_matrix_rmatvec,
    _block_solve_batch_from_assembled,
    _compact_block_matrix,
    _compact_column_chunked_block_matrix,
    _materialize_block_matrix,
)
from ufp.leastsquares._layout import ParameterLayout
from ufp.leastsquares._problem import (
    LinearSolveResult,
    _cg_checkpoint_metadata,
    _conjugate_gradient,
    load_cg_checkpoint,
)
from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares._utils import _iter_with_progress
from ufp.leastsquares.dataset import FitSample


_ASSEMBLED_CACHE_SCHEMA_VERSION = 7
_NORMAL_EQUATION_CACHE_SCHEMA_VERSION = 3
AssembledBatchCacheMode = Literal["auto", "read", "write", "refresh"]


def _hash_array(hasher, value: object, *, dtype: np.dtype | None = None) -> None:
    """Add one shaped array to a cache metadata hash."""
    array = np.ascontiguousarray(np.asarray(value, dtype=dtype))
    hasher.update(str(array.dtype).encode("utf8"))
    hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
    hasher.update(array.tobytes())


def _to_numpy_for_hash(value: object) -> np.ndarray:
    """Convert torch or numpy-like arrays to a CPU numpy array for hashing."""
    if isinstance(value, torch.Tensor):
        return value.detach().cpu().numpy()
    return np.asarray(value)


def _sample_signature(
    samples: Sequence[FitSample],
    *,
    include_weights: bool = True,
) -> str:
    """Return a stable digest for input structures, targets, and weights."""
    hasher = hashlib.sha256()
    for sample in samples:
        atoms = sample.atoms
        hasher.update(b"sample")
        _hash_array(hasher, atoms.numbers, dtype=np.int64)
        _hash_array(hasher, atoms.positions, dtype=np.float64)
        _hash_array(hasher, atoms.cell.array, dtype=np.float64)
        _hash_array(hasher, atoms.pbc, dtype=np.bool_)
        if sample.neighbor_list is None:
            hasher.update(b"neighbor_list:none")
        else:
            neighbor_list = sample.neighbor_list
            hasher.update(b"neighbor_list")
            _hash_array(hasher, _to_numpy_for_hash(neighbor_list.pairs))
            _hash_array(hasher, _to_numpy_for_hash(neighbor_list.shifts))
            if neighbor_list.distances is not None:
                _hash_array(
                    hasher,
                    _to_numpy_for_hash(neighbor_list.distances),
                    dtype=np.float64,
                )
            if neighbor_list.vectors is not None:
                _hash_array(
                    hasher,
                    _to_numpy_for_hash(neighbor_list.vectors),
                    dtype=np.float64,
                )
            hasher.update(str(bool(neighbor_list.full_list)).encode("utf8"))

        hasher.update(str(sample.energy is not None).encode("utf8"))
        if sample.energy is not None:
            _hash_array(hasher, [sample.energy], dtype=np.float64)
        hasher.update(str(sample.forces is not None).encode("utf8"))
        if sample.forces is not None:
            _hash_array(hasher, sample.forces, dtype=np.float64)
        hasher.update(str(sample.per_atom_energy is not None).encode("utf8"))
        if sample.per_atom_energy is not None:
            _hash_array(hasher, sample.per_atom_energy, dtype=np.float64)
        if include_weights:
            _hash_array(
                hasher,
                [sample.energy_weight, sample.force_weight, sample.per_atom_weight],
                dtype=np.float64,
            )
    return hasher.hexdigest()


def _layout_signature(layout: ParameterLayout) -> list[dict[str, object]]:
    """Return stable metadata for the least-squares parameter layout."""
    signature = []
    for block in layout.blocks:
        provider = block.coefficient_provider
        provider_signature = None
        if provider is not None:
            provider_signature = {
                "coefficient_shape": [int(dim) for dim in provider.coefficient_shape],
                "n_proxy_terms": int(provider.n_proxy_terms),
                "n_true_terms": int(provider.n_true_terms),
                "uses_identity_weights": bool(provider.uses_identity_weights),
            }
        signature.append(
            {
                "index": int(block.index),
                "name": block.name,
                "kind": block.kind,
                "label": block.label,
                "shape": [int(dim) for dim in block.shape],
                "start": int(block.start),
                "stop": int(block.stop),
                "coefficient_index": block.coefficient_index,
                "regularization_group": block.regularization_group,
                "assembler": (
                    block.assembler
                    if isinstance(block.assembler, str)
                    else type(block.assembler).__name__
                ),
                "provider": provider_signature,
            }
        )
    return signature


def _cache_metadata_for_fit(
    *,
    layout: ParameterLayout,
    samples: Sequence[FitSample],
    fit_energy: bool,
    fit_forces: bool,
    fit_per_atom_energy: bool,
    dtype: torch.dtype,
    batch_size: int,
    include_sample_weights: bool = True,
) -> dict[str, object]:
    """Return cache metadata needed to detect stale assembled batches."""
    return {
        "sample_count": len(samples),
        "sample_signature": _sample_signature(
            samples,
            include_weights=include_sample_weights,
        ),
        "layout_size": int(layout.size),
        "layout": _layout_signature(layout),
        "fit_energy": bool(fit_energy),
        "fit_forces": bool(fit_forces),
        "fit_per_atom_energy": bool(fit_per_atom_energy),
        "dtype": str(dtype),
        "batch_size": int(batch_size),
    }


def _cache_identity_metadata(
    metadata: dict[str, object] | None,
) -> dict[str, object]:
    """Return metadata used for cache identity comparisons and directory names."""
    if metadata is None:
        return {}
    return dict(metadata)


def _cache_metadata_matches(
    cached: dict[str, object] | None,
    expected: dict[str, object],
) -> bool:
    """Return whether cached assembled batches match the requested fit."""
    return _cache_identity_metadata(cached) == _cache_identity_metadata(expected)


def _cache_problem_metadata_matches(
    cached: dict[str, object] | None,
    expected: dict[str, object],
) -> bool:
    """Return whether cache row/problem metadata is cross-layout compatible."""
    if cached is None:
        return False
    ignored = {
        "layout",
        "layout_size",
        "selected_block_indices",
        "coefficient_selection",
        "fixed_coefficients_signature",
        "cache_blocks",
        "cache_reusable",
    }
    cached_comparable = {
        key: value for key, value in cached.items() if key not in ignored
    }
    expected_comparable = {
        key: value for key, value in expected.items() if key not in ignored
    }
    return cached_comparable == expected_comparable


def _cache_metadata_can_project(
    cached: dict[str, object] | None,
    expected: dict[str, object],
) -> bool:
    """Return whether cached semantic blocks may be projected to this layout."""
    if cached is None:
        return False
    if not bool(cached.get("cache_reusable")) or not bool(
        expected.get("cache_reusable")
    ):
        return False
    if "cache_blocks" not in cached or "cache_blocks" not in expected:
        return False
    return _cache_problem_metadata_matches(cached, expected)


def _cache_metadata_mismatch_reasons(
    cached: dict[str, object] | None,
    expected: dict[str, object],
) -> tuple[str, ...]:
    """Return concise human-readable differences between cache metadata payloads."""
    if cached is None:
        return ("missing metadata",)
    cached_comparable = _cache_identity_metadata(cached)
    expected_comparable = _cache_identity_metadata(expected)
    reasons = []
    labels = {
        "sample_count": "sample count",
        "sample_signature": "sample geometry or targets",
        "layout_size": "parameter count",
        "layout": "model parameter layout",
        "cache_blocks": "semantic cache block layout",
        "fit_energy": "energy target setting",
        "fit_forces": "force target setting",
        "fit_per_atom_energy": "per-atom energy target setting",
        "dtype": "dtype",
        "batch_size": "batch size",
    }
    for key in sorted(set(cached_comparable) | set(expected_comparable)):
        if cached_comparable.get(key) != expected_comparable.get(key):
            reasons.append(labels.get(key, key))
    return tuple(reasons)


def _assembled_cache_metadata_for_fit(
    *,
    layout: ParameterLayout,
    samples: Sequence[FitSample],
    fit_energy: bool,
    fit_forces: bool,
    fit_per_atom_energy: bool,
    dtype: torch.dtype,
    batch_size: int,
) -> dict[str, object]:
    """Return assembled-cache metadata independent of current target weights."""
    metadata = _cache_metadata_for_fit(
        layout=layout,
        samples=samples,
        fit_energy=fit_energy,
        fit_forces=fit_forces,
        fit_per_atom_energy=fit_per_atom_energy,
        dtype=dtype,
        batch_size=batch_size,
        include_sample_weights=False,
    )
    metadata["assembled_components"] = "column_chunked_per_atom_energy_rows_v1"
    return metadata


def _normal_equation_cache_metadata_for_fit(
    *,
    layout: ParameterLayout,
    samples: Sequence[FitSample],
    fit_energy: bool,
    fit_forces: bool,
    fit_per_atom_energy: bool,
    dtype: torch.dtype,
    batch_size: int,
) -> dict[str, object]:
    """Return cache metadata for split unweighted normal-equation components."""
    metadata = _cache_metadata_for_fit(
        layout=layout,
        samples=samples,
        fit_energy=fit_energy,
        fit_forces=fit_forces,
        fit_per_atom_energy=fit_per_atom_energy,
        dtype=dtype,
        batch_size=batch_size,
        include_sample_weights=False,
    )
    metadata["normal_equation_components"] = "per_atom_energy_force_split_v1"
    return metadata


def assembled_cache_dir(
    parent: Path | str,
    metadata: dict[str, object],
) -> Path:
    """Return the settings-named child directory for assembled batches."""
    return settings_cache_dir(
        parent,
        "assembled",
        _cache_identity_metadata(metadata),
    )


def normal_equation_cache_dir(
    parent: Path | str,
    metadata: dict[str, object],
) -> Path:
    """Return the settings-named child directory for normal equations."""
    return settings_cache_dir(
        parent,
        "normal_equations",
        _cache_identity_metadata(metadata),
    )


def _uniform_target_weight(
    samples: Sequence[FitSample],
    *,
    target_name: str,
    value_name: str,
    weight_name: str,
    enabled: bool,
) -> float:
    """Return a uniform target weight, rejecting per-sample reweighting."""
    if not enabled:
        return 0.0
    weights = [
        float(getattr(sample, weight_name))
        for sample in samples
        if getattr(sample, value_name) is not None
    ]
    if not weights:
        return 0.0

    first = weights[0]
    if any(
        not np.isclose(weight, first, rtol=1.0e-12, atol=1.0e-12) for weight in weights
    ):
        raise ValueError(
            "normal-equation caching requires uniform "
            f"{target_name} weights across active samples"
        )
    return first


def _normal_equation_target_weights(
    samples: Sequence[FitSample],
    *,
    fit_energy: bool,
    fit_forces: bool,
    fit_per_atom_energy: bool,
) -> tuple[float, float]:
    """Return global energy and force weights for cached normal equations."""
    if fit_per_atom_energy and any(
        sample.per_atom_energy is not None for sample in samples
    ):
        raise ValueError(
            "normal-equation caching currently supports energy and force targets only"
        )
    return (
        _uniform_target_weight(
            samples,
            target_name="energy",
            value_name="energy",
            weight_name="energy_weight",
            enabled=fit_energy,
        ),
        _uniform_target_weight(
            samples,
            target_name="force",
            value_name="forces",
            weight_name="force_weight",
            enabled=fit_forces,
        ),
    )


def _samples_with_unit_target_weights(
    samples: Sequence[FitSample],
) -> tuple[FitSample, ...]:
    """Return samples with target values unchanged and target weights set to one."""
    return tuple(
        replace(
            sample,
            energy_weight=1.0,
            force_weight=1.0,
            per_atom_weight=1.0,
        )
        for sample in samples
    )


@dataclass
class NormalEquationComponents:
    """Unweighted split normal-equation components for energy and force rows."""

    energy_gram: torch.Tensor
    energy_rhs: torch.Tensor
    force_gram: torch.Tensor
    force_rhs: torch.Tensor
    energy_target_norm: torch.Tensor
    force_target_norm: torch.Tensor
    n_energy_rows: int
    n_force_rows: int

    @property
    def n_rows(self) -> int:
        """Return the number of cached target rows."""
        return int(self.n_energy_rows + self.n_force_rows)


def _normal_equation_components_to(
    components: NormalEquationComponents,
    *,
    dtype: torch.dtype,
    device: torch.device,
) -> NormalEquationComponents:
    """Move split normal-equation components to a solve device."""
    return NormalEquationComponents(
        energy_gram=components.energy_gram.to(dtype=dtype, device=device),
        energy_rhs=components.energy_rhs.to(dtype=dtype, device=device),
        force_gram=components.force_gram.to(dtype=dtype, device=device),
        force_rhs=components.force_rhs.to(dtype=dtype, device=device),
        energy_target_norm=components.energy_target_norm.to(dtype=dtype, device=device),
        force_target_norm=components.force_target_norm.to(dtype=dtype, device=device),
        n_energy_rows=components.n_energy_rows,
        n_force_rows=components.n_force_rows,
    )


def _write_npy_memmap(path: Path, tensor: torch.Tensor) -> None:
    """Persist one tensor as a ``.npy`` file suitable for memory mapping."""
    atomic_write_npy_memmap(path, tensor.detach().cpu().numpy())


[docs] def save_assembled_batches_memmap( directory: Path | str, assembled_batches: Sequence[AssembledBatch], *, manifest_name: str = "assembled_batches.json", metadata: dict[str, object] | None = None, column_chunk_sizes: dict[object, int] | None = None, compact: bool = True, ) -> None: """ Persist assembled least-squares batches as block-separated ``.npy`` files. The cache stores each batch target and each term-block matrix separately. This keeps the existing block layout visible on disk and avoids materializing one global dense design matrix. Args: directory: Directory where cache files should be written. assembled_batches: Batches returned by ``assemble_true_blocks``. manifest_name: JSON manifest filename inside ``directory``. metadata: Optional metadata copied into the manifest. column_chunk_sizes: Optional column chunk sizes keyed by cache block key. compact: Whether dense block matrices should be compacted during writing. """ cache_dir = Path(directory) cache_dir.mkdir(parents=True, exist_ok=True) manifest_batches = [] for batch_index, batch in enumerate(assembled_batches): manifest_batches.append( _write_assembled_batch_memmap( cache_dir, batch_index, batch, metadata=metadata, column_chunk_sizes=column_chunk_sizes, compact=compact, ) ) _write_assembled_batches_manifest( cache_dir, manifest_batches, manifest_name=manifest_name, metadata=metadata, )
def _write_assembled_batch_memmap( cache_dir: Path, batch_index: int, batch: AssembledBatch, *, metadata: dict[str, object] | None = None, column_chunk_sizes: dict[object, int] | None = None, compact: bool = True, ) -> dict[str, object]: """Persist one assembled batch and return its manifest entry.""" target_name = f"batch{batch_index}_target.npy" _write_npy_memmap(cache_dir / target_name, batch.target) block_files: dict[str, str | dict[str, object]] = {} for block_number, (block_key, matrix) in enumerate( sorted(batch.block_matrices.items(), key=lambda item: str(item[0])) ): manifest_key = str(block_key) block_token = f"block{block_number}" if isinstance(matrix, ColumnRowIndexedBlockMatrix): block_files[manifest_key] = _write_column_row_indexed_block( cache_dir, batch_index, block_token, matrix, ) continue if isinstance(matrix, RowIndexedBlockMatrix): block_files[manifest_key] = _write_row_indexed_block( cache_dir, batch_index, block_token, matrix, ) continue if not isinstance(matrix, torch.Tensor): matrix = _materialize_block_matrix(matrix) chunk_size = ( None if column_chunk_sizes is None else column_chunk_sizes.get(block_key, column_chunk_sizes.get(manifest_key)) ) compact_chunks = None if compact and chunk_size is not None: compact_chunks = _compact_column_chunked_block_matrix( matrix, chunk_size=chunk_size, ) if compact_chunks is not None: block_files[manifest_key] = _write_column_row_indexed_block( cache_dir, batch_index, block_token, compact_chunks, ) continue if compact: compact_rows = _compact_block_matrix(matrix) if compact_rows.rows.numel() < matrix.shape[0]: block_files[manifest_key] = _write_row_indexed_block( cache_dir, batch_index, block_token, compact_rows, ) continue matrix_name = f"batch{batch_index}_{block_token}.npy" _write_npy_memmap(cache_dir / matrix_name, matrix) block_files[manifest_key] = matrix_name entry = { "target": target_name, "blocks": block_files, } _write_assembled_batch_manifest( cache_dir, batch_index, entry, metadata=metadata, ) return entry def _write_row_indexed_block( cache_dir: Path, batch_index: int, block_token: str, matrix: RowIndexedBlockMatrix, ) -> dict[str, object]: """Persist one row-indexed block matrix and return its manifest entry.""" row_name = f"batch{batch_index}_{block_token}_rows.npy" value_name = f"batch{batch_index}_{block_token}_values.npy" _write_npy_memmap(cache_dir / row_name, matrix.rows.to(dtype=torch.int64)) _write_npy_memmap(cache_dir / value_name, matrix.values) return { "storage": "row_indexed", "rows": row_name, "values": value_name, "n_rows": int(matrix.n_rows), } def _write_column_row_indexed_block( cache_dir: Path, batch_index: int, block_token: str, matrix: ColumnRowIndexedBlockMatrix, ) -> dict[str, object]: """Persist one column-row indexed block matrix and return its manifest entry.""" chunk_entries = [] for chunk_index, chunk in enumerate(matrix.chunks): row_name = f"batch{batch_index}_{block_token}_chunk{chunk_index}_rows.npy" value_name = f"batch{batch_index}_{block_token}_chunk{chunk_index}_values.npy" _write_npy_memmap(cache_dir / row_name, chunk.rows.to(dtype=torch.int64)) _write_npy_memmap(cache_dir / value_name, chunk.values) chunk_entries.append( { "column_start": int(chunk.column_start), "n_columns": int(chunk.values.shape[1]), "rows": row_name, "values": value_name, } ) return { "storage": "column_row_indexed", "chunks": chunk_entries, "n_rows": int(matrix.n_rows), "n_cols": int(matrix.n_cols), } def _write_assembled_batch_manifest( cache_dir: Path, batch_index: int, entry: dict[str, object], *, metadata: dict[str, object] | None = None, ) -> None: """Write one completed-batch manifest for resumable cache assembly.""" manifest = { "schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION, "batch_index": int(batch_index), "entry": dict(entry), "metadata": {} if metadata is None else metadata, } atomic_write_json(cache_dir / f"batch{batch_index}_manifest.json", manifest) def _cached_block_files_exist( cache_dir: Path, block_entry: object, *, manifest_path: Path, ) -> bool: """Return whether all files referenced by one cached block exist.""" if isinstance(block_entry, str): return (cache_dir / block_entry).is_file() if not isinstance(block_entry, dict): raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") storage = str(block_entry.get("storage")) if storage == "row_indexed": for field in ("rows", "values"): if not (cache_dir / str(block_entry.get(field))).is_file(): return False if int(block_entry.get("n_rows", -1)) < 0: raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") return True if storage == "column_row_indexed": if ( int(block_entry.get("n_rows", -1)) < 0 or int(block_entry.get("n_cols", -1)) < 0 ): raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") chunks = block_entry.get("chunks") if not isinstance(chunks, list): raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") for chunk in chunks: if not isinstance(chunk, dict): raise ValueError( f"invalid least-squares batch manifest: {manifest_path}" ) if int(chunk.get("column_start", -1)) < 0: raise ValueError( f"invalid least-squares batch manifest: {manifest_path}" ) if int(chunk.get("n_columns", -1)) <= 0: raise ValueError( f"invalid least-squares batch manifest: {manifest_path}" ) for field in ("rows", "values"): if not (cache_dir / str(chunk.get(field))).is_file(): return False return True raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") def _load_assembled_batch_manifest( cache_dir: Path, batch_index: int, *, expected_metadata: dict[str, object], ) -> dict[str, object] | None: """Load one completed-batch manifest if it matches the expected fit.""" manifest_path = cache_dir / f"batch{batch_index}_manifest.json" if not manifest_path.is_file(): return None with manifest_path.open("r", encoding="utf8") as handle: manifest = json.load(handle) if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION: raise ValueError( "unsupported least-squares batch cache schema version: " f"{manifest.get('schema_version')}" ) if int(manifest.get("batch_index", -1)) != int(batch_index): raise ValueError( f"least-squares batch manifest {manifest_path} has the wrong index" ) if not _cache_metadata_matches(manifest.get("metadata"), expected_metadata): raise ValueError( "least-squares batch cache metadata does not match the requested " "samples, targets, dtype, or model layout" ) entry = manifest.get("entry") if not isinstance(entry, dict): raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") if not (cache_dir / str(entry.get("target"))).is_file(): return None blocks = entry.get("blocks") if not isinstance(blocks, dict): raise ValueError(f"invalid least-squares batch manifest: {manifest_path}") for block_entry in blocks.values(): if not _cached_block_files_exist( cache_dir, block_entry, manifest_path=manifest_path, ): return None return entry def _matching_assembled_batch_cache_size( cache_dir: Path, *, expected_metadata: dict[str, object], ) -> int | None: """Return the batch size used by matching resumable batch manifests.""" for manifest_path in sorted(cache_dir.glob("batch*_manifest.json")): try: with manifest_path.open("r", encoding="utf8") as handle: manifest = json.load(handle) except (OSError, json.JSONDecodeError): continue if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION: continue metadata = manifest.get("metadata") if not _cache_metadata_matches(metadata, expected_metadata): continue try: batch_size = int(metadata["batch_size"]) # type: ignore[index] except (KeyError, TypeError, ValueError): continue if batch_size > 0: return batch_size return None def _write_assembled_batches_manifest( cache_dir: Path, manifest_batches: Sequence[dict[str, object]], *, manifest_name: str = "assembled_batches.json", metadata: dict[str, object] | None = None, ) -> None: """Write the assembled-batch cache manifest.""" manifest = { "schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION, "n_batches": len(manifest_batches), "batches": list(manifest_batches), "metadata": {} if metadata is None else metadata, } atomic_write_json(cache_dir / manifest_name, manifest) write_cache_settings_summary( cache_dir, cache_kind="least_squares_assembled_batches", prefix="assembled", settings=manifest["metadata"], extra={ "cache_schema_version": _ASSEMBLED_CACHE_SCHEMA_VERSION, "manifest": manifest_name, }, ) def _normal_equation_cache_exists(cache_directory: Path | str) -> bool: """Return whether a split normal-equation cache manifest exists.""" return (Path(cache_directory) / "normal_equations.json").is_file() def _write_normal_equation_cache( cache_directory: Path | str, components: NormalEquationComponents, *, metadata: dict[str, object], ) -> None: """Persist split normal-equation components as four matrix/vector arrays.""" cache_dir = Path(cache_directory) cache_dir.mkdir(parents=True, exist_ok=True) files = { "energy_gram": "energy_ata.npy", "energy_rhs": "energy_atb.npy", "force_gram": "force_ata.npy", "force_rhs": "force_atb.npy", } _write_npy_memmap(cache_dir / files["energy_gram"], components.energy_gram) _write_npy_memmap(cache_dir / files["energy_rhs"], components.energy_rhs) _write_npy_memmap(cache_dir / files["force_gram"], components.force_gram) _write_npy_memmap(cache_dir / files["force_rhs"], components.force_rhs) manifest = { "schema_version": _NORMAL_EQUATION_CACHE_SCHEMA_VERSION, "files": files, "n_energy_rows": int(components.n_energy_rows), "n_force_rows": int(components.n_force_rows), "energy_target_norm": float(components.energy_target_norm.item()), "force_target_norm": float(components.force_target_norm.item()), "metadata": metadata, } atomic_write_json(cache_dir / "normal_equations.json", manifest) write_cache_settings_summary( cache_dir, cache_kind="least_squares_normal_equations", prefix="normal_equations", settings=metadata, extra={ "cache_schema_version": _NORMAL_EQUATION_CACHE_SCHEMA_VERSION, "manifest": "normal_equations.json", }, ) def _load_assembled_batches_manifest( directory: Path | str, *, manifest_name: str = "assembled_batches.json", ) -> dict[str, object]: """Load and validate an assembled-batch cache manifest.""" cache_dir = Path(directory) manifest_path = cache_dir / manifest_name if not manifest_path.is_file(): child_manifests = sorted(cache_dir.glob(f"*/{manifest_name}")) if len(child_manifests) == 1: manifest_path = child_manifests[0] else: raise FileNotFoundError( f"least-squares cache manifest not found: {manifest_path}" ) with manifest_path.open("r", encoding="utf8") as handle: manifest = json.load(handle) if int(manifest.get("schema_version", -1)) != _ASSEMBLED_CACHE_SCHEMA_VERSION: raise ValueError( "unsupported least-squares cache schema version: " f"{manifest.get('schema_version')}" ) manifest["_cache_dir"] = str(manifest_path.parent) return manifest def _load_normal_equation_cache( cache_directory: Path | str, *, dtype: torch.dtype, device: torch.device, expected_metadata: dict[str, object], ) -> NormalEquationComponents: """Load split normal-equation components from a validated cache.""" cache_dir = Path(cache_directory) manifest_path = cache_dir / "normal_equations.json" if not manifest_path.is_file(): raise FileNotFoundError( f"normal-equation cache manifest not found: {manifest_path}" ) with manifest_path.open("r", encoding="utf8") as handle: manifest = json.load(handle) if int(manifest.get("schema_version", -1)) != _NORMAL_EQUATION_CACHE_SCHEMA_VERSION: raise ValueError( "unsupported normal-equation cache schema version: " f"{manifest.get('schema_version')}" ) if not _cache_metadata_matches(manifest.get("metadata"), expected_metadata): raise ValueError( "normal-equation cache metadata does not match the requested samples, " "targets, dtype, or model layout" ) files = manifest.get("files") if not isinstance(files, dict): raise ValueError(f"invalid normal-equation cache manifest: {manifest_path}") for field in ("energy_gram", "energy_rhs", "force_gram", "force_rhs"): if not (cache_dir / str(files.get(field))).is_file(): missing_path = cache_dir / str(files.get(field)) raise FileNotFoundError( f"normal-equation cache file not found: {missing_path}" ) return NormalEquationComponents( energy_gram=_load_cached_tensor( cache_dir / str(files["energy_gram"]), mmap_mode="r+", dtype=dtype, device=device, ), energy_rhs=_load_cached_tensor( cache_dir / str(files["energy_rhs"]), mmap_mode="r+", dtype=dtype, device=device, ), force_gram=_load_cached_tensor( cache_dir / str(files["force_gram"]), mmap_mode="r+", dtype=dtype, device=device, ), force_rhs=_load_cached_tensor( cache_dir / str(files["force_rhs"]), mmap_mode="r+", dtype=dtype, device=device, ), energy_target_norm=torch.tensor( float(manifest.get("energy_target_norm", 0.0)), dtype=dtype, device=device, ), force_target_norm=torch.tensor( float(manifest.get("force_target_norm", 0.0)), dtype=dtype, device=device, ), n_energy_rows=int(manifest.get("n_energy_rows", 0)), n_force_rows=int(manifest.get("n_force_rows", 0)), ) def _load_cached_tensor( path: Path, *, mmap_mode: Literal["r", "r+", "c"] | None, dtype: torch.dtype | None, device: torch.device | None, ) -> torch.Tensor: """Load one cached ``.npy`` array as a torch tensor.""" array = np.load(path, mmap_mode=mmap_mode) tensor = torch.as_tensor(array) if dtype is not None or device is not None: tensor = tensor.to(dtype=dtype, device=device) return tensor def _load_cached_block_matrix( cache_dir: Path, entry: object, *, mmap_mode: Literal["r", "r+", "c"] | None, dtype: torch.dtype | None, device: torch.device | None, row_indexed_blocks: bool, ) -> BlockMatrix: """Load one dense or row-indexed cached block matrix.""" if isinstance(entry, str): return _load_cached_tensor( cache_dir / entry, mmap_mode=mmap_mode, dtype=dtype, device=device, ) if not isinstance(entry, dict): raise ValueError("invalid least-squares cache block entry") storage = str(entry.get("storage")) if storage == "row_indexed": rows = _load_cached_tensor( cache_dir / str(entry["rows"]), mmap_mode=mmap_mode, dtype=torch.int64, device=device, ) values = _load_cached_tensor( cache_dir / str(entry["values"]), mmap_mode=mmap_mode, dtype=dtype, device=device, ) block = RowIndexedBlockMatrix( rows=rows, values=values, n_rows=int(entry["n_rows"]), ) if row_indexed_blocks: return block return block.materialize() if storage == "column_row_indexed": chunks = [] for chunk in entry["chunks"]: rows = _load_cached_tensor( cache_dir / str(chunk["rows"]), mmap_mode=mmap_mode, dtype=torch.int64, device=device, ) values = _load_cached_tensor( cache_dir / str(chunk["values"]), mmap_mode=mmap_mode, dtype=dtype, device=device, ) chunks.append( ColumnRowIndexedChunk( column_start=int(chunk["column_start"]), rows=rows, values=values, ) ) block = ColumnRowIndexedBlockMatrix( chunks=tuple(chunks), n_rows=int(entry["n_rows"]), n_cols=int(entry["n_cols"]), ) if row_indexed_blocks: return block return block.materialize() raise ValueError("invalid least-squares cache block entry") def _load_assembled_batch_entry_memmap( cache_dir: Path, entry: object, *, mmap_mode: Literal["r", "r+", "c"] | None, dtype: torch.dtype | None, device: torch.device | None, row_indexed_blocks: bool, row_weights: torch.Tensor | None = None, ) -> AssembledBatch: """Load one assembled-batch manifest entry from a disk cache.""" if not isinstance(entry, dict): raise ValueError("invalid least-squares cache batch entry") target = _load_cached_tensor( cache_dir / str(entry["target"]), mmap_mode=mmap_mode, dtype=dtype, device=device, ) block_matrices = { str(block_key): _load_cached_block_matrix( cache_dir, block_entry, mmap_mode=mmap_mode, dtype=dtype, device=device, row_indexed_blocks=row_indexed_blocks, ) for block_key, block_entry in sorted( entry["blocks"].items(), key=lambda item: str(item[0]) ) } batch = AssembledBatch( target=target, block_matrices=block_matrices, ) if row_weights is None: return batch return _apply_row_weights_to_assembled_batch(batch, row_weights)
[docs] def load_assembled_batches_memmap( directory: Path | str, *, manifest_name: str = "assembled_batches.json", mmap_mode: Literal["r", "r+", "c"] | None = "r+", dtype: torch.dtype | None = None, device: torch.device | None = None, expected_metadata: dict[str, object] | None = None, row_indexed_blocks: bool = False, row_weights: Sequence[torch.Tensor] | None = None, ) -> tuple[AssembledBatch, ...]: """ Load block-separated assembled batches from a disk cache. Args: directory: Directory containing a cache manifest and ``.npy`` files. manifest_name: JSON manifest filename inside ``directory``. mmap_mode: Mode passed to ``numpy.load``. Use ``None`` to load ordinary arrays. dtype: Optional dtype conversion for loaded tensors. device: Optional device conversion for loaded tensors. expected_metadata: Optional metadata payload that must match the cache. row_indexed_blocks: Whether to preserve sparse row-indexed block storage. row_weights: Optional row weights applied to loaded batches. Returns: Assembled least-squares batches. Raises: ValueError: If cache metadata or row-weight counts do not match. """ cache_dir = Path(directory) manifest = _load_assembled_batches_manifest( cache_dir, manifest_name=manifest_name, ) if expected_metadata is not None and not _cache_metadata_matches( manifest.get("metadata"), expected_metadata, ): raise ValueError( "least-squares cache metadata does not match the requested samples, " "targets, dtype, or model layout" ) cache_dir = Path(str(manifest.get("_cache_dir", cache_dir))) manifest_batches = tuple(manifest.get("batches", ())) if row_weights is not None and len(row_weights) != len(manifest_batches): raise ValueError( "number of least-squares cache weight batches does not match manifest" ) assembled_batches = [] for batch_index, batch in enumerate(manifest_batches): weight = None if row_weights is None else row_weights[batch_index] assembled_batches.append( _load_assembled_batch_entry_memmap( cache_dir, batch, mmap_mode=mmap_mode, dtype=dtype, device=device, row_indexed_blocks=row_indexed_blocks, row_weights=weight, ) ) return tuple(assembled_batches)
def _assembled_cache_manifest_exists( cache_directory: Path | str, *, manifest_name: str = "assembled_batches.json", ) -> bool: """Return whether a least-squares assembled-batch cache manifest exists.""" return (Path(cache_directory) / manifest_name).is_file() class _CachedBlockBatchSequence: """Lazy sequence of cached solve batches backed by ``.npy`` files.""" def __init__( self, *, cache_dir: Path, entries: Sequence[object], dtype: torch.dtype, device: torch.device, row_weights: Sequence[torch.Tensor] | None, repulsion_batch: BlockSolveBatch | None, load_batch_entry: Callable[..., AssembledBatch] = ( _load_assembled_batch_entry_memmap ), mmap_mode: Literal["r", "r+", "c"] | None = "r+", row_indexed_blocks: bool = True, ) -> None: """Store cache metadata needed to load one batch at a time.""" if row_weights is not None and len(row_weights) != len(entries): raise ValueError( "number of least-squares cache weight batches does not match manifest" ) self._cache_dir = cache_dir self._entries = tuple(entries) self._dtype = dtype self._device = device self._row_weights = None if row_weights is None else tuple(row_weights) self._repulsion_batch = repulsion_batch self._load_batch_entry = load_batch_entry self._mmap_mode = mmap_mode self._row_indexed_blocks = row_indexed_blocks def __len__(self) -> int: """Return the number of cached data batches plus optional pseudo-batch.""" return len(self._entries) + int(self._repulsion_batch is not None) def __iter__(self): """Yield batches by loading each cached batch only for this iteration.""" for index in range(len(self._entries)): yield self._load_data_batch(index) if self._repulsion_batch is not None: yield self._repulsion_batch def __getitem__(self, index): """Load one batch by index, or a tuple of batches for slices.""" if isinstance(index, slice): return tuple(self[item] for item in range(*index.indices(len(self)))) resolved = int(index) if resolved < 0: resolved += len(self) if resolved < 0 or resolved >= len(self): raise IndexError("cached least-squares batch index out of range") if resolved == len(self._entries): assert self._repulsion_batch is not None return self._repulsion_batch return self._load_data_batch(resolved) def _load_data_batch(self, index: int) -> BlockSolveBatch: """Load one cached data batch and apply current row weights.""" weight = None if self._row_weights is None else self._row_weights[index] batch = self._load_batch_entry( self._cache_dir, self._entries[index], mmap_mode=self._mmap_mode, dtype=self._dtype, device=self._device, row_indexed_blocks=self._row_indexed_blocks, row_weights=weight, ) return _block_solve_batch_from_assembled(batch)
[docs] class CachedBlockLinearProblem: """Matrix-free linear problem that streams cached batches from disk.""" def __init__( self, *, layout: BlockProblemLayout, cache_dir: Path, entries: Sequence[object], dtype: torch.dtype, device: torch.device, row_weights: Sequence[torch.Tensor], repulsion_batch: BlockSolveBatch | None = None, load_batch_entry: Callable[..., AssembledBatch] = ( _load_assembled_batch_entry_memmap ), row_indexed_blocks: bool = True, ) -> None: """Create a cached problem without materializing all batches.""" self.layout = layout self._dtype = dtype self._device = device self._batch_row_counts = tuple(int(weight.numel()) for weight in row_weights) if len(self._batch_row_counts) != len(entries): raise ValueError( "number of least-squares cache weight batches does not match manifest" ) self._repulsion_batch = repulsion_batch self._stream_progress = False self._stream_pass = 0 self._batches = _CachedBlockBatchSequence( cache_dir=cache_dir, entries=entries, dtype=dtype, device=device, row_weights=row_weights, repulsion_batch=repulsion_batch, load_batch_entry=load_batch_entry, row_indexed_blocks=bool(row_indexed_blocks), ) @property def batches(self) -> _CachedBlockBatchSequence: """Return a lazy sequence that loads cached batches on demand.""" return self._batches @property def n_rows(self) -> int: """Return the total number of target rows across all batches.""" total = sum(self._batch_row_counts) if self._repulsion_batch is not None: total += self._repulsion_batch.n_rows return int(total) @property def dtype(self) -> torch.dtype: """Return the dtype used when cached batches are streamed.""" return self._dtype @property def device(self) -> torch.device: """Return the device used when cached batches are streamed.""" return self._device def _iter_stream_batches(self, description: str): """Yield cached batches with an optional progress bar.""" return _iter_with_progress( self.batches, enabled=self._stream_progress, description=description, total=len(self.batches), )
[docs] def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply the regularized normal operator without concatenating rows.""" self._stream_pass += 1 theta = theta.reshape(self.layout.size) output = torch.zeros( (self.layout.size,), dtype=theta.dtype, device=theta.device, ) for batch in self._iter_stream_batches( f"Streaming CG matvec {self._stream_pass}" ): prediction = torch.zeros( (batch.n_rows,), dtype=theta.dtype, device=theta.device, ) for key, block_matrix in batch.matrices.items(): prediction = prediction + _block_matrix_matvec( block_matrix, theta[self.layout.theta_slice(key)], ).to(device=theta.device, dtype=theta.dtype) for key, block_matrix in batch.matrices.items(): output[self.layout.theta_slice(key)] += _block_matrix_rmatvec( block_matrix, prediction, ).to(device=theta.device, dtype=theta.dtype) return output + self.regularization_apply(theta)
[docs] def rhs(self) -> torch.Tensor: """Return the right-hand side by streaming cached batches.""" output = torch.zeros( (self.layout.size,), dtype=self.dtype, device=self.device, ) for batch in self._iter_stream_batches("Streaming cached RHS"): for key, block_matrix in batch.matrices.items(): output[self.layout.theta_slice(key)] += _block_matrix_rmatvec( block_matrix, batch.target, ).to(device=self.device, dtype=self.dtype) return output + self.regularization_rhs()
[docs] def normal_equation_diagonal(self) -> torch.Tensor: """Return the diagonal of ``A.T @ A`` by streaming cached batches.""" diagonal = torch.zeros( (self.layout.size,), dtype=self.dtype, device=self.device, ) for batch in self._iter_stream_batches("Streaming normal diagonal"): for key, block_matrix in batch.matrices.items(): diagonal[self.layout.theta_slice(key)] += _block_matrix_diagonal( block_matrix ) return diagonal
[docs] def design_trace_by_block(self) -> dict[object, float]: """Return weighted design-matrix trace contributions by solve block.""" diagonal = self.normal_equation_diagonal() return { block.key: float(diagonal[self.layout.theta_slice(block.key)].sum().item()) for block in self.layout.blocks }
[docs] def target_vector(self) -> torch.Tensor: """Concatenate targets, loading cached batches on demand.""" if not self.batches: return torch.zeros(0, dtype=self.dtype, device=self.device) return torch.cat([batch.target for batch in self.batches], dim=0)
[docs] def materialize_design_matrix(self) -> tuple[torch.Tensor, torch.Tensor]: """Build the explicit dense design matrix for debugging or tiny problems.""" matrix = torch.zeros( (self.n_rows, self.layout.size), dtype=self.dtype, device=self.device, ) targets: list[torch.Tensor] = [] offset = 0 for batch in self.batches: for key, block_matrix in batch.matrices.items(): matrix[offset : offset + batch.n_rows, self.layout.theta_slice(key)] = ( _materialize_block_matrix(block_matrix) ) targets.append(batch.target) offset += batch.n_rows if not targets: return matrix, torch.zeros(0, dtype=self.dtype, device=self.device) return matrix, torch.cat(targets, dim=0)
[docs] def matvec(self, theta: torch.Tensor) -> torch.Tensor: """Apply the design matrix, loading cached batches on demand.""" theta = theta.reshape(self.layout.size) outputs: list[torch.Tensor] = [] for batch in self.batches: prediction = torch.zeros( (batch.n_rows,), dtype=batch.target.dtype, device=batch.target.device, ) for key, block_matrix in batch.matrices.items(): prediction = prediction + _block_matrix_matvec( block_matrix, theta[self.layout.theta_slice(key)], ) outputs.append(prediction) if not outputs: return torch.zeros(0, dtype=self.dtype, device=self.device) return torch.cat(outputs, dim=0)
[docs] def rmatvec(self, residual: torch.Tensor) -> torch.Tensor: """Apply the transpose design matrix to a residual vector.""" residual = residual.reshape(self.n_rows) output = torch.zeros( (self.layout.size,), dtype=residual.dtype, device=residual.device, ) offset = 0 for batch in self.batches: batch_residual = residual[offset : offset + batch.n_rows] for key, block_matrix in batch.matrices.items(): output[self.layout.theta_slice(key)] += _block_matrix_rmatvec( block_matrix, batch_residual, ) offset += batch.n_rows return output
[docs] def regularization_apply(self, theta: torch.Tensor) -> torch.Tensor: """Apply all block regularizers to a flat parameter vector.""" theta = theta.reshape(self.layout.size) output = torch.zeros_like(theta) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) output[theta_slice] += block.regularization.apply(theta[theta_slice]) return output
[docs] def regularization_diagonal(self) -> torch.Tensor: """Return the summed diagonal preconditioner implied by regularizers.""" diag = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) diag[theta_slice] += block.regularization.diagonal( dtype=self.dtype, device=self.device, ) return diag
[docs] def regularization_rhs(self) -> torch.Tensor: """Return the summed RHS shifts implied by block regularizers.""" rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) rhs[theta_slice] += block.regularization.rhs( dtype=self.dtype, device=self.device, ) return rhs
[docs] def accumulate_normal_equations(self) -> tuple[torch.Tensor, torch.Tensor]: """Materialize the Gram matrix and right-hand side from streamed batches.""" gram = torch.zeros( (self.layout.size, self.layout.size), dtype=self.dtype, device=self.device, ) rhs = torch.zeros((self.layout.size,), dtype=self.dtype, device=self.device) for batch in self.batches: keys = tuple(batch.matrices) for key in keys: theta_slice = self.layout.theta_slice(key) block_matrix = batch.matrices[key] rhs[theta_slice] += _block_matrix_rmatvec(block_matrix, batch.target) for index_i, key_i in enumerate(keys): slice_i = self.layout.theta_slice(key_i) matrix_i = batch.matrices[key_i] gram[slice_i, slice_i] += _block_matrix_cross(matrix_i, matrix_i) for key_j in keys[index_i + 1 :]: slice_j = self.layout.theta_slice(key_j) cross = _block_matrix_cross(matrix_i, batch.matrices[key_j]) gram[slice_i, slice_j] += cross gram[slice_j, slice_i] += cross.T for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) gram[theta_slice, theta_slice] += block.regularization.materialize( dtype=self.dtype, device=self.device, ) rhs[theta_slice] += block.regularization.rhs( dtype=self.dtype, device=self.device, ) return gram, rhs
[docs] def objective(self, theta: torch.Tensor) -> torch.Tensor: """Evaluate the regularized objective without materializing residuals.""" theta = theta.reshape(self.layout.size) value = torch.zeros((), dtype=theta.dtype, device=theta.device) for batch in self.batches: prediction = torch.zeros( (batch.n_rows,), dtype=theta.dtype, device=theta.device, ) for key, block_matrix in batch.matrices.items(): prediction = prediction + _block_matrix_matvec( block_matrix, theta[self.layout.theta_slice(key)], ).to(device=theta.device, dtype=theta.dtype) target = batch.target.to(device=theta.device, dtype=theta.dtype) residual = prediction - target value = value + torch.dot(residual, residual) for block in self.layout.blocks: if block.regularization is None: continue theta_slice = self.layout.theta_slice(block.key) value = value + block.regularization.quadratic(theta[theta_slice]) return value
[docs] def residual_norm(self, theta: torch.Tensor) -> torch.Tensor: """Return ``||A theta - b||`` without materializing all residual rows.""" theta = theta.reshape(self.layout.size) squared = torch.zeros((), dtype=theta.dtype, device=theta.device) for batch in self.batches: prediction = torch.zeros( (batch.n_rows,), dtype=theta.dtype, device=theta.device, ) for key, block_matrix in batch.matrices.items(): prediction = prediction + _block_matrix_matvec( block_matrix, theta[self.layout.theta_slice(key)], ).to(device=theta.device, dtype=theta.dtype) target = batch.target.to(device=theta.device, dtype=theta.dtype) residual = prediction - target squared = squared + torch.dot(residual, residual) return torch.sqrt(torch.clamp(squared, min=0.0))
[docs] def solve( self, *, solver: str, cg_tolerance: float, cg_max_iter: int | None, progress: bool = False, progress_frequency: int = 10, initial_theta: torch.Tensor | None = None, cg_checkpoint_path: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, cg_checkpoint_metadata: dict[str, object] | None = None, fallback_theta: torch.Tensor | None = None, return_info: bool = False, ) -> torch.Tensor | LinearSolveResult: """Solve the streamed problem with the selected backend.""" def fallback_result() -> LinearSolveResult: if fallback_theta is not None: theta = fallback_theta.to(dtype=self.dtype, device=self.device) elif initial_theta is not None: theta = initial_theta.to(dtype=self.dtype, device=self.device) else: theta = torch.zeros( (self.layout.size,), dtype=self.dtype, device=self.device, ) return LinearSolveResult(theta=theta, interrupted=True) def maybe_return(result: LinearSolveResult) -> torch.Tensor | LinearSolveResult: return result if return_info else result.theta if solver == "dense_lstsq": try: matrix, target = self.materialize_design_matrix() reg_rows = [] reg_targets = [] for block in self.layout.blocks: if block.regularization is None: continue block_rows, block_target = block.regularization.least_squares_rows( dtype=self.dtype, device=self.device, ) if block_rows.shape[0] == 0: continue row_block = torch.zeros( (block_rows.shape[0], self.layout.size), dtype=self.dtype, device=self.device, ) row_block[:, self.layout.theta_slice(block.key)] = block_rows reg_rows.append(row_block) reg_targets.append(block_target) if reg_rows: matrix = torch.cat([matrix, *reg_rows], dim=0) target = torch.cat([target, *reg_targets], dim=0) return maybe_return( LinearSolveResult(torch.linalg.lstsq(matrix, target).solution) ) except KeyboardInterrupt: if not return_info: raise if progress: print( "Interrupted dense least-squares solve; " "using fallback coefficients." ) return fallback_result() if solver == "normal_equation_direct": try: gram, rhs = self.accumulate_normal_equations() if progress: print("Solving normal equations directly...") try: theta = torch.linalg.solve(gram, rhs) except RuntimeError: if progress: print( "Direct solve failed; falling back to torch.linalg.lstsq." ) theta = torch.linalg.lstsq(gram, rhs).solution return maybe_return(LinearSolveResult(theta)) except KeyboardInterrupt: if not return_info: raise if progress: print( "Interrupted normal-equation solve; " "using fallback coefficients." ) return fallback_result() if solver == "cg": self._stream_progress = progress self._stream_pass = 0 try: rhs = self.rhs() if cg_checkpoint_metadata is None: checkpoint_metadata = _cg_checkpoint_metadata( n_parameters=self.layout.size, dtype=self.dtype, ) else: checkpoint_metadata = dict(cg_checkpoint_metadata) checkpoint_state = ( None if not cg_resume or cg_checkpoint_path is None else load_cg_checkpoint( cg_checkpoint_path, dtype=self.dtype, device=self.device, expected_metadata=checkpoint_metadata, ) ) result = _conjugate_gradient( self.normal_matvec, rhs, diagonal_preconditioner=self.regularization_diagonal() + self.normal_equation_diagonal(), tolerance=cg_tolerance, max_iter=cg_max_iter, progress=progress, progress_frequency=progress_frequency, initial_guess=initial_theta, checkpoint_state=checkpoint_state, checkpoint_path=cg_checkpoint_path, checkpoint_frequency=cg_checkpoint_frequency, checkpoint_metadata=checkpoint_metadata, handle_interrupts=return_info, ) return maybe_return(result) except KeyboardInterrupt: if not return_info: raise if progress: print("Interrupted CG setup; using fallback coefficients.") return fallback_result() finally: self._stream_progress = False choices = ", ".join(["dense_lstsq", "normal_equation_direct", "cg"]) raise ValueError(f"Unsupported solver '{solver}'. Expected one of: {choices}.")
__all__ = [ "AssembledBatchCacheMode", "CachedBlockLinearProblem", "NormalEquationComponents", "assembled_cache_dir", "load_assembled_batches_memmap", "normal_equation_cache_dir", "save_assembled_batches_memmap", ]