"""
Direct linear fitting over assembled spline coefficient blocks.
Use this module to build least-squares problems, materialize diagnostic
matrices, and solve for direct coefficient vectors with dense or iterative
backends.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Optional, Sequence
import torch
from ufp.core.output import UFPOutput
from ufp.leastsquares._assemble import (
assemble_true_blocks,
assemble_true_blocks_by_terms,
)
from ufp.leastsquares._block import (
BlockMatrix,
BlockProblemLayout,
BlockSolveBatch,
ColumnRowIndexedBlockMatrix,
MatrixStorageMode,
RowIndexedBlockMatrix,
SolveBlock,
_block_matrix_matvec, # noqa: F401 - private compatibility import
_compact_block_matrix_for_storage,
_materialize_block_matrix,
)
from ufp.leastsquares._cache import (
AssembledBatchCacheMode,
CachedBlockLinearProblem,
_assembled_cache_manifest_exists,
_assembled_cache_metadata_for_fit,
_cache_metadata_can_project,
_cache_metadata_matches,
_cache_metadata_mismatch_reasons,
_load_assembled_batch_entry_memmap,
_load_assembled_batch_manifest,
_load_assembled_batches_manifest,
_matching_assembled_batch_cache_size,
_normal_equation_target_weights,
_sample_signature,
_samples_with_unit_target_weights,
_write_assembled_batch_memmap,
_write_assembled_batches_manifest,
assembled_cache_dir,
load_assembled_batches_memmap,
save_assembled_batches_memmap,
)
from ufp.leastsquares._cache_layout import (
CacheWritePlan,
build_cache_projection_plan,
build_cache_write_plan,
cache_blocks_from_metadata,
project_batch_to_cache,
project_cache_batch_to_layout,
)
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._normal_equations import DenseNormalEquationMixin
from ufp.leastsquares._problem import (
BlockLinearProblem,
CGCheckpointState,
LinearFitResult,
LinearSolveResult,
_cg_checkpoint_metadata,
load_cg_checkpoint,
save_cg_checkpoint,
)
from ufp.leastsquares._selection import (
CoefficientSelector,
SelectedCoefficientBlock,
block_matches_selector,
)
from ufp.leastsquares._setup import (
build_linear_fitter_selection_plan,
current_selected_vector,
fixed_coefficients_signature,
selected_coefficient_metadata,
write_selected_vector,
)
from ufp.leastsquares._types import AssembledBatch
from ufp.leastsquares._utils import _infer_tensor_options, _iter_with_progress
from ufp.leastsquares.dataset import FitSample, PreparedBatch, prepare_batches
from ufp.leastsquares.regularization import (
BlockRegularization,
RegularizationStencil,
_make_block_regularization,
)
from ufp.terms._twobody_shape import (
TwoBodySplineShapePenalty,
normalize_twobody_shape_penalty,
)
from ufp.terms.model import UFPModel
from ufp.terms.twobody import SplineTwoBodyTerm
def _twobody_shape_regularization_rows(block) -> tuple[int, ...] | None:
"""Return active coefficient rows for two-body regularization."""
if isinstance(block.term, SplineTwoBodyTerm):
return tuple(int(index) for index in block.term._active_pair_indices)
return None
_TWOBODY_SHAPE_REGULARIZATION_SEMANTICS = "boundary_aware_partial_v1"
_THIRD_DIFFERENCE_STENCIL = (-1.0, 3.0, -3.0, 1.0)
[docs]
class LinearFitter(DenseNormalEquationMixin):
"""High-level entry point for direct least-squares fitting of spline models."""
def __init__(
self,
model: UFPModel,
*,
fit_energy: bool = True,
fit_forces: bool = True,
fit_per_atom_energy: bool = False,
solver: str = "cg",
ridge: float = 0.0,
onebody_ridge: float | None = None,
pair_ridge: float | None = None,
twobody_ridge: float | None = None,
threebody_ridge: float | None = None,
twobody_shape_penalty: TwoBodySplineShapePenalty | None = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
cg_tolerance: float = 1.0e-10,
cg_max_iter: int | None = None,
threebody_lstsq_backend: str | None = None,
threebody_bucket_backend: str | None = None,
fit_blocks: Sequence[int | str] | None = None,
freeze_blocks: Sequence[int | str] = (),
assembly_contract: str = "term",
matrix_storage: MatrixStorageMode = "auto",
) -> None:
"""Store solve options and freeze the model parameter layout."""
if not (fit_energy or fit_forces or fit_per_atom_energy):
raise ValueError("at least one target type must be enabled")
inferred_dtype, inferred_device = _infer_tensor_options(model)
self.model = model
self.fit_energy = bool(fit_energy)
self.fit_forces = bool(fit_forces)
self.fit_per_atom_energy = bool(fit_per_atom_energy)
self.solver = str(solver)
self.ridge = float(ridge)
self.onebody_ridge = (
self.ridge if onebody_ridge is None else float(onebody_ridge)
)
if pair_ridge is not None and twobody_ridge is not None:
raise ValueError("use only one of `pair_ridge` and `twobody_ridge`")
resolved_pair_ridge = pair_ridge if pair_ridge is not None else twobody_ridge
self.pair_ridge = (
self.ridge if resolved_pair_ridge is None else float(resolved_pair_ridge)
)
self.threebody_ridge = (
self.ridge if threebody_ridge is None else float(threebody_ridge)
)
if self.ridge < 0.0:
raise ValueError("`ridge` must be non-negative")
if self.onebody_ridge < 0.0:
raise ValueError("`onebody_ridge` must be non-negative")
if self.pair_ridge < 0.0:
raise ValueError("`pair_ridge` must be non-negative")
if self.threebody_ridge < 0.0:
raise ValueError("`threebody_ridge` must be non-negative")
self.twobody_shape_penalty = normalize_twobody_shape_penalty(
twobody_shape_penalty
)
self.dtype = inferred_dtype if dtype is None else dtype
self.device = inferred_device if device is None else device
self.cg_tolerance = float(cg_tolerance)
self.cg_max_iter = cg_max_iter
self.threebody_lstsq_backend = threebody_lstsq_backend
self.threebody_bucket_backend = threebody_bucket_backend
self.layout = ParameterLayout.from_model(model, include_frozen=True)
self.fit_blocks = None if fit_blocks is None else tuple(fit_blocks)
self.freeze_blocks = tuple(freeze_blocks)
self.selection_plan = build_linear_fitter_selection_plan(
self.layout,
fit_blocks=self.fit_blocks,
freeze_blocks=self.freeze_blocks,
)
self.selected_coefficients = self.selection_plan.selected_coefficients
self.assembly_contract = str(assembly_contract)
if self.assembly_contract not in {"block", "term"}:
raise ValueError("`assembly_contract` must be 'block' or 'term'")
self.matrix_storage = str(matrix_storage)
if self.matrix_storage not in {
"dense",
"row_indexed",
"column_chunked",
"auto",
}:
raise ValueError(
"`matrix_storage` must be one of: dense, row_indexed, "
"column_chunked, auto"
)
def _block_matches_selector(self, block, selector: int | str) -> bool:
"""Return whether a block matches one include/exclude selector."""
return block_matches_selector(block, selector)
def _selected_block_indices(self) -> tuple[int, ...]:
"""Return layout block indices included in direct assembly and solving."""
return self.selection_plan.selected_block_indices
def _selected_by_block(self) -> dict[int, SelectedCoefficientBlock]:
"""Return selected coefficient metadata keyed by original block index."""
return dict(self.selection_plan.selected_by_block)
def _current_selected_vector(
self,
*,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
"""Read the current selected coefficients into the compact solve layout."""
return current_selected_vector(
self.selected_coefficients,
dtype=dtype,
device=device,
)
def _write_selected_vector(self, theta: torch.Tensor) -> None:
"""Write compact selected coefficients back while preserving fixed entries."""
write_selected_vector(self.selected_coefficients, theta)
def _selection_metadata(self) -> list[dict[str, object]]:
"""Return compact metadata describing the selected coefficient columns."""
return selected_coefficient_metadata(self.selected_coefficients)
def _fixed_coefficients_signature(self) -> str:
"""Return a digest of current coefficient values outside the solve layout."""
return fixed_coefficients_signature(self.layout, self.selected_coefficients)
def _selected_size(self) -> int:
"""Return the compact solve vector length for selected coefficients."""
return self.selection_plan.selected_size
def _uses_full_direct_layout(self) -> bool:
"""Return whether the selected vector matches the full layout exactly."""
return self.selection_plan.uses_full_direct_layout
def _regularization_metadata(self) -> dict[str, object]:
"""Return serializable solve settings that affect CG checkpoint state."""
penalty = self.twobody_shape_penalty
return {
"semantics": self._regularization_semantics_metadata(),
"ridge": float(self.ridge),
"onebody_ridge": float(self.onebody_ridge),
"pair_ridge": float(self.pair_ridge),
"threebody_ridge": float(self.threebody_ridge),
"twobody_shape_penalty": {
"third_difference_weight": float(penalty.third_difference_weight),
"first_coefficient_min": float(penalty.first_coefficient_min),
"first_coefficient_weight": float(penalty.first_coefficient_weight),
"initial_curvature_weight": float(penalty.initial_curvature_weight),
"decreasing_weight": float(penalty.decreasing_weight),
"decreasing_count": int(penalty.decreasing_count),
},
}
def _regularization_semantics_metadata(self) -> dict[str, object]:
"""Return metadata for cache invalidation across regularization semantics."""
return {"twobody_shape_regularization": _TWOBODY_SHAPE_REGULARIZATION_SEMANTICS}
def _coefficient_checkpoint_metadata(
self,
*,
n_parameters: int,
dtype: torch.dtype,
) -> dict[str, object]:
"""Return CG metadata for the selected coefficient layout."""
metadata = _cg_checkpoint_metadata(
n_parameters=n_parameters,
dtype=dtype,
)
metadata.update(
{
"selected_block_indices": [
int(index) for index in self._selected_block_indices()
],
"coefficient_selection": self._selection_metadata(),
"fixed_coefficients_signature": (self._fixed_coefficients_signature()),
}
)
return metadata
def _fit_checkpoint_metadata(
self,
samples: Sequence[FitSample],
*,
n_parameters: int,
dtype: torch.dtype,
) -> dict[str, object]:
"""Return CG metadata for one concrete selected least-squares problem."""
metadata = self._coefficient_checkpoint_metadata(
n_parameters=n_parameters,
dtype=dtype,
)
metadata.update(self._threebody_backend_metadata())
metadata.update(
{
"sample_signature": _sample_signature(
samples,
include_weights=True,
),
"target_layout": {
"fit_energy": bool(self.fit_energy),
"fit_forces": bool(self.fit_forces),
"fit_per_atom_energy": bool(self.fit_per_atom_energy),
},
"regularization": self._regularization_metadata(),
}
)
return metadata
@staticmethod
def _checkpoint_metadata_contains(
actual: dict[str, object],
expected: dict[str, object],
) -> bool:
"""Return whether one checkpoint metadata payload contains all expected keys."""
return all(actual.get(key) == value for key, value in expected.items())
[docs]
@classmethod
def from_model(cls, model: UFPModel, **kwargs) -> "LinearFitter":
"""Convenience constructor mirroring ``__init__``."""
return cls(model, **kwargs)
def _prepare_batches(
self,
samples: Sequence[FitSample],
*,
batch_size: int,
) -> tuple[PreparedBatch, ...]:
"""Normalize fit samples into prepared geometry and target batches."""
return prepare_batches(
self.model,
samples,
batch_size=batch_size,
fit_energy=self.fit_energy,
fit_forces=self.fit_forces,
fit_per_atom_energy=self.fit_per_atom_energy,
dtype=self.dtype,
device=self.device,
)
def _iter_prepared_batches(
self,
samples: Sequence[FitSample],
*,
batch_size: int,
device: torch.device | None = None,
):
"""Yield prepared geometry batches without retaining all chunks."""
items = tuple(samples)
if not items:
raise ValueError("`samples` must contain at least one FitSample")
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
target_device = self.device if device is None else device
for start in range(0, len(items), batch_size):
chunk = items[start : start + batch_size]
prepared = prepare_batches(
self.model,
chunk,
batch_size=len(chunk),
fit_energy=self.fit_energy,
fit_forces=self.fit_forces,
fit_per_atom_energy=self.fit_per_atom_energy,
dtype=self.dtype,
device=target_device,
)
yield from prepared
def _assemble_true_blocks(self, batch: PreparedBatch) -> AssembledBatch:
"""Assemble one prepared batch and residualize fixed coefficients."""
selected = self._assemble_selected_true_blocks(batch)
if selected is not None:
return selected
assembler = (
assemble_true_blocks_by_terms
if self.assembly_contract == "term"
else assemble_true_blocks
)
assembled = assembler(
batch,
self.layout,
threebody_lstsq_backend=self.threebody_lstsq_backend,
threebody_bucket_backend=self.threebody_bucket_backend,
)
return self._compact_selected_assembled_batch(assembled)
def _assemble_selected_true_blocks(
self,
batch: PreparedBatch,
) -> AssembledBatch | None:
"""Assemble compact selected columns through optional term-level hooks."""
if self.assembly_contract != "term" or self._uses_full_direct_layout():
return None
if not self._can_residualize_fixed_coefficients_from_outputs():
return None
row_weights = (batch.targets.sqrt_weights * batch.targets.row_scales)[:, None]
matrices: dict[int, torch.Tensor] = {}
for selection in self.selected_coefficients:
matrix = self._assemble_selected_matrix(batch, selection)
if matrix is not None:
matrices[int(selection.block.index)] = matrix * row_weights
return AssembledBatch(
target=(
batch.targets.weighted_values
- self._fixed_coefficient_prediction(batch)
),
block_matrices=matrices,
)
def _assemble_selected_matrix(
self,
batch: PreparedBatch,
selection: SelectedCoefficientBlock,
) -> torch.Tensor | None:
"""Assemble one compact selected-column block, falling back when needed."""
selected_method = getattr(
selection.block.term,
"assemble_selected_linear_block",
None,
)
used_selected_method = False
matrix = None
if not selection.is_full_block and callable(selected_method):
used_selected_method = True
matrix = selected_method(
selection.block,
batch.inputs,
batch.targets,
selection.indices,
)
if matrix is not None and tuple(matrix.shape) != (
batch.targets.n_rows,
selection.size,
):
raise ValueError(
"selected linear assembler for block "
f"{selection.block.label!r} returned shape "
f"{tuple(matrix.shape)}, expected "
f"({batch.targets.n_rows}, {selection.size})"
)
if matrix is not None or used_selected_method:
return matrix
fallback = assemble_true_blocks_by_terms(
batch,
self.layout,
selected_block_indices=(int(selection.block.index),),
threebody_lstsq_backend=self.threebody_lstsq_backend,
threebody_bucket_backend=self.threebody_bucket_backend,
)
full_matrix = fallback.block_matrices.get(int(selection.block.index))
if full_matrix is None:
return None
if selection.is_full_block:
return full_matrix
selected_index = torch.tensor(
selection.indices,
dtype=torch.int64,
device=full_matrix.device,
)
return full_matrix.index_select(1, selected_index)
def _can_residualize_fixed_coefficients_from_outputs(self) -> bool:
"""Return whether blocks can be temporarily zeroed for residualization."""
for block in self.layout.blocks:
provider = block.coefficient_provider
if provider is not None and not provider.uses_identity_weights:
return False
return True
def _fixed_coefficient_prediction(self, batch: PreparedBatch) -> torch.Tensor:
"""Evaluate fixed coefficient contributions as weighted target rows."""
originals: list[tuple[TermBlock, torch.Tensor]] = []
rows = torch.zeros_like(batch.targets.values)
try:
for selection in self.selected_coefficients:
block = selection.block
current = block.read().detach().clone()
originals.append((block, current))
flat = current.reshape(-1).clone()
index = torch.tensor(
selection.indices,
dtype=torch.int64,
device=flat.device,
)
flat.index_fill_(0, index, 0.0)
block.write(flat.reshape(block.shape))
with torch.no_grad():
for term in self._layout_terms():
self._accumulate_output_rows(rows, batch, term(batch.inputs))
finally:
for block, values in reversed(originals):
block.write(values)
return rows * batch.targets.sqrt_weights * batch.targets.row_scales
def _layout_terms(self) -> tuple[torch.nn.Module, ...]:
"""Return unique terms that own coefficient blocks in layout order."""
seen: set[int] = set()
terms: list[torch.nn.Module] = []
for block in self.layout.blocks:
term_id = id(block.term)
if term_id in seen:
continue
seen.add(term_id)
terms.append(block.term)
return tuple(terms)
def _accumulate_output_rows(
self,
rows: torch.Tensor,
batch: PreparedBatch,
output: UFPOutput,
) -> None:
"""Add one term output to unweighted least-squares target rows."""
inputs = batch.inputs
targets = batch.targets
if output.energy is not None:
energy = torch.as_tensor(
output.energy,
dtype=inputs.dtype,
device=inputs.device,
)
energy = energy.reshape(inputs.n_systems, -1)
if energy.shape[1] != 1:
raise ValueError("term energy must provide one value per system")
valid = targets.energy_rows >= 0
if torch.any(valid):
rows.index_add_(
0,
targets.energy_rows[valid],
energy[:, 0][valid],
)
if output.forces is not None:
forces = torch.as_tensor(
output.forces,
dtype=inputs.dtype,
device=inputs.device,
)
if tuple(forces.shape) != (inputs.n_atoms, 3):
raise ValueError(
"term forces must have shape "
f"({inputs.n_atoms}, 3), got {tuple(forces.shape)}"
)
valid = targets.force_rows >= 0
if torch.any(valid):
rows.index_add_(0, targets.force_rows[valid], forces[valid])
if output.per_atom_energy is not None:
per_atom_energy = torch.as_tensor(
output.per_atom_energy,
dtype=inputs.dtype,
device=inputs.device,
)
if per_atom_energy.ndim == 2 and tuple(per_atom_energy.shape) == (
inputs.n_atoms,
1,
):
per_atom_energy = per_atom_energy[:, 0]
else:
per_atom_energy = per_atom_energy.reshape(-1)
if tuple(per_atom_energy.shape) != (inputs.n_atoms,):
raise ValueError(
"term per-atom energy must have shape "
f"({inputs.n_atoms},), got {tuple(per_atom_energy.shape)}"
)
valid = targets.per_atom_rows >= 0
if torch.any(valid):
rows.index_add_(
0,
targets.per_atom_rows[valid],
per_atom_energy[valid],
)
def _compact_selected_assembled_batch(
self,
batch: AssembledBatch,
) -> AssembledBatch:
"""Split full-block matrices into selected columns and fixed target."""
selected_by_block = self._selected_by_block()
target = batch.target.clone()
matrices: dict[int, torch.Tensor] = {}
for block in self.layout.blocks:
matrix = batch.block_matrices.get(int(block.index))
if matrix is None:
continue
dense = (
matrix
if isinstance(matrix, torch.Tensor)
else _materialize_block_matrix(matrix)
)
selection = selected_by_block.get(int(block.index))
selected_indices = set(()) if selection is None else set(selection.indices)
fixed_indices = tuple(
index for index in range(block.size) if index not in selected_indices
)
current = (
block.read()
.reshape(-1)
.to(
dtype=dense.dtype,
device=dense.device,
)
)
if fixed_indices:
fixed_index = torch.tensor(
fixed_indices,
dtype=torch.int64,
device=dense.device,
)
fixed_matrix = dense.index_select(1, fixed_index)
fixed_theta = current.index_select(0, fixed_index)
target = target - fixed_matrix @ fixed_theta
if selection is None:
continue
selected_index = torch.tensor(
selection.indices,
dtype=torch.int64,
device=dense.device,
)
matrices[int(block.index)] = dense.index_select(1, selected_index)
return AssembledBatch(
target=target,
block_matrices=matrices,
)
def _threebody_backend_metadata(self) -> dict[str, object]:
"""Return backend settings that affect assembly cache construction."""
return {
"threebody_lstsq_backend": (
os.environ.get("UFP_THREEBODY_LSTSQ_BACKEND", "auto")
if self.threebody_lstsq_backend is None
else str(self.threebody_lstsq_backend)
),
"threebody_bucket_backend": (
os.environ.get("UFP_THREEBODY_BUCKET_BACKEND", "auto")
if self.threebody_bucket_backend is None
else str(self.threebody_bucket_backend)
),
"assembly_contract": self.assembly_contract,
}
def _partial_twobody_third_difference_stencils(
self,
selection: SelectedCoefficientBlock,
) -> tuple[RegularizationStencil, ...]:
"""Return boundary-aware third-difference rows for partial selections."""
block = selection.block
if block.kind not in {"pair", "twobody"}:
return ()
if selection.is_full_block:
return ()
if len(block.shape) not in {1, 2} or block.shape[-1] < 4:
return ()
selected_columns = {
int(original): compact for compact, original in enumerate(selection.indices)
}
values = block.read().detach().reshape(-1).cpu()
rows: list[RegularizationStencil] = []
if len(block.shape) == 1:
row_offsets = (0,)
n_coeffs = int(block.shape[0])
else:
n_coeffs = int(block.shape[1])
active_rows = _twobody_shape_regularization_rows(block)
if active_rows is None:
active_rows = tuple(range(int(block.shape[0])))
row_offsets = tuple(int(row) * n_coeffs for row in active_rows)
for offset in row_offsets:
for coeff in range(n_coeffs - 3):
compact_columns: list[int] = []
compact_weights: list[float] = []
fixed_term = 0.0
for local, weight in enumerate(_THIRD_DIFFERENCE_STENCIL):
original = int(offset + coeff + local)
compact = selected_columns.get(original)
if compact is None:
fixed_term += float(weight) * float(values[original].item())
else:
compact_columns.append(int(compact))
compact_weights.append(float(weight))
if not compact_columns:
continue
rows.append(
RegularizationStencil(
columns=tuple(compact_columns),
weights=tuple(compact_weights),
target=-fixed_term,
)
)
return tuple(rows)
def _direct_blocks(self) -> tuple[SolveBlock, ...]:
"""Build solve metadata for every block handled by the direct linear system."""
blocks = []
for selection in self.selected_coefficients:
block = selection.block
group = block.regularization_group
ridge = self.ridge
if group == "onebody":
ridge = self.onebody_ridge
elif group in {"pair", "twobody"}:
ridge = self.pair_ridge
elif group == "threebody":
ridge = self.threebody_ridge
third_difference_penalty = (
self.twobody_shape_penalty.third_difference_weight
if block.kind in {"pair", "twobody"}
else 0.0
)
active_rows = None
third_difference_stencils = None
if third_difference_penalty > 0.0:
if selection.is_full_block:
active_rows = _twobody_shape_regularization_rows(block)
else:
third_difference_stencils = (
self._partial_twobody_third_difference_stencils(selection)
)
blocks.append(
SolveBlock(
key=block.index,
size=selection.size,
label=block.label,
regularization=_make_block_regularization(
selection.shape,
ridge=ridge,
third_difference_penalty=third_difference_penalty,
active_rows=active_rows,
third_difference_stencils=third_difference_stencils,
),
)
)
return tuple(blocks)
def _column_chunk_sizes(self) -> dict[int, int]:
"""Return per-block coefficient category widths for sparse cache storage."""
chunk_sizes: dict[int, int] = {}
for selection in self.selected_coefficients:
if len(selection.shape) < 2 or int(selection.shape[0]) <= 1:
continue
chunk_size = 1
for dim in selection.shape[1:]:
chunk_size *= int(dim)
chunk_sizes[int(selection.block.index)] = int(chunk_size)
return chunk_sizes
def _apply_matrix_storage(
self,
batch: AssembledBatch,
*,
for_cache: bool = False,
) -> AssembledBatch:
"""Return an assembled batch using this fitter's block storage option."""
dense_auto = (
self.matrix_storage == "auto" and self.solver == "cg" and not for_cache
)
if self.matrix_storage == "dense" or dense_auto:
return AssembledBatch(
target=batch.target,
block_matrices={
key: _materialize_block_matrix(matrix)
for key, matrix in batch.block_matrices.items()
},
)
column_chunk_sizes = self._column_chunk_sizes()
return AssembledBatch(
target=batch.target,
block_matrices={
key: _compact_block_matrix_for_storage(
matrix,
mode=self.matrix_storage, # type: ignore[arg-type]
column_chunk_size=column_chunk_sizes.get(key),
)
for key, matrix in batch.block_matrices.items()
},
)
def _preserve_cached_compact_blocks(self) -> bool:
"""Return whether streamed cache loads should preserve compact storage."""
return self.matrix_storage != "dense"
def _cache_write_plan(self) -> CacheWritePlan:
"""Return the semantic cache layout for this fitter's solve layout."""
return build_cache_write_plan(
self.selected_coefficients,
reusable=self._uses_full_direct_layout(),
)
def _projected_cache_loader(
self,
metadata: object,
cache_write_plan: CacheWritePlan,
):
"""Return a batch loader projecting semantic cache blocks to solve blocks."""
source_blocks = cache_blocks_from_metadata(metadata)
projection_plan = build_cache_projection_plan(source_blocks, cache_write_plan)
def load_projected_batch(cache_dir, entry, **kwargs):
batch = _load_assembled_batch_entry_memmap(cache_dir, entry, **kwargs)
projected = project_cache_batch_to_layout(batch, projection_plan)
if self.matrix_storage == "dense":
return self._apply_matrix_storage(projected, for_cache=True)
return projected
return load_projected_batch
def _cached_problem_from_manifest(
self,
*,
manifest: dict[str, object],
cache_dir: Path,
cache_write_plan: CacheWritePlan,
items: Sequence[FitSample],
default_batch_size: int,
) -> CachedBlockLinearProblem:
"""Create a streamed cached problem from a compatible manifest."""
metadata = manifest.get("metadata")
if not isinstance(metadata, dict):
raise ValueError("least-squares cache metadata is missing")
cache_batch_size = int(metadata.get("batch_size", default_batch_size))
row_weights = self._prepared_batch_sqrt_weights(
items,
batch_size=cache_batch_size,
device=torch.device("cpu"),
)
resolved_cache_dir = Path(str(manifest.get("_cache_dir", cache_dir)))
return CachedBlockLinearProblem(
layout=BlockProblemLayout.from_blocks(self._direct_blocks()),
cache_dir=resolved_cache_dir,
entries=tuple(manifest.get("batches", ())),
dtype=self.dtype,
device=self.device,
row_weights=row_weights,
load_batch_entry=self._projected_cache_loader(
metadata,
cache_write_plan,
),
row_indexed_blocks=self._preserve_cached_compact_blocks(),
)
def _find_projectable_cache(
self,
*,
cache_parent: Path,
expected_metadata: dict[str, object],
cache_write_plan: CacheWritePlan,
items: Sequence[FitSample],
default_batch_size: int,
skip_dirs: set[Path] | None = None,
) -> CachedBlockLinearProblem | None:
"""Find a sibling assembled cache that can be projected to this layout."""
skipped = {path.resolve() for path in (() if skip_dirs is None else skip_dirs)}
for manifest_path in sorted(cache_parent.glob("*/assembled_batches.json")):
candidate_dir = manifest_path.parent
if candidate_dir.resolve() in skipped:
continue
try:
manifest = _load_assembled_batches_manifest(candidate_dir)
except (OSError, ValueError, json.JSONDecodeError):
continue
metadata = manifest.get("metadata")
if not _cache_metadata_can_project(metadata, expected_metadata):
continue
try:
return self._cached_problem_from_manifest(
manifest=manifest,
cache_dir=candidate_dir,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=default_batch_size,
)
except ValueError:
continue
return None
def _prepared_batch_sqrt_weights(
self,
samples: Sequence[FitSample],
*,
batch_size: int,
device: torch.device | None = None,
) -> tuple[torch.Tensor, ...]:
"""Return current target row weights in prepared-batch order."""
resolved_device = self.device if device is None else device
items = tuple(samples)
if not items:
raise ValueError("`samples` must contain at least one FitSample")
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
weights: list[torch.Tensor] = []
for start in range(0, len(items), batch_size):
chunk = items[start : start + batch_size]
prepared_batches = prepare_batches(
self.model,
chunk,
batch_size=len(chunk),
fit_energy=self.fit_energy,
fit_forces=self.fit_forces,
fit_per_atom_energy=self.fit_per_atom_energy,
dtype=self.dtype,
device=resolved_device,
)
weights.extend(
prepared.targets.sqrt_weights for prepared in prepared_batches
)
return tuple(weights)
[docs]
def build_problem(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
progress: bool = False,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
) -> BlockLinearProblem:
"""Prepare batches, assemble block matrices, and return the linear problem."""
items = tuple(samples)
if not items:
raise ValueError("`samples` must contain at least one FitSample")
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
expected_metadata = _assembled_cache_metadata_for_fit(
layout=self.layout,
samples=items,
fit_energy=self.fit_energy,
fit_forces=self.fit_forces,
fit_per_atom_energy=self.fit_per_atom_energy,
dtype=self.dtype,
batch_size=batch_size,
)
expected_metadata.update(self._threebody_backend_metadata())
expected_metadata["regularization_semantics"] = (
self._regularization_semantics_metadata()
)
expected_metadata["selected_block_indices"] = [
int(index) for index in self._selected_block_indices()
]
expected_metadata["coefficient_selection"] = self._selection_metadata()
expected_metadata["fixed_coefficients_signature"] = (
self._fixed_coefficients_signature()
)
cache_write_plan = self._cache_write_plan()
expected_metadata["cache_blocks"] = cache_write_plan.metadata()
expected_metadata["cache_reusable"] = bool(cache_write_plan.reusable)
total_batches = (len(items) + batch_size - 1) // batch_size
cache_parent = None if cache_directory is None else Path(cache_directory)
cache_dir = (
None
if cache_parent is None
else assembled_cache_dir(cache_parent, expected_metadata)
)
if cache_directory is not None and cache_mode in {"auto", "read"}:
assert cache_parent is not None
assert cache_dir is not None
read_cache_dir = cache_dir
if not _assembled_cache_manifest_exists(read_cache_dir):
legacy_cache_dir = cache_parent
if _assembled_cache_manifest_exists(legacy_cache_dir):
read_cache_dir = legacy_cache_dir
if _assembled_cache_manifest_exists(read_cache_dir):
try:
manifest = _load_assembled_batches_manifest(read_cache_dir)
except ValueError:
if cache_mode == "read":
raise
manifest = None
if progress:
print(
"Ignoring incompatible least-squares cache in "
f"{read_cache_dir}; rebuilding..."
)
if manifest is None:
metadata_matches = False
metadata_projectable = False
else:
metadata_matches = _cache_metadata_matches(
manifest.get("metadata"),
expected_metadata,
)
metadata_projectable = _cache_metadata_can_project(
manifest.get("metadata"),
expected_metadata,
)
if metadata_matches:
if progress:
print(
"Using cached least-squares batches from "
f"{read_cache_dir}..."
)
return self._cached_problem_from_manifest(
manifest=manifest,
cache_dir=read_cache_dir,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=batch_size,
)
if metadata_projectable:
try:
if progress:
print(
"Using projectable least-squares cache from "
f"{read_cache_dir}..."
)
return self._cached_problem_from_manifest(
manifest=manifest,
cache_dir=read_cache_dir,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=batch_size,
)
except ValueError:
if cache_mode == "read":
raise
compatible_problem = self._find_projectable_cache(
cache_parent=cache_parent,
expected_metadata=expected_metadata,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=batch_size,
skip_dirs=(
set()
if manifest is None
else {Path(str(manifest.get("_cache_dir", read_cache_dir)))}
),
)
if compatible_problem is not None:
if progress:
print("Using projectable sibling least-squares cache...")
return compatible_problem
if cache_mode == "read":
raise ValueError(
"least-squares cache metadata does not match the requested "
"samples, targets, dtype, or model layout"
)
if progress:
mismatch_reasons = (
()
if manifest is None
else _cache_metadata_mismatch_reasons(
manifest.get("metadata"),
expected_metadata,
)
)
reason_suffix = (
""
if not mismatch_reasons
else f" ({', '.join(mismatch_reasons)})"
)
print(
"Ignoring stale least-squares cache in "
f"{read_cache_dir}{reason_suffix}; rebuilding..."
)
elif cache_mode == "read":
cached_batch_size = _matching_assembled_batch_cache_size(
cache_dir,
expected_metadata=expected_metadata,
)
restore_batch_size = (
batch_size if cached_batch_size is None else cached_batch_size
)
restore_total_batches = (
len(items) + restore_batch_size - 1
) // restore_batch_size
try:
manifest_batches = [
_load_assembled_batch_manifest(
cache_dir,
batch_index,
expected_metadata=expected_metadata,
)
for batch_index in range(restore_total_batches)
]
except ValueError:
raise
if all(entry is not None for entry in manifest_batches):
if progress:
print(
"Restoring least-squares manifest from completed "
f"batch caches in {cache_dir}..."
)
restored_entries = [
entry for entry in manifest_batches if entry is not None
]
restored_metadata = dict(expected_metadata)
restored_metadata["batch_size"] = int(restore_batch_size)
_write_assembled_batches_manifest(
cache_dir,
restored_entries,
metadata=restored_metadata,
)
return self._cached_problem_from_manifest(
manifest={
"_cache_dir": str(cache_dir),
"batches": restored_entries,
"metadata": restored_metadata,
},
cache_dir=cache_dir,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=restore_batch_size,
)
compatible_problem = self._find_projectable_cache(
cache_parent=cache_parent,
expected_metadata=expected_metadata,
cache_write_plan=cache_write_plan,
items=items,
default_batch_size=batch_size,
skip_dirs={read_cache_dir},
)
if compatible_problem is not None:
if progress:
print("Using projectable sibling least-squares cache...")
return compatible_problem
if cache_mode == "read":
if any(cache_parent.glob("*/assembled_batches.json")):
raise ValueError(
"least-squares cache metadata does not match the requested "
"samples, targets, dtype, or model layout"
)
raise FileNotFoundError(
"least-squares cache requested in read mode, but no complete "
f"cache was found in {cache_dir}"
)
if progress:
print(f"Preparing least-squares batches with batch_size={batch_size}...")
if cache_directory is not None and cache_mode in {"auto", "write", "refresh"}:
assert cache_dir is not None
cache_dir.mkdir(parents=True, exist_ok=True)
assembly_batch_size = batch_size
if cache_mode != "refresh":
cached_batch_size = _matching_assembled_batch_cache_size(
cache_dir,
expected_metadata=expected_metadata,
)
if cached_batch_size is not None:
assembly_batch_size = cached_batch_size
assembly_metadata = dict(expected_metadata)
assembly_metadata["batch_size"] = int(assembly_batch_size)
total_batches = (
len(items) + assembly_batch_size - 1
) // assembly_batch_size
manifest_batches: list[dict[str, object]] = []
unit_weight_items = _samples_with_unit_target_weights(items)
column_chunk_sizes = cache_write_plan.column_chunk_sizes
prepared_batches = _iter_with_progress(
self._iter_prepared_batches(
unit_weight_items,
batch_size=assembly_batch_size,
),
enabled=progress,
description="Assembling least-squares cache",
total=total_batches,
)
for batch_index, batch in enumerate(
prepared_batches,
start=0,
):
entry = None
if cache_mode != "refresh":
try:
entry = _load_assembled_batch_manifest(
cache_dir,
batch_index,
expected_metadata=assembly_metadata,
)
except (json.JSONDecodeError, ValueError):
entry = None
if entry is None:
assembled = self._apply_matrix_storage(
self._assemble_true_blocks(batch),
for_cache=True,
)
assembled = project_batch_to_cache(assembled, cache_write_plan)
if self.matrix_storage == "dense":
assembled = self._apply_matrix_storage(
assembled,
for_cache=True,
)
entry = _write_assembled_batch_memmap(
cache_dir,
batch_index,
assembled,
metadata=assembly_metadata,
column_chunk_sizes=column_chunk_sizes,
compact=False,
)
manifest_batches.append(entry)
if progress:
print("Finished assembling all batches.")
print(f"Writing cached least-squares manifest to {cache_dir}...")
_write_assembled_batches_manifest(
cache_dir,
manifest_batches,
metadata=assembly_metadata,
)
if progress:
print(f"Using memory-mapped least-squares batches from {cache_dir}...")
row_weights = self._prepared_batch_sqrt_weights(
items,
batch_size=assembly_batch_size,
device=torch.device("cpu"),
)
return CachedBlockLinearProblem(
layout=BlockProblemLayout.from_blocks(self._direct_blocks()),
cache_dir=cache_dir,
entries=tuple(manifest_batches),
dtype=self.dtype,
device=self.device,
row_weights=row_weights,
load_batch_entry=self._projected_cache_loader(
assembly_metadata,
cache_write_plan,
),
row_indexed_blocks=self._preserve_cached_compact_blocks(),
)
assembled_batches = []
prepared_batches = _iter_with_progress(
self._iter_prepared_batches(items, batch_size=batch_size),
enabled=progress,
description="Assembling least-squares batches",
total=total_batches,
)
for batch in prepared_batches:
assembled_batches.append(
self._apply_matrix_storage(self._assemble_true_blocks(batch))
)
if progress:
print("Finished assembling all batches.")
return self.problem_from_assembled_batches(assembled_batches)
[docs]
def problem_from_assembled_batches(
self,
assembled_batches: Sequence[AssembledBatch],
) -> BlockLinearProblem:
"""Wrap preassembled batches in the common block-problem interface."""
layout = BlockProblemLayout.from_blocks(self._direct_blocks())
batches = tuple(
BlockSolveBatch(
target=batch.target,
matrices=self._apply_matrix_storage(batch).block_matrices,
)
for batch in assembled_batches
)
return BlockLinearProblem(layout=layout, batches=batches)
[docs]
def make_linear_operator(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
) -> BlockLinearProblem:
"""Alias ``build_problem`` for callers that want a matrix-free operator view."""
return self.build_problem(
samples,
batch_size=batch_size,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
[docs]
def materialize_design_matrix(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build the explicit design matrix for the provided samples."""
return self.build_problem(
samples,
batch_size=batch_size,
cache_directory=cache_directory,
cache_mode=cache_mode,
).materialize_design_matrix()
[docs]
def accumulate_normal_equations(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
normal_equation_cache_directory: Path | str | None = None,
normal_equation_cache_mode: AssembledBatchCacheMode = "auto",
normal_equation_build_device: torch.device | str | None = None,
progress: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build the normal equations directly from the provided samples."""
if normal_equation_cache_directory is not None:
build_device = (
None
if normal_equation_build_device is None
else torch.device(normal_equation_build_device)
)
energy_weight, force_weight = _normal_equation_target_weights(
samples,
fit_energy=self.fit_energy,
fit_forces=self.fit_forces,
fit_per_atom_energy=self.fit_per_atom_energy,
)
components = self._load_or_build_normal_equation_components(
tuple(samples),
batch_size=batch_size,
progress=progress,
cache_directory=normal_equation_cache_directory,
cache_mode=normal_equation_cache_mode,
build_device=build_device,
)
gram, rhs, _, _ = self._finalize_normal_equations(
components,
energy_weight=energy_weight,
force_weight=force_weight,
)
return gram, rhs
return self.build_problem(
samples,
batch_size=batch_size,
cache_directory=cache_directory,
cache_mode=cache_mode,
).accumulate_normal_equations()
[docs]
def write_back(self, theta: torch.Tensor) -> None:
"""Write one solved direct vector back into the model coefficients."""
if self._selected_size() == self.layout.size:
self.layout.write_direct_vector(theta)
return
self._write_selected_vector(theta)
def _empty_problem(self) -> BlockLinearProblem:
"""Return an empty problem shell for interrupted pre-assembly fits."""
return BlockLinearProblem(
layout=BlockProblemLayout.from_blocks(self._direct_blocks()),
batches=(),
)
def _interrupted_result(
self,
theta: torch.Tensor,
*,
problem: BlockLinearProblem | None = None,
restored_checkpoint_path: str | None = None,
) -> LinearFitResult:
"""Build a fallback fit result after an interrupted fit stage."""
result_problem = self._empty_problem() if problem is None else problem
return LinearFitResult(
theta=theta,
objective=float("nan"),
residual_norm=float("nan"),
solver=self.solver,
n_rows=result_problem.n_rows,
n_parameters=self._selected_size(),
layout=self.layout,
problem=result_problem,
interrupted=True,
restored_checkpoint_path=restored_checkpoint_path,
)
[docs]
def write_checkpoint_to_model(self, checkpoint_path: Path | str) -> None:
"""Write the coefficient vector from a CG checkpoint into the model."""
checkpoint = load_cg_checkpoint(
checkpoint_path,
dtype=self.dtype,
device=self.device,
)
expected = self._coefficient_checkpoint_metadata(
n_parameters=self._selected_size(),
dtype=self.dtype,
)
legacy_metadata = _cg_checkpoint_metadata(
n_parameters=self._selected_size(),
dtype=self.dtype,
)
if not self._checkpoint_metadata_contains(checkpoint.metadata, expected):
if not (
self._uses_full_direct_layout()
and checkpoint.metadata == legacy_metadata
):
raise ValueError(
"CG checkpoint metadata does not match this coefficient selection"
)
self.write_back(checkpoint.x)
[docs]
def fit(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
write_back: bool = True,
progress: bool = False,
progress_frequency: int = 10,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
normal_equation_cache: bool = False,
normal_equation_cache_directory: Path | str | None = None,
normal_equation_cache_mode: AssembledBatchCacheMode = "auto",
normal_equation_build_device: torch.device | str | None = None,
warm_start: bool = False,
cg_checkpoint_path: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
) -> LinearFitResult:
"""Build, solve, optionally write back, and summarize one direct linear fit."""
items = tuple(samples)
fallback_theta = self._current_selected_vector(
dtype=self.dtype,
device=self.device,
)
if normal_equation_cache:
if self.solver != "normal_equation_direct":
raise ValueError(
"normal-equation caching is only supported with "
"solver='normal_equation_direct'"
)
if normal_equation_cache_directory is None:
if cache_directory is None:
raise ValueError(
"`cache_directory` or `normal_equation_cache_directory` is "
"required when `normal_equation_cache=True`"
)
normal_equation_cache_directory = (
Path(cache_directory) / "normal_equations"
)
build_device = (
None
if normal_equation_build_device is None
else torch.device(normal_equation_build_device)
)
try:
theta, objective, residual_norm, n_rows = (
self._solve_cached_normal_equations(
items,
batch_size=batch_size,
progress=progress,
cache_directory=normal_equation_cache_directory,
cache_mode=normal_equation_cache_mode,
build_device=build_device,
)
)
except KeyboardInterrupt:
if progress:
print(
"Interrupted cached normal-equation fit; "
"using current coefficients."
)
if write_back:
self.write_back(fallback_theta)
return self._interrupted_result(fallback_theta)
if write_back:
self.write_back(theta)
return LinearFitResult(
theta=theta,
objective=objective,
residual_norm=residual_norm,
solver=self.solver,
n_rows=n_rows,
n_parameters=self._selected_size(),
layout=self.layout,
problem=BlockLinearProblem(
layout=BlockProblemLayout.from_blocks(self._direct_blocks()),
batches=(),
),
)
try:
problem = self.build_problem(
items,
batch_size=batch_size,
progress=progress,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
except KeyboardInterrupt:
if progress:
print("Interrupted least-squares assembly; using current coefficients.")
if write_back:
self.write_back(fallback_theta)
return self._interrupted_result(fallback_theta)
if progress:
print(
"Solving least-squares system: "
f"solver={self.solver}, rows={problem.n_rows}, "
f"parameters={problem.layout.size}"
)
fallback_theta = fallback_theta.to(dtype=problem.dtype, device=problem.device)
solve_result = problem.solve(
solver=self.solver,
cg_tolerance=self.cg_tolerance,
cg_max_iter=self.cg_max_iter,
progress=progress,
progress_frequency=progress_frequency,
initial_theta=(
fallback_theta if warm_start else None
),
fallback_theta=fallback_theta,
return_info=True,
cg_checkpoint_path=cg_checkpoint_path,
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
cg_checkpoint_metadata=(
self._fit_checkpoint_metadata(
items,
n_parameters=problem.layout.size,
dtype=problem.dtype,
)
if self.solver == "cg"
else None
),
)
assert isinstance(solve_result, LinearSolveResult)
theta = solve_result.theta
if write_back:
self.write_back(theta)
if solve_result.interrupted:
return self._interrupted_result(
theta,
problem=problem,
restored_checkpoint_path=solve_result.restored_checkpoint_path,
)
residual_norm_tensor = (
problem.residual_norm(theta)
if isinstance(problem, CachedBlockLinearProblem)
else torch.linalg.norm(problem.matvec(theta) - problem.target_vector())
)
if progress:
print(
"Least-squares solve complete: "
f"||residual||={residual_norm_tensor.item():.6e}"
)
return LinearFitResult(
theta=theta,
objective=float(problem.objective(theta).item()),
residual_norm=float(residual_norm_tensor.item()),
solver=self.solver,
n_rows=problem.n_rows,
n_parameters=problem.layout.size,
layout=self.layout,
problem=problem,
)
__all__ = [
"BlockLinearProblem",
"BlockMatrix",
"BlockProblemLayout",
"BlockRegularization",
"BlockSolveBatch",
"CGCheckpointState",
"CachedBlockLinearProblem",
"ColumnRowIndexedBlockMatrix",
"CoefficientSelector",
"LinearFitResult",
"LinearFitter",
"LinearSolveResult",
"MatrixStorageMode",
"RowIndexedBlockMatrix",
"SolveBlock",
"load_assembled_batches_memmap",
"load_cg_checkpoint",
"save_assembled_batches_memmap",
"save_cg_checkpoint",
]