Source code for ufp.leastsquares.linear

"""
Direct linear fitting over assembled spline coefficient blocks.

Use this module to build least-squares problems, materialize diagnostic
matrices, and solve for direct coefficient vectors with dense or iterative
backends.
"""

from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Optional, Sequence

import torch

from ufp.core.output import UFPOutput
from ufp.leastsquares._assemble import (
    assemble_true_blocks,
    assemble_true_blocks_by_terms,
)
from ufp.leastsquares._block import (
    BlockMatrix,
    BlockProblemLayout,
    BlockSolveBatch,
    ColumnRowIndexedBlockMatrix,
    MatrixStorageMode,
    RowIndexedBlockMatrix,
    SolveBlock,
    _block_matrix_matvec,  # noqa: F401 - private compatibility import
    _compact_block_matrix_for_storage,
    _materialize_block_matrix,
)
from ufp.leastsquares._cache import (
    AssembledBatchCacheMode,
    CachedBlockLinearProblem,
    _assembled_cache_manifest_exists,
    _assembled_cache_metadata_for_fit,
    _cache_metadata_can_project,
    _cache_metadata_matches,
    _cache_metadata_mismatch_reasons,
    _load_assembled_batch_entry_memmap,
    _load_assembled_batch_manifest,
    _load_assembled_batches_manifest,
    _matching_assembled_batch_cache_size,
    _normal_equation_target_weights,
    _sample_signature,
    _samples_with_unit_target_weights,
    _write_assembled_batch_memmap,
    _write_assembled_batches_manifest,
    assembled_cache_dir,
    load_assembled_batches_memmap,
    save_assembled_batches_memmap,
)
from ufp.leastsquares._cache_layout import (
    CacheWritePlan,
    build_cache_projection_plan,
    build_cache_write_plan,
    cache_blocks_from_metadata,
    project_batch_to_cache,
    project_cache_batch_to_layout,
)
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._normal_equations import DenseNormalEquationMixin
from ufp.leastsquares._problem import (
    BlockLinearProblem,
    CGCheckpointState,
    LinearFitResult,
    LinearSolveResult,
    _cg_checkpoint_metadata,
    load_cg_checkpoint,
    save_cg_checkpoint,
)
from ufp.leastsquares._selection import (
    CoefficientSelector,
    SelectedCoefficientBlock,
    block_matches_selector,
)
from ufp.leastsquares._setup import (
    build_linear_fitter_selection_plan,
    current_selected_vector,
    fixed_coefficients_signature,
    selected_coefficient_metadata,
    write_selected_vector,
)
from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares._utils import _infer_tensor_options, _iter_with_progress
from ufp.leastsquares.dataset import FitSample, PreparedBatch, prepare_batches
from ufp.leastsquares.regularization import (
    BlockRegularization,
    RegularizationStencil,
    _make_block_regularization,
)
from ufp.terms._twobody_shape import (
    TwoBodySplineShapePenalty,
    normalize_twobody_shape_penalty,
)
from ufp.terms.model import UFPModel
from ufp.terms.twobody import SplineTwoBodyTerm


def _twobody_shape_regularization_rows(block) -> tuple[int, ...] | None:
    """Return active coefficient rows for two-body regularization."""
    if isinstance(block.term, SplineTwoBodyTerm):
        return tuple(int(index) for index in block.term._active_pair_indices)
    return None


_TWOBODY_SHAPE_REGULARIZATION_SEMANTICS = "boundary_aware_partial_v1"
_THIRD_DIFFERENCE_STENCIL = (-1.0, 3.0, -3.0, 1.0)


[docs] class LinearFitter(DenseNormalEquationMixin): """High-level entry point for direct least-squares fitting of spline models.""" def __init__( self, model: UFPModel, *, fit_energy: bool = True, fit_forces: bool = True, fit_per_atom_energy: bool = False, solver: str = "cg", ridge: float = 0.0, onebody_ridge: float | None = None, pair_ridge: float | None = None, twobody_ridge: float | None = None, threebody_ridge: float | None = None, twobody_shape_penalty: TwoBodySplineShapePenalty | None = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, cg_tolerance: float = 1.0e-10, cg_max_iter: int | None = None, threebody_lstsq_backend: str | None = None, threebody_bucket_backend: str | None = None, fit_blocks: Sequence[int | str] | None = None, freeze_blocks: Sequence[int | str] = (), assembly_contract: str = "term", matrix_storage: MatrixStorageMode = "auto", ) -> None: """Store solve options and freeze the model parameter layout.""" if not (fit_energy or fit_forces or fit_per_atom_energy): raise ValueError("at least one target type must be enabled") inferred_dtype, inferred_device = _infer_tensor_options(model) self.model = model self.fit_energy = bool(fit_energy) self.fit_forces = bool(fit_forces) self.fit_per_atom_energy = bool(fit_per_atom_energy) self.solver = str(solver) self.ridge = float(ridge) self.onebody_ridge = ( self.ridge if onebody_ridge is None else float(onebody_ridge) ) if pair_ridge is not None and twobody_ridge is not None: raise ValueError("use only one of `pair_ridge` and `twobody_ridge`") resolved_pair_ridge = pair_ridge if pair_ridge is not None else twobody_ridge self.pair_ridge = ( self.ridge if resolved_pair_ridge is None else float(resolved_pair_ridge) ) self.threebody_ridge = ( self.ridge if threebody_ridge is None else float(threebody_ridge) ) if self.ridge < 0.0: raise ValueError("`ridge` must be non-negative") if self.onebody_ridge < 0.0: raise ValueError("`onebody_ridge` must be non-negative") if self.pair_ridge < 0.0: raise ValueError("`pair_ridge` must be non-negative") if self.threebody_ridge < 0.0: raise ValueError("`threebody_ridge` must be non-negative") self.twobody_shape_penalty = normalize_twobody_shape_penalty( twobody_shape_penalty ) self.dtype = inferred_dtype if dtype is None else dtype self.device = inferred_device if device is None else device self.cg_tolerance = float(cg_tolerance) self.cg_max_iter = cg_max_iter self.threebody_lstsq_backend = threebody_lstsq_backend self.threebody_bucket_backend = threebody_bucket_backend self.layout = ParameterLayout.from_model(model, include_frozen=True) self.fit_blocks = None if fit_blocks is None else tuple(fit_blocks) self.freeze_blocks = tuple(freeze_blocks) self.selection_plan = build_linear_fitter_selection_plan( self.layout, fit_blocks=self.fit_blocks, freeze_blocks=self.freeze_blocks, ) self.selected_coefficients = self.selection_plan.selected_coefficients self.assembly_contract = str(assembly_contract) if self.assembly_contract not in {"block", "term"}: raise ValueError("`assembly_contract` must be 'block' or 'term'") self.matrix_storage = str(matrix_storage) if self.matrix_storage not in { "dense", "row_indexed", "column_chunked", "auto", }: raise ValueError( "`matrix_storage` must be one of: dense, row_indexed, " "column_chunked, auto" ) def _block_matches_selector(self, block, selector: int | str) -> bool: """Return whether a block matches one include/exclude selector.""" return block_matches_selector(block, selector) def _selected_block_indices(self) -> tuple[int, ...]: """Return layout block indices included in direct assembly and solving.""" return self.selection_plan.selected_block_indices def _selected_by_block(self) -> dict[int, SelectedCoefficientBlock]: """Return selected coefficient metadata keyed by original block index.""" return dict(self.selection_plan.selected_by_block) def _current_selected_vector( self, *, dtype: torch.dtype | None = None, device: torch.device | None = None, ) -> torch.Tensor: """Read the current selected coefficients into the compact solve layout.""" return current_selected_vector( self.selected_coefficients, dtype=dtype, device=device, ) def _write_selected_vector(self, theta: torch.Tensor) -> None: """Write compact selected coefficients back while preserving fixed entries.""" write_selected_vector(self.selected_coefficients, theta) def _selection_metadata(self) -> list[dict[str, object]]: """Return compact metadata describing the selected coefficient columns.""" return selected_coefficient_metadata(self.selected_coefficients) def _fixed_coefficients_signature(self) -> str: """Return a digest of current coefficient values outside the solve layout.""" return fixed_coefficients_signature(self.layout, self.selected_coefficients) def _selected_size(self) -> int: """Return the compact solve vector length for selected coefficients.""" return self.selection_plan.selected_size def _uses_full_direct_layout(self) -> bool: """Return whether the selected vector matches the full layout exactly.""" return self.selection_plan.uses_full_direct_layout def _regularization_metadata(self) -> dict[str, object]: """Return serializable solve settings that affect CG checkpoint state.""" penalty = self.twobody_shape_penalty return { "semantics": self._regularization_semantics_metadata(), "ridge": float(self.ridge), "onebody_ridge": float(self.onebody_ridge), "pair_ridge": float(self.pair_ridge), "threebody_ridge": float(self.threebody_ridge), "twobody_shape_penalty": { "third_difference_weight": float(penalty.third_difference_weight), "first_coefficient_min": float(penalty.first_coefficient_min), "first_coefficient_weight": float(penalty.first_coefficient_weight), "initial_curvature_weight": float(penalty.initial_curvature_weight), "decreasing_weight": float(penalty.decreasing_weight), "decreasing_count": int(penalty.decreasing_count), }, } def _regularization_semantics_metadata(self) -> dict[str, object]: """Return metadata for cache invalidation across regularization semantics.""" return {"twobody_shape_regularization": _TWOBODY_SHAPE_REGULARIZATION_SEMANTICS} def _coefficient_checkpoint_metadata( self, *, n_parameters: int, dtype: torch.dtype, ) -> dict[str, object]: """Return CG metadata for the selected coefficient layout.""" metadata = _cg_checkpoint_metadata( n_parameters=n_parameters, dtype=dtype, ) metadata.update( { "selected_block_indices": [ int(index) for index in self._selected_block_indices() ], "coefficient_selection": self._selection_metadata(), "fixed_coefficients_signature": (self._fixed_coefficients_signature()), } ) return metadata def _fit_checkpoint_metadata( self, samples: Sequence[FitSample], *, n_parameters: int, dtype: torch.dtype, ) -> dict[str, object]: """Return CG metadata for one concrete selected least-squares problem.""" metadata = self._coefficient_checkpoint_metadata( n_parameters=n_parameters, dtype=dtype, ) metadata.update(self._threebody_backend_metadata()) metadata.update( { "sample_signature": _sample_signature( samples, include_weights=True, ), "target_layout": { "fit_energy": bool(self.fit_energy), "fit_forces": bool(self.fit_forces), "fit_per_atom_energy": bool(self.fit_per_atom_energy), }, "regularization": self._regularization_metadata(), } ) return metadata @staticmethod def _checkpoint_metadata_contains( actual: dict[str, object], expected: dict[str, object], ) -> bool: """Return whether one checkpoint metadata payload contains all expected keys.""" return all(actual.get(key) == value for key, value in expected.items())
[docs] @classmethod def from_model(cls, model: UFPModel, **kwargs) -> "LinearFitter": """Convenience constructor mirroring ``__init__``.""" return cls(model, **kwargs)
def _prepare_batches( self, samples: Sequence[FitSample], *, batch_size: int, ) -> tuple[PreparedBatch, ...]: """Normalize fit samples into prepared geometry and target batches.""" return prepare_batches( self.model, samples, batch_size=batch_size, fit_energy=self.fit_energy, fit_forces=self.fit_forces, fit_per_atom_energy=self.fit_per_atom_energy, dtype=self.dtype, device=self.device, ) def _iter_prepared_batches( self, samples: Sequence[FitSample], *, batch_size: int, device: torch.device | None = None, ): """Yield prepared geometry batches without retaining all chunks.""" items = tuple(samples) if not items: raise ValueError("`samples` must contain at least one FitSample") if batch_size <= 0: raise ValueError("`batch_size` must be positive") target_device = self.device if device is None else device for start in range(0, len(items), batch_size): chunk = items[start : start + batch_size] prepared = prepare_batches( self.model, chunk, batch_size=len(chunk), fit_energy=self.fit_energy, fit_forces=self.fit_forces, fit_per_atom_energy=self.fit_per_atom_energy, dtype=self.dtype, device=target_device, ) yield from prepared def _assemble_true_blocks(self, batch: PreparedBatch) -> AssembledBatch: """Assemble one prepared batch and residualize fixed coefficients.""" selected = self._assemble_selected_true_blocks(batch) if selected is not None: return selected assembler = ( assemble_true_blocks_by_terms if self.assembly_contract == "term" else assemble_true_blocks ) assembled = assembler( batch, self.layout, threebody_lstsq_backend=self.threebody_lstsq_backend, threebody_bucket_backend=self.threebody_bucket_backend, ) return self._compact_selected_assembled_batch(assembled) def _assemble_selected_true_blocks( self, batch: PreparedBatch, ) -> AssembledBatch | None: """Assemble compact selected columns through optional term-level hooks.""" if self.assembly_contract != "term" or self._uses_full_direct_layout(): return None if not self._can_residualize_fixed_coefficients_from_outputs(): return None row_weights = (batch.targets.sqrt_weights * batch.targets.row_scales)[:, None] matrices: dict[int, torch.Tensor] = {} for selection in self.selected_coefficients: matrix = self._assemble_selected_matrix(batch, selection) if matrix is not None: matrices[int(selection.block.index)] = matrix * row_weights return AssembledBatch( target=( batch.targets.weighted_values - self._fixed_coefficient_prediction(batch) ), block_matrices=matrices, ) def _assemble_selected_matrix( self, batch: PreparedBatch, selection: SelectedCoefficientBlock, ) -> torch.Tensor | None: """Assemble one compact selected-column block, falling back when needed.""" selected_method = getattr( selection.block.term, "assemble_selected_linear_block", None, ) used_selected_method = False matrix = None if not selection.is_full_block and callable(selected_method): used_selected_method = True matrix = selected_method( selection.block, batch.inputs, batch.targets, selection.indices, ) if matrix is not None and tuple(matrix.shape) != ( batch.targets.n_rows, selection.size, ): raise ValueError( "selected linear assembler for block " f"{selection.block.label!r} returned shape " f"{tuple(matrix.shape)}, expected " f"({batch.targets.n_rows}, {selection.size})" ) if matrix is not None or used_selected_method: return matrix fallback = assemble_true_blocks_by_terms( batch, self.layout, selected_block_indices=(int(selection.block.index),), threebody_lstsq_backend=self.threebody_lstsq_backend, threebody_bucket_backend=self.threebody_bucket_backend, ) full_matrix = fallback.block_matrices.get(int(selection.block.index)) if full_matrix is None: return None if selection.is_full_block: return full_matrix selected_index = torch.tensor( selection.indices, dtype=torch.int64, device=full_matrix.device, ) return full_matrix.index_select(1, selected_index) def _can_residualize_fixed_coefficients_from_outputs(self) -> bool: """Return whether blocks can be temporarily zeroed for residualization.""" for block in self.layout.blocks: provider = block.coefficient_provider if provider is not None and not provider.uses_identity_weights: return False return True def _fixed_coefficient_prediction(self, batch: PreparedBatch) -> torch.Tensor: """Evaluate fixed coefficient contributions as weighted target rows.""" originals: list[tuple[TermBlock, torch.Tensor]] = [] rows = torch.zeros_like(batch.targets.values) try: for selection in self.selected_coefficients: block = selection.block current = block.read().detach().clone() originals.append((block, current)) flat = current.reshape(-1).clone() index = torch.tensor( selection.indices, dtype=torch.int64, device=flat.device, ) flat.index_fill_(0, index, 0.0) block.write(flat.reshape(block.shape)) with torch.no_grad(): for term in self._layout_terms(): self._accumulate_output_rows(rows, batch, term(batch.inputs)) finally: for block, values in reversed(originals): block.write(values) return rows * batch.targets.sqrt_weights * batch.targets.row_scales def _layout_terms(self) -> tuple[torch.nn.Module, ...]: """Return unique terms that own coefficient blocks in layout order.""" seen: set[int] = set() terms: list[torch.nn.Module] = [] for block in self.layout.blocks: term_id = id(block.term) if term_id in seen: continue seen.add(term_id) terms.append(block.term) return tuple(terms) def _accumulate_output_rows( self, rows: torch.Tensor, batch: PreparedBatch, output: UFPOutput, ) -> None: """Add one term output to unweighted least-squares target rows.""" inputs = batch.inputs targets = batch.targets if output.energy is not None: energy = torch.as_tensor( output.energy, dtype=inputs.dtype, device=inputs.device, ) energy = energy.reshape(inputs.n_systems, -1) if energy.shape[1] != 1: raise ValueError("term energy must provide one value per system") valid = targets.energy_rows >= 0 if torch.any(valid): rows.index_add_( 0, targets.energy_rows[valid], energy[:, 0][valid], ) if output.forces is not None: forces = torch.as_tensor( output.forces, dtype=inputs.dtype, device=inputs.device, ) if tuple(forces.shape) != (inputs.n_atoms, 3): raise ValueError( "term forces must have shape " f"({inputs.n_atoms}, 3), got {tuple(forces.shape)}" ) valid = targets.force_rows >= 0 if torch.any(valid): rows.index_add_(0, targets.force_rows[valid], forces[valid]) if output.per_atom_energy is not None: per_atom_energy = torch.as_tensor( output.per_atom_energy, dtype=inputs.dtype, device=inputs.device, ) if per_atom_energy.ndim == 2 and tuple(per_atom_energy.shape) == ( inputs.n_atoms, 1, ): per_atom_energy = per_atom_energy[:, 0] else: per_atom_energy = per_atom_energy.reshape(-1) if tuple(per_atom_energy.shape) != (inputs.n_atoms,): raise ValueError( "term per-atom energy must have shape " f"({inputs.n_atoms},), got {tuple(per_atom_energy.shape)}" ) valid = targets.per_atom_rows >= 0 if torch.any(valid): rows.index_add_( 0, targets.per_atom_rows[valid], per_atom_energy[valid], ) def _compact_selected_assembled_batch( self, batch: AssembledBatch, ) -> AssembledBatch: """Split full-block matrices into selected columns and fixed target.""" selected_by_block = self._selected_by_block() target = batch.target.clone() matrices: dict[int, torch.Tensor] = {} for block in self.layout.blocks: matrix = batch.block_matrices.get(int(block.index)) if matrix is None: continue dense = ( matrix if isinstance(matrix, torch.Tensor) else _materialize_block_matrix(matrix) ) selection = selected_by_block.get(int(block.index)) selected_indices = set(()) if selection is None else set(selection.indices) fixed_indices = tuple( index for index in range(block.size) if index not in selected_indices ) current = ( block.read() .reshape(-1) .to( dtype=dense.dtype, device=dense.device, ) ) if fixed_indices: fixed_index = torch.tensor( fixed_indices, dtype=torch.int64, device=dense.device, ) fixed_matrix = dense.index_select(1, fixed_index) fixed_theta = current.index_select(0, fixed_index) target = target - fixed_matrix @ fixed_theta if selection is None: continue selected_index = torch.tensor( selection.indices, dtype=torch.int64, device=dense.device, ) matrices[int(block.index)] = dense.index_select(1, selected_index) return AssembledBatch( target=target, block_matrices=matrices, ) def _threebody_backend_metadata(self) -> dict[str, object]: """Return backend settings that affect assembly cache construction.""" return { "threebody_lstsq_backend": ( os.environ.get("UFP_THREEBODY_LSTSQ_BACKEND", "auto") if self.threebody_lstsq_backend is None else str(self.threebody_lstsq_backend) ), "threebody_bucket_backend": ( os.environ.get("UFP_THREEBODY_BUCKET_BACKEND", "auto") if self.threebody_bucket_backend is None else str(self.threebody_bucket_backend) ), "assembly_contract": self.assembly_contract, } def _partial_twobody_third_difference_stencils( self, selection: SelectedCoefficientBlock, ) -> tuple[RegularizationStencil, ...]: """Return boundary-aware third-difference rows for partial selections.""" block = selection.block if block.kind not in {"pair", "twobody"}: return () if selection.is_full_block: return () if len(block.shape) not in {1, 2} or block.shape[-1] < 4: return () selected_columns = { int(original): compact for compact, original in enumerate(selection.indices) } values = block.read().detach().reshape(-1).cpu() rows: list[RegularizationStencil] = [] if len(block.shape) == 1: row_offsets = (0,) n_coeffs = int(block.shape[0]) else: n_coeffs = int(block.shape[1]) active_rows = _twobody_shape_regularization_rows(block) if active_rows is None: active_rows = tuple(range(int(block.shape[0]))) row_offsets = tuple(int(row) * n_coeffs for row in active_rows) for offset in row_offsets: for coeff in range(n_coeffs - 3): compact_columns: list[int] = [] compact_weights: list[float] = [] fixed_term = 0.0 for local, weight in enumerate(_THIRD_DIFFERENCE_STENCIL): original = int(offset + coeff + local) compact = selected_columns.get(original) if compact is None: fixed_term += float(weight) * float(values[original].item()) else: compact_columns.append(int(compact)) compact_weights.append(float(weight)) if not compact_columns: continue rows.append( RegularizationStencil( columns=tuple(compact_columns), weights=tuple(compact_weights), target=-fixed_term, ) ) return tuple(rows) def _direct_blocks(self) -> tuple[SolveBlock, ...]: """Build solve metadata for every block handled by the direct linear system.""" blocks = [] for selection in self.selected_coefficients: block = selection.block group = block.regularization_group ridge = self.ridge if group == "onebody": ridge = self.onebody_ridge elif group in {"pair", "twobody"}: ridge = self.pair_ridge elif group == "threebody": ridge = self.threebody_ridge third_difference_penalty = ( self.twobody_shape_penalty.third_difference_weight if block.kind in {"pair", "twobody"} else 0.0 ) active_rows = None third_difference_stencils = None if third_difference_penalty > 0.0: if selection.is_full_block: active_rows = _twobody_shape_regularization_rows(block) else: third_difference_stencils = ( self._partial_twobody_third_difference_stencils(selection) ) blocks.append( SolveBlock( key=block.index, size=selection.size, label=block.label, regularization=_make_block_regularization( selection.shape, ridge=ridge, third_difference_penalty=third_difference_penalty, active_rows=active_rows, third_difference_stencils=third_difference_stencils, ), ) ) return tuple(blocks) def _column_chunk_sizes(self) -> dict[int, int]: """Return per-block coefficient category widths for sparse cache storage.""" chunk_sizes: dict[int, int] = {} for selection in self.selected_coefficients: if len(selection.shape) < 2 or int(selection.shape[0]) <= 1: continue chunk_size = 1 for dim in selection.shape[1:]: chunk_size *= int(dim) chunk_sizes[int(selection.block.index)] = int(chunk_size) return chunk_sizes def _apply_matrix_storage( self, batch: AssembledBatch, *, for_cache: bool = False, ) -> AssembledBatch: """Return an assembled batch using this fitter's block storage option.""" dense_auto = ( self.matrix_storage == "auto" and self.solver == "cg" and not for_cache ) if self.matrix_storage == "dense" or dense_auto: return AssembledBatch( target=batch.target, block_matrices={ key: _materialize_block_matrix(matrix) for key, matrix in batch.block_matrices.items() }, ) column_chunk_sizes = self._column_chunk_sizes() return AssembledBatch( target=batch.target, block_matrices={ key: _compact_block_matrix_for_storage( matrix, mode=self.matrix_storage, # type: ignore[arg-type] column_chunk_size=column_chunk_sizes.get(key), ) for key, matrix in batch.block_matrices.items() }, ) def _preserve_cached_compact_blocks(self) -> bool: """Return whether streamed cache loads should preserve compact storage.""" return self.matrix_storage != "dense" def _cache_write_plan(self) -> CacheWritePlan: """Return the semantic cache layout for this fitter's solve layout.""" return build_cache_write_plan( self.selected_coefficients, reusable=self._uses_full_direct_layout(), ) def _projected_cache_loader( self, metadata: object, cache_write_plan: CacheWritePlan, ): """Return a batch loader projecting semantic cache blocks to solve blocks.""" source_blocks = cache_blocks_from_metadata(metadata) projection_plan = build_cache_projection_plan(source_blocks, cache_write_plan) def load_projected_batch(cache_dir, entry, **kwargs): batch = _load_assembled_batch_entry_memmap(cache_dir, entry, **kwargs) projected = project_cache_batch_to_layout(batch, projection_plan) if self.matrix_storage == "dense": return self._apply_matrix_storage(projected, for_cache=True) return projected return load_projected_batch def _cached_problem_from_manifest( self, *, manifest: dict[str, object], cache_dir: Path, cache_write_plan: CacheWritePlan, items: Sequence[FitSample], default_batch_size: int, ) -> CachedBlockLinearProblem: """Create a streamed cached problem from a compatible manifest.""" metadata = manifest.get("metadata") if not isinstance(metadata, dict): raise ValueError("least-squares cache metadata is missing") cache_batch_size = int(metadata.get("batch_size", default_batch_size)) row_weights = self._prepared_batch_sqrt_weights( items, batch_size=cache_batch_size, device=torch.device("cpu"), ) resolved_cache_dir = Path(str(manifest.get("_cache_dir", cache_dir))) return CachedBlockLinearProblem( layout=BlockProblemLayout.from_blocks(self._direct_blocks()), cache_dir=resolved_cache_dir, entries=tuple(manifest.get("batches", ())), dtype=self.dtype, device=self.device, row_weights=row_weights, load_batch_entry=self._projected_cache_loader( metadata, cache_write_plan, ), row_indexed_blocks=self._preserve_cached_compact_blocks(), ) def _find_projectable_cache( self, *, cache_parent: Path, expected_metadata: dict[str, object], cache_write_plan: CacheWritePlan, items: Sequence[FitSample], default_batch_size: int, skip_dirs: set[Path] | None = None, ) -> CachedBlockLinearProblem | None: """Find a sibling assembled cache that can be projected to this layout.""" skipped = {path.resolve() for path in (() if skip_dirs is None else skip_dirs)} for manifest_path in sorted(cache_parent.glob("*/assembled_batches.json")): candidate_dir = manifest_path.parent if candidate_dir.resolve() in skipped: continue try: manifest = _load_assembled_batches_manifest(candidate_dir) except (OSError, ValueError, json.JSONDecodeError): continue metadata = manifest.get("metadata") if not _cache_metadata_can_project(metadata, expected_metadata): continue try: return self._cached_problem_from_manifest( manifest=manifest, cache_dir=candidate_dir, cache_write_plan=cache_write_plan, items=items, default_batch_size=default_batch_size, ) except ValueError: continue return None def _prepared_batch_sqrt_weights( self, samples: Sequence[FitSample], *, batch_size: int, device: torch.device | None = None, ) -> tuple[torch.Tensor, ...]: """Return current target row weights in prepared-batch order.""" resolved_device = self.device if device is None else device items = tuple(samples) if not items: raise ValueError("`samples` must contain at least one FitSample") if batch_size <= 0: raise ValueError("`batch_size` must be positive") weights: list[torch.Tensor] = [] for start in range(0, len(items), batch_size): chunk = items[start : start + batch_size] prepared_batches = prepare_batches( self.model, chunk, batch_size=len(chunk), fit_energy=self.fit_energy, fit_forces=self.fit_forces, fit_per_atom_energy=self.fit_per_atom_energy, dtype=self.dtype, device=resolved_device, ) weights.extend( prepared.targets.sqrt_weights for prepared in prepared_batches ) return tuple(weights)
[docs] def build_problem( self, samples: Sequence[FitSample], *, batch_size: int = 32, progress: bool = False, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", ) -> BlockLinearProblem: """Prepare batches, assemble block matrices, and return the linear problem.""" items = tuple(samples) if not items: raise ValueError("`samples` must contain at least one FitSample") if batch_size <= 0: raise ValueError("`batch_size` must be positive") expected_metadata = _assembled_cache_metadata_for_fit( layout=self.layout, samples=items, fit_energy=self.fit_energy, fit_forces=self.fit_forces, fit_per_atom_energy=self.fit_per_atom_energy, dtype=self.dtype, batch_size=batch_size, ) expected_metadata.update(self._threebody_backend_metadata()) expected_metadata["regularization_semantics"] = ( self._regularization_semantics_metadata() ) expected_metadata["selected_block_indices"] = [ int(index) for index in self._selected_block_indices() ] expected_metadata["coefficient_selection"] = self._selection_metadata() expected_metadata["fixed_coefficients_signature"] = ( self._fixed_coefficients_signature() ) cache_write_plan = self._cache_write_plan() expected_metadata["cache_blocks"] = cache_write_plan.metadata() expected_metadata["cache_reusable"] = bool(cache_write_plan.reusable) total_batches = (len(items) + batch_size - 1) // batch_size cache_parent = None if cache_directory is None else Path(cache_directory) cache_dir = ( None if cache_parent is None else assembled_cache_dir(cache_parent, expected_metadata) ) if cache_directory is not None and cache_mode in {"auto", "read"}: assert cache_parent is not None assert cache_dir is not None read_cache_dir = cache_dir if not _assembled_cache_manifest_exists(read_cache_dir): legacy_cache_dir = cache_parent if _assembled_cache_manifest_exists(legacy_cache_dir): read_cache_dir = legacy_cache_dir if _assembled_cache_manifest_exists(read_cache_dir): try: manifest = _load_assembled_batches_manifest(read_cache_dir) except ValueError: if cache_mode == "read": raise manifest = None if progress: print( "Ignoring incompatible least-squares cache in " f"{read_cache_dir}; rebuilding..." ) if manifest is None: metadata_matches = False metadata_projectable = False else: metadata_matches = _cache_metadata_matches( manifest.get("metadata"), expected_metadata, ) metadata_projectable = _cache_metadata_can_project( manifest.get("metadata"), expected_metadata, ) if metadata_matches: if progress: print( "Using cached least-squares batches from " f"{read_cache_dir}..." ) return self._cached_problem_from_manifest( manifest=manifest, cache_dir=read_cache_dir, cache_write_plan=cache_write_plan, items=items, default_batch_size=batch_size, ) if metadata_projectable: try: if progress: print( "Using projectable least-squares cache from " f"{read_cache_dir}..." ) return self._cached_problem_from_manifest( manifest=manifest, cache_dir=read_cache_dir, cache_write_plan=cache_write_plan, items=items, default_batch_size=batch_size, ) except ValueError: if cache_mode == "read": raise compatible_problem = self._find_projectable_cache( cache_parent=cache_parent, expected_metadata=expected_metadata, cache_write_plan=cache_write_plan, items=items, default_batch_size=batch_size, skip_dirs=( set() if manifest is None else {Path(str(manifest.get("_cache_dir", read_cache_dir)))} ), ) if compatible_problem is not None: if progress: print("Using projectable sibling least-squares cache...") return compatible_problem if cache_mode == "read": raise ValueError( "least-squares cache metadata does not match the requested " "samples, targets, dtype, or model layout" ) if progress: mismatch_reasons = ( () if manifest is None else _cache_metadata_mismatch_reasons( manifest.get("metadata"), expected_metadata, ) ) reason_suffix = ( "" if not mismatch_reasons else f" ({', '.join(mismatch_reasons)})" ) print( "Ignoring stale least-squares cache in " f"{read_cache_dir}{reason_suffix}; rebuilding..." ) elif cache_mode == "read": cached_batch_size = _matching_assembled_batch_cache_size( cache_dir, expected_metadata=expected_metadata, ) restore_batch_size = ( batch_size if cached_batch_size is None else cached_batch_size ) restore_total_batches = ( len(items) + restore_batch_size - 1 ) // restore_batch_size try: manifest_batches = [ _load_assembled_batch_manifest( cache_dir, batch_index, expected_metadata=expected_metadata, ) for batch_index in range(restore_total_batches) ] except ValueError: raise if all(entry is not None for entry in manifest_batches): if progress: print( "Restoring least-squares manifest from completed " f"batch caches in {cache_dir}..." ) restored_entries = [ entry for entry in manifest_batches if entry is not None ] restored_metadata = dict(expected_metadata) restored_metadata["batch_size"] = int(restore_batch_size) _write_assembled_batches_manifest( cache_dir, restored_entries, metadata=restored_metadata, ) return self._cached_problem_from_manifest( manifest={ "_cache_dir": str(cache_dir), "batches": restored_entries, "metadata": restored_metadata, }, cache_dir=cache_dir, cache_write_plan=cache_write_plan, items=items, default_batch_size=restore_batch_size, ) compatible_problem = self._find_projectable_cache( cache_parent=cache_parent, expected_metadata=expected_metadata, cache_write_plan=cache_write_plan, items=items, default_batch_size=batch_size, skip_dirs={read_cache_dir}, ) if compatible_problem is not None: if progress: print("Using projectable sibling least-squares cache...") return compatible_problem if cache_mode == "read": if any(cache_parent.glob("*/assembled_batches.json")): raise ValueError( "least-squares cache metadata does not match the requested " "samples, targets, dtype, or model layout" ) raise FileNotFoundError( "least-squares cache requested in read mode, but no complete " f"cache was found in {cache_dir}" ) if progress: print(f"Preparing least-squares batches with batch_size={batch_size}...") if cache_directory is not None and cache_mode in {"auto", "write", "refresh"}: assert cache_dir is not None cache_dir.mkdir(parents=True, exist_ok=True) assembly_batch_size = batch_size if cache_mode != "refresh": cached_batch_size = _matching_assembled_batch_cache_size( cache_dir, expected_metadata=expected_metadata, ) if cached_batch_size is not None: assembly_batch_size = cached_batch_size assembly_metadata = dict(expected_metadata) assembly_metadata["batch_size"] = int(assembly_batch_size) total_batches = ( len(items) + assembly_batch_size - 1 ) // assembly_batch_size manifest_batches: list[dict[str, object]] = [] unit_weight_items = _samples_with_unit_target_weights(items) column_chunk_sizes = cache_write_plan.column_chunk_sizes prepared_batches = _iter_with_progress( self._iter_prepared_batches( unit_weight_items, batch_size=assembly_batch_size, ), enabled=progress, description="Assembling least-squares cache", total=total_batches, ) for batch_index, batch in enumerate( prepared_batches, start=0, ): entry = None if cache_mode != "refresh": try: entry = _load_assembled_batch_manifest( cache_dir, batch_index, expected_metadata=assembly_metadata, ) except (json.JSONDecodeError, ValueError): entry = None if entry is None: assembled = self._apply_matrix_storage( self._assemble_true_blocks(batch), for_cache=True, ) assembled = project_batch_to_cache(assembled, cache_write_plan) if self.matrix_storage == "dense": assembled = self._apply_matrix_storage( assembled, for_cache=True, ) entry = _write_assembled_batch_memmap( cache_dir, batch_index, assembled, metadata=assembly_metadata, column_chunk_sizes=column_chunk_sizes, compact=False, ) manifest_batches.append(entry) if progress: print("Finished assembling all batches.") print(f"Writing cached least-squares manifest to {cache_dir}...") _write_assembled_batches_manifest( cache_dir, manifest_batches, metadata=assembly_metadata, ) if progress: print(f"Using memory-mapped least-squares batches from {cache_dir}...") row_weights = self._prepared_batch_sqrt_weights( items, batch_size=assembly_batch_size, device=torch.device("cpu"), ) return CachedBlockLinearProblem( layout=BlockProblemLayout.from_blocks(self._direct_blocks()), cache_dir=cache_dir, entries=tuple(manifest_batches), dtype=self.dtype, device=self.device, row_weights=row_weights, load_batch_entry=self._projected_cache_loader( assembly_metadata, cache_write_plan, ), row_indexed_blocks=self._preserve_cached_compact_blocks(), ) assembled_batches = [] prepared_batches = _iter_with_progress( self._iter_prepared_batches(items, batch_size=batch_size), enabled=progress, description="Assembling least-squares batches", total=total_batches, ) for batch in prepared_batches: assembled_batches.append( self._apply_matrix_storage(self._assemble_true_blocks(batch)) ) if progress: print("Finished assembling all batches.") return self.problem_from_assembled_batches(assembled_batches)
[docs] def problem_from_assembled_batches( self, assembled_batches: Sequence[AssembledBatch], ) -> BlockLinearProblem: """Wrap preassembled batches in the common block-problem interface.""" layout = BlockProblemLayout.from_blocks(self._direct_blocks()) batches = tuple( BlockSolveBatch( target=batch.target, matrices=self._apply_matrix_storage(batch).block_matrices, ) for batch in assembled_batches ) return BlockLinearProblem(layout=layout, batches=batches)
[docs] def make_linear_operator( self, samples: Sequence[FitSample], *, batch_size: int = 32, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", ) -> BlockLinearProblem: """Alias ``build_problem`` for callers that want a matrix-free operator view.""" return self.build_problem( samples, batch_size=batch_size, cache_directory=cache_directory, cache_mode=cache_mode, )
[docs] def materialize_design_matrix( self, samples: Sequence[FitSample], *, batch_size: int = 32, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", ) -> tuple[torch.Tensor, torch.Tensor]: """Build the explicit design matrix for the provided samples.""" return self.build_problem( samples, batch_size=batch_size, cache_directory=cache_directory, cache_mode=cache_mode, ).materialize_design_matrix()
[docs] def accumulate_normal_equations( self, samples: Sequence[FitSample], *, batch_size: int = 32, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", normal_equation_cache_directory: Path | str | None = None, normal_equation_cache_mode: AssembledBatchCacheMode = "auto", normal_equation_build_device: torch.device | str | None = None, progress: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Build the normal equations directly from the provided samples.""" if normal_equation_cache_directory is not None: build_device = ( None if normal_equation_build_device is None else torch.device(normal_equation_build_device) ) energy_weight, force_weight = _normal_equation_target_weights( samples, fit_energy=self.fit_energy, fit_forces=self.fit_forces, fit_per_atom_energy=self.fit_per_atom_energy, ) components = self._load_or_build_normal_equation_components( tuple(samples), batch_size=batch_size, progress=progress, cache_directory=normal_equation_cache_directory, cache_mode=normal_equation_cache_mode, build_device=build_device, ) gram, rhs, _, _ = self._finalize_normal_equations( components, energy_weight=energy_weight, force_weight=force_weight, ) return gram, rhs return self.build_problem( samples, batch_size=batch_size, cache_directory=cache_directory, cache_mode=cache_mode, ).accumulate_normal_equations()
[docs] def write_back(self, theta: torch.Tensor) -> None: """Write one solved direct vector back into the model coefficients.""" if self._selected_size() == self.layout.size: self.layout.write_direct_vector(theta) return self._write_selected_vector(theta)
def _empty_problem(self) -> BlockLinearProblem: """Return an empty problem shell for interrupted pre-assembly fits.""" return BlockLinearProblem( layout=BlockProblemLayout.from_blocks(self._direct_blocks()), batches=(), ) def _interrupted_result( self, theta: torch.Tensor, *, problem: BlockLinearProblem | None = None, restored_checkpoint_path: str | None = None, ) -> LinearFitResult: """Build a fallback fit result after an interrupted fit stage.""" result_problem = self._empty_problem() if problem is None else problem return LinearFitResult( theta=theta, objective=float("nan"), residual_norm=float("nan"), solver=self.solver, n_rows=result_problem.n_rows, n_parameters=self._selected_size(), layout=self.layout, problem=result_problem, interrupted=True, restored_checkpoint_path=restored_checkpoint_path, )
[docs] def write_checkpoint_to_model(self, checkpoint_path: Path | str) -> None: """Write the coefficient vector from a CG checkpoint into the model.""" checkpoint = load_cg_checkpoint( checkpoint_path, dtype=self.dtype, device=self.device, ) expected = self._coefficient_checkpoint_metadata( n_parameters=self._selected_size(), dtype=self.dtype, ) legacy_metadata = _cg_checkpoint_metadata( n_parameters=self._selected_size(), dtype=self.dtype, ) if not self._checkpoint_metadata_contains(checkpoint.metadata, expected): if not ( self._uses_full_direct_layout() and checkpoint.metadata == legacy_metadata ): raise ValueError( "CG checkpoint metadata does not match this coefficient selection" ) self.write_back(checkpoint.x)
[docs] def fit( self, samples: Sequence[FitSample], *, batch_size: int = 32, write_back: bool = True, progress: bool = False, progress_frequency: int = 10, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", normal_equation_cache: bool = False, normal_equation_cache_directory: Path | str | None = None, normal_equation_cache_mode: AssembledBatchCacheMode = "auto", normal_equation_build_device: torch.device | str | None = None, warm_start: bool = False, cg_checkpoint_path: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, ) -> LinearFitResult: """Build, solve, optionally write back, and summarize one direct linear fit.""" items = tuple(samples) fallback_theta = self._current_selected_vector( dtype=self.dtype, device=self.device, ) if normal_equation_cache: if self.solver != "normal_equation_direct": raise ValueError( "normal-equation caching is only supported with " "solver='normal_equation_direct'" ) if normal_equation_cache_directory is None: if cache_directory is None: raise ValueError( "`cache_directory` or `normal_equation_cache_directory` is " "required when `normal_equation_cache=True`" ) normal_equation_cache_directory = ( Path(cache_directory) / "normal_equations" ) build_device = ( None if normal_equation_build_device is None else torch.device(normal_equation_build_device) ) try: theta, objective, residual_norm, n_rows = ( self._solve_cached_normal_equations( items, batch_size=batch_size, progress=progress, cache_directory=normal_equation_cache_directory, cache_mode=normal_equation_cache_mode, build_device=build_device, ) ) except KeyboardInterrupt: if progress: print( "Interrupted cached normal-equation fit; " "using current coefficients." ) if write_back: self.write_back(fallback_theta) return self._interrupted_result(fallback_theta) if write_back: self.write_back(theta) return LinearFitResult( theta=theta, objective=objective, residual_norm=residual_norm, solver=self.solver, n_rows=n_rows, n_parameters=self._selected_size(), layout=self.layout, problem=BlockLinearProblem( layout=BlockProblemLayout.from_blocks(self._direct_blocks()), batches=(), ), ) try: problem = self.build_problem( items, batch_size=batch_size, progress=progress, cache_directory=cache_directory, cache_mode=cache_mode, ) except KeyboardInterrupt: if progress: print("Interrupted least-squares assembly; using current coefficients.") if write_back: self.write_back(fallback_theta) return self._interrupted_result(fallback_theta) if progress: print( "Solving least-squares system: " f"solver={self.solver}, rows={problem.n_rows}, " f"parameters={problem.layout.size}" ) fallback_theta = fallback_theta.to(dtype=problem.dtype, device=problem.device) solve_result = problem.solve( solver=self.solver, cg_tolerance=self.cg_tolerance, cg_max_iter=self.cg_max_iter, progress=progress, progress_frequency=progress_frequency, initial_theta=( fallback_theta if warm_start else None ), fallback_theta=fallback_theta, return_info=True, cg_checkpoint_path=cg_checkpoint_path, cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, cg_checkpoint_metadata=( self._fit_checkpoint_metadata( items, n_parameters=problem.layout.size, dtype=problem.dtype, ) if self.solver == "cg" else None ), ) assert isinstance(solve_result, LinearSolveResult) theta = solve_result.theta if write_back: self.write_back(theta) if solve_result.interrupted: return self._interrupted_result( theta, problem=problem, restored_checkpoint_path=solve_result.restored_checkpoint_path, ) residual_norm_tensor = ( problem.residual_norm(theta) if isinstance(problem, CachedBlockLinearProblem) else torch.linalg.norm(problem.matvec(theta) - problem.target_vector()) ) if progress: print( "Least-squares solve complete: " f"||residual||={residual_norm_tensor.item():.6e}" ) return LinearFitResult( theta=theta, objective=float(problem.objective(theta).item()), residual_norm=float(residual_norm_tensor.item()), solver=self.solver, n_rows=problem.n_rows, n_parameters=problem.layout.size, layout=self.layout, problem=problem, )
__all__ = [ "BlockLinearProblem", "BlockMatrix", "BlockProblemLayout", "BlockRegularization", "BlockSolveBatch", "CGCheckpointState", "CachedBlockLinearProblem", "ColumnRowIndexedBlockMatrix", "CoefficientSelector", "LinearFitResult", "LinearFitter", "LinearSolveResult", "MatrixStorageMode", "RowIndexedBlockMatrix", "SolveBlock", "load_assembled_batches_memmap", "load_cg_checkpoint", "save_assembled_batches_memmap", "save_cg_checkpoint", ]