"""
Alternating least-squares fitting for alchemical coefficient providers.
Use this module when spline blocks are shared through proxy coefficients and
the proxy tensors and mixing weights should be fit in alternating subproblems.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Sequence
import torch
from ufp.leastsquares._block import (
ColumnRowIndexedBlockMatrix,
ColumnRowIndexedChunk,
RowIndexedBlockMatrix,
_block_matrix_diagonal,
_block_matrix_rmatvec,
)
from ufp.leastsquares._layout import ParameterLayout, ProviderGroup
from ufp.leastsquares.dataset import FitSample
from ufp.leastsquares.linear import (
AssembledBatchCacheMode,
BlockLinearProblem,
BlockMatrix,
BlockProblemLayout,
BlockSolveBatch,
LinearFitter,
LinearSolveResult,
SolveBlock,
_block_matrix_matvec,
_make_block_regularization,
_materialize_block_matrix,
_twobody_shape_regularization_rows,
load_cg_checkpoint,
)
from ufp.terms._twobody_shape import (
TwoBodySplineShapePenalty,
normalize_twobody_shape_penalty,
)
from ufp.terms.model import UFPModel
def _provider_proxy_key(
provider_group: ProviderGroup, proxy_index: int
) -> tuple[str, int, int]:
"""Build the solve-key used for one provider proxy block."""
return ("proxy", id(provider_group.provider), int(proxy_index))
def _provider_weight_key(
provider_group: ProviderGroup, true_index: int
) -> tuple[str, int, int]:
"""Build the solve-key used for one provider weight row."""
return ("weight", id(provider_group.provider), int(true_index))
def _svd_initialize(
true_coeffs: torch.Tensor, n_proxy_terms: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initialize proxy and weight factors from a direct true-coefficient matrix."""
n_true_terms, width = true_coeffs.shape
rank = min(int(n_proxy_terms), n_true_terms, width)
weights = torch.zeros(
(n_true_terms, int(n_proxy_terms)),
dtype=true_coeffs.dtype,
device=true_coeffs.device,
)
proxies = torch.zeros(
(int(n_proxy_terms), width),
dtype=true_coeffs.dtype,
device=true_coeffs.device,
)
if rank == 0:
return weights, proxies
u, s, vh = torch.linalg.svd(true_coeffs, full_matrices=False)
sqrt_s = torch.sqrt(torch.clamp(s[:rank], min=0.0))
weights[:, :rank] = u[:, :rank] * sqrt_s[None, :]
proxies[:rank] = sqrt_s[:, None] * vh[:rank]
return weights, proxies
def _normalize_provider(provider_group: ProviderGroup) -> None:
"""Rescale proxy rows and weights without changing true coefficients."""
provider = provider_group.provider
if provider.weights is None or not _provider_weights_are_trainable(provider):
return
proxy = provider.proxy_coeffs.data.reshape(provider_group.n_proxy_terms, -1)
weights = provider.weights.data
for proxy_i in range(provider_group.n_proxy_terms):
norm = torch.linalg.norm(proxy[proxy_i])
if float(norm.item()) <= 0.0:
continue
proxy[proxy_i] /= norm
weights[:, proxy_i] *= norm
provider.proxy_coeffs.data.copy_(proxy.reshape_as(provider.proxy_coeffs))
def _provider_weights_are_trainable(provider) -> bool:
"""Return whether ALS is allowed to update provider mixing weights."""
return (
provider.weights is not None
and isinstance(provider.weights, torch.nn.Parameter)
and provider.weights.requires_grad
)
def _fit_proxies_to_fixed_weights(
true_coeffs: torch.Tensor,
weights: torch.Tensor,
n_proxy_terms: int,
) -> torch.Tensor:
"""Initialize proxy coefficients with fixed provider weights."""
width = true_coeffs.shape[1]
proxies = torch.zeros(
(int(n_proxy_terms), width),
dtype=true_coeffs.dtype,
device=true_coeffs.device,
)
if weights.numel() == 0:
return proxies
solution = torch.linalg.lstsq(weights, true_coeffs).solution
proxies[: solution.shape[0]] = solution[: int(n_proxy_terms)]
return proxies
def _subtract_block_prediction(
target: torch.Tensor,
matrix: BlockMatrix,
theta: torch.Tensor,
) -> torch.Tensor:
"""Subtract one fixed block contribution from a target vector."""
return target - _block_matrix_matvec(matrix, theta).to(
dtype=target.dtype,
device=target.device,
)
def _dense_block_matrix(matrix: BlockMatrix) -> torch.Tensor:
"""Return a dense block matrix for one transient ALS batch."""
return _materialize_block_matrix(matrix)
def _scale_block_matrix(matrix: BlockMatrix, scalar: torch.Tensor) -> BlockMatrix:
"""Scale one block matrix while preserving compact storage when possible."""
if isinstance(matrix, RowIndexedBlockMatrix):
scale = scalar.to(dtype=matrix.values.dtype, device=matrix.values.device)
return RowIndexedBlockMatrix(
rows=matrix.rows,
values=matrix.values * scale,
n_rows=matrix.n_rows,
)
if isinstance(matrix, ColumnRowIndexedBlockMatrix):
chunks = tuple(
ColumnRowIndexedChunk(
column_start=chunk.column_start,
rows=chunk.rows,
values=chunk.values
* scalar.to(dtype=chunk.values.dtype, device=chunk.values.device),
)
for chunk in matrix.chunks
)
return ColumnRowIndexedBlockMatrix(
chunks=chunks,
n_rows=matrix.n_rows,
n_cols=matrix.n_cols,
)
return matrix * scalar.to(dtype=matrix.dtype, device=matrix.device)
def _nonzero_weight_indices(weights: torch.Tensor) -> tuple[int, ...]:
"""Return exact nonzero entries in one provider weight row."""
if weights.numel() == 0:
return ()
nonzero = torch.nonzero(weights.detach() != 0, as_tuple=False).reshape(-1)
return tuple(int(index) for index in nonzero.tolist())
class _StreamingBlockLinearProblem(BlockLinearProblem):
"""Block problem whose normal-equation operations stream lazy batches."""
@property
def dtype(self) -> torch.dtype:
"""Return the dtype without forcing a transformed batch load."""
batch_sequence = self.batches
if hasattr(batch_sequence, "dtype"):
return batch_sequence.dtype
return super().dtype
@property
def device(self) -> torch.device:
"""Return the device without forcing a transformed batch load."""
batch_sequence = self.batches
if hasattr(batch_sequence, "device"):
return batch_sequence.device
return super().device
def _prediction_for_batch(
self,
batch: BlockSolveBatch,
theta: torch.Tensor,
) -> torch.Tensor:
prediction = torch.zeros(
(batch.n_rows,),
dtype=theta.dtype,
device=theta.device,
)
for key, block_matrix in batch.matrices.items():
prediction = prediction + _block_matrix_matvec(
block_matrix,
theta[self.layout.theta_slice(key)],
).to(device=theta.device, dtype=theta.dtype)
return prediction
def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply the regularized normal operator without concatenating rows."""
theta = theta.reshape(self.layout.size)
output = torch.zeros(
(self.layout.size,),
dtype=theta.dtype,
device=theta.device,
)
for batch in self.batches:
prediction = self._prediction_for_batch(batch, theta)
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
prediction,
).to(device=theta.device, dtype=theta.dtype)
return output + self.regularization_apply(theta)
def rhs(self) -> torch.Tensor:
"""Return the right-hand side by streaming transformed batches."""
output = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
for batch in self.batches:
target = batch.target.to(device=self.device, dtype=self.dtype)
for key, block_matrix in batch.matrices.items():
output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
block_matrix,
target,
).to(device=self.device, dtype=self.dtype)
return output + self.regularization_rhs()
def normal_equation_diagonal(self) -> torch.Tensor:
"""Return the diagonal of ``A.T @ A`` by streaming transformed batches."""
diagonal = torch.zeros(
(self.layout.size,),
dtype=self.dtype,
device=self.device,
)
for batch in self.batches:
for key, block_matrix in batch.matrices.items():
diagonal[self.layout.theta_slice(key)] += _block_matrix_diagonal(
block_matrix,
).to(device=self.device, dtype=self.dtype)
return diagonal
def objective(self, theta: torch.Tensor) -> torch.Tensor:
"""Evaluate the regularized objective without materializing all rows."""
theta = theta.reshape(self.layout.size)
value = torch.zeros((), dtype=theta.dtype, device=theta.device)
for batch in self.batches:
prediction = self._prediction_for_batch(batch, theta)
target = batch.target.to(device=theta.device, dtype=theta.dtype)
residual = prediction - target
value = value + torch.dot(residual, residual)
for block in self.layout.blocks:
if block.regularization is None:
continue
theta_slice = self.layout.theta_slice(block.key)
value = value + block.regularization.quadratic(theta[theta_slice])
return value
def residual_norm(self, theta: torch.Tensor) -> torch.Tensor:
"""Return ``||A theta - b||`` without materializing all residual rows."""
theta = theta.reshape(self.layout.size)
squared = torch.zeros((), dtype=theta.dtype, device=theta.device)
for batch in self.batches:
prediction = self._prediction_for_batch(batch, theta)
target = batch.target.to(device=theta.device, dtype=theta.dtype)
residual = prediction - target
squared = squared + torch.dot(residual, residual)
return torch.sqrt(torch.clamp(squared, min=0.0))
class _AlchemicalSubproblemBatchSequence:
"""Lazy batch sequence for one alchemical ALS subproblem."""
def __init__(
self,
*,
fitter: "AlchemicalALSFitter",
true_problem: BlockLinearProblem,
provider_group: ProviderGroup,
mode: str,
) -> None:
"""Store the fixed state used to transform true batches on demand."""
if mode not in {"proxy", "weight"}:
raise ValueError("`mode` must be 'proxy' or 'weight'")
self._fitter = fitter
self._true_problem = true_problem
self._provider_group = provider_group
self._mode = mode
self._current_true = fitter._current_true_vector(true_problem)
self._direct_block_indices = fitter._active_direct_blocks()
self._fixed_provider_ids = {
id(group.provider)
for group in fitter.layout.non_identity_providers()
if id(group.provider) != id(provider_group.provider)
}
provider = provider_group.provider
if mode == "proxy":
if provider.weights is None:
raise ValueError("proxy subproblems require provider weights")
self._weights = provider.weights.detach().to(
dtype=true_problem.dtype,
device=true_problem.device,
)
self._proxy = None
else:
self._weights = None
self._proxy = (
provider.proxy_coeffs.detach()
.reshape(
provider_group.n_proxy_terms,
provider_group.block_size,
)
.to(dtype=true_problem.dtype, device=true_problem.device)
)
def __len__(self) -> int:
"""Return the number of source true-problem batches."""
return len(self._true_problem.batches)
@property
def dtype(self) -> torch.dtype:
"""Return the true-problem dtype for transformed batches."""
return self._true_problem.dtype
@property
def device(self) -> torch.device:
"""Return the true-problem device for transformed batches."""
return self._true_problem.device
def __iter__(self):
"""Yield transformed subproblem batches lazily."""
for batch in self._true_problem.batches:
yield self._transform_batch(batch)
def __getitem__(self, index):
"""Transform one source batch by index, or a tuple for slices."""
if isinstance(index, slice):
return tuple(self[item] for item in range(*index.indices(len(self))))
return self._transform_batch(self._true_problem.batches[index])
def _transform_batch(self, batch: BlockSolveBatch) -> BlockSolveBatch:
"""Transform one true-coefficient batch into the requested subproblem."""
target = batch.target.clone()
matrices: dict[Any, BlockMatrix] = {}
for block_index, matrix in batch.matrices.items():
block = self._fitter.layout.block(int(block_index))
provider_obj = block.coefficient_provider
if (
provider_obj is not None
and id(provider_obj) in self._fixed_provider_ids
):
target = _subtract_block_prediction(
target,
matrix,
self._current_true[block.theta_slice],
)
continue
if block_index in self._direct_block_indices:
matrices[block_index] = matrix
continue
if provider_obj is None or id(provider_obj) != id(
self._provider_group.provider
):
continue
if self._mode == "proxy":
self._add_proxy_matrices(matrices, block, matrix)
else:
self._add_weight_matrix(matrices, block, matrix)
return BlockSolveBatch(target=target, matrices=matrices)
def _add_proxy_matrices(
self,
matrices: dict[Any, BlockMatrix],
block,
matrix: BlockMatrix,
) -> None:
"""Add fixed-weight proxy solve matrices for one true block."""
assert block.coefficient_index is not None
assert self._weights is not None
weights = self._weights[block.coefficient_index]
for proxy_index in _nonzero_weight_indices(weights):
key = _provider_proxy_key(self._provider_group, proxy_index)
contribution = _scale_block_matrix(matrix, weights[proxy_index])
if key in matrices:
matrices[key] = _dense_block_matrix(
matrices[key]
) + _dense_block_matrix(contribution)
else:
matrices[key] = contribution
def _add_weight_matrix(
self,
matrices: dict[Any, BlockMatrix],
block,
matrix: BlockMatrix,
) -> None:
"""Add fixed-proxy weight solve matrix for one true block."""
assert block.coefficient_index is not None
assert self._proxy is not None
key = _provider_weight_key(self._provider_group, block.coefficient_index)
matrices[key] = _dense_block_matrix(matrix) @ self._proxy.T
[docs]
@dataclass(frozen=True)
class AlchemicalALSResult:
"""Summary of one alternating least-squares fit over alchemical coefficients."""
theta: torch.Tensor
objective_history: tuple[float, ...]
converged: bool
sweeps: int
layout: ParameterLayout
problem: BlockLinearProblem
interrupted: bool = False
restored_checkpoint_path: str | None = None
[docs]
class AlchemicalALSFitter:
"""Alternating least-squares driver for models with shared alchemical providers."""
def __init__(
self,
model: UFPModel,
*,
fit_energy: bool = True,
fit_forces: bool = True,
fit_per_atom_energy: bool = False,
solver: str = "cg",
ridge: float = 0.0,
onebody_ridge: float | None = None,
pair_ridge: float | None = None,
twobody_ridge: float | None = None,
threebody_ridge: float | None = None,
twobody_shape_penalty: TwoBodySplineShapePenalty | None = None,
weight_ridge: float | None = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
cg_tolerance: float = 1.0e-10,
cg_max_iter: int | None = None,
threebody_lstsq_backend: str | None = None,
threebody_bucket_backend: str | None = None,
max_sweeps: int = 10,
tolerance: float = 1.0e-8,
) -> None:
"""Store ALS settings and reuse one true-coefficient fitter."""
self.linear_fitter = LinearFitter(
model,
fit_energy=fit_energy,
fit_forces=fit_forces,
fit_per_atom_energy=fit_per_atom_energy,
solver=solver,
ridge=ridge,
onebody_ridge=onebody_ridge,
pair_ridge=pair_ridge,
twobody_ridge=twobody_ridge,
threebody_ridge=threebody_ridge,
twobody_shape_penalty=twobody_shape_penalty,
dtype=dtype,
device=device,
cg_tolerance=cg_tolerance,
cg_max_iter=cg_max_iter,
threebody_lstsq_backend=threebody_lstsq_backend,
threebody_bucket_backend=threebody_bucket_backend,
)
self.model = model
self.solver = solver
self.ridge = float(ridge)
self.onebody_ridge = self.linear_fitter.onebody_ridge
self.pair_ridge = self.linear_fitter.pair_ridge
self.threebody_ridge = self.linear_fitter.threebody_ridge
self.twobody_shape_penalty = normalize_twobody_shape_penalty(
twobody_shape_penalty
)
self.weight_ridge = (
max(float(ridge), 1.0e-12) if weight_ridge is None else float(weight_ridge)
)
self.cg_tolerance = float(cg_tolerance)
self.cg_max_iter = cg_max_iter
self.max_sweeps = int(max_sweeps)
self.tolerance = float(tolerance)
self.layout = self.linear_fitter.layout
def _ridge_for_block_index(self, block_index: int) -> float:
"""Return the coefficient ridge assigned to one true block."""
block = self.layout.block(block_index)
if block.kind == "onebody":
return self.onebody_ridge
if block.kind in ("pair", "twobody"):
return self.pair_ridge
if block.kind == "threebody":
return self.threebody_ridge
return self.ridge
def _third_difference_for_block_index(self, block_index: int) -> float:
"""Return the two-body third-difference penalty for one block."""
block = self.layout.block(block_index)
if block.kind in {"pair", "twobody"}:
return self.twobody_shape_penalty.third_difference_weight
return 0.0
def _provider_twobody_active_rows(
self,
provider_group: ProviderGroup,
) -> tuple[int, ...] | None:
"""Return active two-body rows for a provider-owned coefficient shape."""
for block_index in provider_group.block_indices:
block = self.layout.block(block_index)
if block.kind == "twobody":
return _twobody_shape_regularization_rows(block)
return None
def _third_difference_for_provider(self, provider_group: ProviderGroup) -> float:
"""Return the provider proxy third-difference penalty."""
if any(
self.layout.block(block_index).kind in {"pair", "twobody"}
for block_index in provider_group.block_indices
):
return self.twobody_shape_penalty.third_difference_weight
return 0.0
def _ridge_for_provider(self, provider_group: ProviderGroup) -> float:
"""Return the proxy-coefficient ridge for one alchemical provider."""
block_index = provider_group.block_indices[0]
return self._ridge_for_block_index(block_index)
def _initialize_from_direct_solution(self, theta: torch.Tensor) -> None:
"""Seed proxy and weight factors from the direct true-coefficient solve."""
for block in self.layout.blocks:
if (
block.coefficient_provider is None
or block.coefficient_provider.uses_identity_weights
):
self.layout.write_block_vector(
block.index,
theta[block.theta_slice],
)
for provider_group in self.layout.non_identity_providers():
true_matrix = self.layout.provider_true_matrix(theta, provider_group)
weights, proxies = _svd_initialize(
true_matrix, provider_group.n_proxy_terms
)
provider = provider_group.provider
if not _provider_weights_are_trainable(provider):
assert provider.weights is not None
fixed_weights = provider.weights.to(
dtype=true_matrix.dtype,
device=true_matrix.device,
)
proxies = _fit_proxies_to_fixed_weights(
true_matrix,
fixed_weights,
provider_group.n_proxy_terms,
)
provider.proxy_coeffs.data.copy_(
proxies.reshape_as(provider.proxy_coeffs).to(provider.proxy_coeffs)
)
continue
provider.proxy_coeffs.data.copy_(
proxies.reshape_as(provider.proxy_coeffs).to(provider.proxy_coeffs)
)
assert provider.weights is not None
provider.weights.data.copy_(weights.to(provider.weights))
_normalize_provider(provider_group)
def _current_true_vector(self, problem: BlockLinearProblem) -> torch.Tensor:
"""Read current true coefficients in the problem layout."""
return self.layout.current_true_vector(
dtype=problem.dtype, device=problem.device
)
[docs]
def initialize_from_direct_cg_checkpoint(self, checkpoint_path: Path | str) -> None:
"""Initialize alchemical coefficients from a direct true-CG checkpoint."""
checkpoint = load_cg_checkpoint(
checkpoint_path,
dtype=self.linear_fitter.dtype,
device=self.linear_fitter.device,
)
theta = checkpoint.x.reshape(-1)
if theta.numel() != self.layout.size:
raise ValueError(
"CG checkpoint parameter count does not match this alchemical layout"
)
metadata_n_parameters = checkpoint.metadata.get("n_parameters")
if metadata_n_parameters is not None:
try:
metadata_n_parameters = int(metadata_n_parameters)
except (TypeError, ValueError) as exc:
raise ValueError(
"CG checkpoint metadata contains an invalid parameter count"
) from exc
if metadata_n_parameters != self.layout.size:
raise ValueError(
"CG checkpoint metadata does not match this alchemical layout"
)
if self.layout.non_identity_providers():
self._initialize_from_direct_solution(theta)
else:
self.linear_fitter.write_back(theta)
def _active_direct_blocks(self) -> tuple[int, ...]:
"""Return direct blocks that stay in every ALS subproblem."""
return self.layout.direct_block_indices()
def _proxy_initial_vector(
self,
problem: BlockLinearProblem,
provider_group: ProviderGroup,
) -> torch.Tensor:
"""Return the current model state in one proxy-subproblem layout."""
beta = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
current_true = self._current_true_vector(problem)
for block_index in self._active_direct_blocks():
beta[problem.layout.theta_slice(block_index)] = current_true[
self.layout.block(block_index).theta_slice
]
provider = provider_group.provider
proxy = provider.proxy_coeffs.reshape(
provider_group.n_proxy_terms,
provider_group.block_size,
).to(dtype=problem.dtype, device=problem.device)
for proxy_index in range(provider_group.n_proxy_terms):
beta[
problem.layout.theta_slice(
_provider_proxy_key(
provider_group,
proxy_index,
)
)
] = proxy[proxy_index]
return beta
def _weight_initial_vector(
self,
problem: BlockLinearProblem,
provider_group: ProviderGroup,
) -> torch.Tensor:
"""Return the current model state in one weight-subproblem layout."""
beta = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
current_true = self._current_true_vector(problem)
for block_index in self._active_direct_blocks():
beta[problem.layout.theta_slice(block_index)] = current_true[
self.layout.block(block_index).theta_slice
]
provider = provider_group.provider
assert provider.weights is not None
weights = provider.weights.to(dtype=problem.dtype, device=problem.device)
for true_index in range(provider_group.n_true_terms):
beta[
problem.layout.theta_slice(
_provider_weight_key(
provider_group,
true_index,
)
)
] = weights[true_index]
return beta
def _write_checkpoint(
self,
checkpoint_directory: Path | str | None,
*,
stage: str,
sweep: int,
provider_index: int | None,
objective_history: Sequence[float],
true_problem: BlockLinearProblem,
) -> None:
"""Write a restorable alchemical model checkpoint."""
if checkpoint_directory is None:
return
checkpoint_dir = Path(checkpoint_directory)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
payload = {
"model_state_dict": self.model.state_dict(),
"theta": self._current_true_vector(true_problem).detach().cpu(),
"objective_history": tuple(float(value) for value in objective_history),
"sweep": int(sweep),
"provider_index": provider_index,
"stage": str(stage),
}
latest_path = checkpoint_dir / "alchemical_latest.pt"
torch.save(payload, latest_path)
stage_name = (
f"sweep{sweep}_{stage}"
if provider_index is None
else f"sweep{sweep}_provider{provider_index}_{stage}"
)
torch.save(payload, checkpoint_dir / f"alchemical_{stage_name}.pt")
def _make_proxy_problem(
self,
true_problem: BlockLinearProblem,
provider_group: ProviderGroup,
) -> BlockLinearProblem:
"""Build the linear subproblem that updates one provider's proxy blocks."""
direct_block_indices = self._active_direct_blocks()
solve_blocks: list[SolveBlock] = []
for block_index in direct_block_indices:
block = self.layout.block(block_index)
solve_blocks.append(
SolveBlock(
key=block_index,
size=block.size,
label=block.label,
regularization=_make_block_regularization(
block.shape,
ridge=self._ridge_for_block_index(block_index),
third_difference_penalty=(
self._third_difference_for_block_index(block_index)
),
active_rows=_twobody_shape_regularization_rows(block),
),
)
)
for proxy_index in range(provider_group.n_proxy_terms):
solve_blocks.append(
SolveBlock(
key=_provider_proxy_key(provider_group, proxy_index),
size=provider_group.block_size,
label=f"proxy[{proxy_index}]",
regularization=_make_block_regularization(
provider_group.coefficient_shape,
ridge=self._ridge_for_provider(provider_group),
third_difference_penalty=(
self._third_difference_for_provider(provider_group)
),
active_rows=self._provider_twobody_active_rows(provider_group),
),
)
)
return _StreamingBlockLinearProblem(
layout=BlockProblemLayout.from_blocks(tuple(solve_blocks)),
batches=_AlchemicalSubproblemBatchSequence(
fitter=self,
true_problem=true_problem,
provider_group=provider_group,
mode="proxy",
),
)
def _make_weight_problem(
self,
true_problem: BlockLinearProblem,
provider_group: ProviderGroup,
) -> BlockLinearProblem:
"""Build the linear subproblem that updates one provider's mixing weights."""
direct_block_indices = self._active_direct_blocks()
solve_blocks: list[SolveBlock] = []
for block_index in direct_block_indices:
block = self.layout.block(block_index)
solve_blocks.append(
SolveBlock(
key=block_index,
size=block.size,
label=block.label,
regularization=_make_block_regularization(
block.shape,
ridge=self._ridge_for_block_index(block_index),
third_difference_penalty=(
self._third_difference_for_block_index(block_index)
),
active_rows=_twobody_shape_regularization_rows(block),
),
)
)
for true_index in range(provider_group.n_true_terms):
solve_blocks.append(
SolveBlock(
key=_provider_weight_key(provider_group, true_index),
size=provider_group.n_proxy_terms,
label=f"weights[{true_index}]",
regularization=_make_block_regularization(
(provider_group.n_proxy_terms,),
ridge=self.weight_ridge,
),
)
)
return _StreamingBlockLinearProblem(
layout=BlockProblemLayout.from_blocks(tuple(solve_blocks)),
batches=_AlchemicalSubproblemBatchSequence(
fitter=self,
true_problem=true_problem,
provider_group=provider_group,
mode="weight",
),
)
def _solve_provider_proxy_subproblem(
self,
true_problem: BlockLinearProblem,
provider_group: ProviderGroup,
*,
cg_checkpoint_path: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
progress: bool = False,
progress_frequency: int = 10,
) -> LinearSolveResult:
"""Solve and write back one provider's proxy-update subproblem."""
problem = self._make_proxy_problem(true_problem, provider_group)
initial_theta = self._proxy_initial_vector(problem, provider_group)
result = problem.solve(
solver=self.solver,
cg_tolerance=self.cg_tolerance,
cg_max_iter=self.cg_max_iter,
initial_theta=initial_theta,
fallback_theta=initial_theta,
return_info=True,
progress=progress,
progress_frequency=progress_frequency,
cg_checkpoint_path=cg_checkpoint_path,
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
)
assert isinstance(result, LinearSolveResult)
beta = result.theta
for block_index in self._active_direct_blocks():
theta_slice = problem.layout.theta_slice(block_index)
self.layout.write_block_vector(block_index, beta[theta_slice])
provider = provider_group.provider
for proxy_index in range(provider_group.n_proxy_terms):
theta_slice = problem.layout.theta_slice(
_provider_proxy_key(provider_group, proxy_index)
)
provider.proxy_coeffs.data[proxy_index].copy_(
beta[theta_slice]
.reshape(provider_group.coefficient_shape)
.to(provider.proxy_coeffs)
)
_normalize_provider(provider_group)
return result
def _solve_provider_weight_subproblem(
self,
true_problem: BlockLinearProblem,
provider_group: ProviderGroup,
*,
cg_checkpoint_path: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
progress: bool = False,
progress_frequency: int = 10,
) -> LinearSolveResult:
"""Solve and write back one provider's weight-update subproblem."""
problem = self._make_weight_problem(true_problem, provider_group)
initial_theta = self._weight_initial_vector(problem, provider_group)
result = problem.solve(
solver=self.solver,
cg_tolerance=self.cg_tolerance,
cg_max_iter=self.cg_max_iter,
initial_theta=initial_theta,
fallback_theta=initial_theta,
return_info=True,
progress=progress,
progress_frequency=progress_frequency,
cg_checkpoint_path=cg_checkpoint_path,
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
)
assert isinstance(result, LinearSolveResult)
beta = result.theta
for block_index in self._active_direct_blocks():
theta_slice = problem.layout.theta_slice(block_index)
self.layout.write_block_vector(block_index, beta[theta_slice])
provider = provider_group.provider
assert provider.weights is not None
for true_index in range(provider_group.n_true_terms):
theta_slice = problem.layout.theta_slice(
_provider_weight_key(provider_group, true_index)
)
provider.weights.data[true_index].copy_(
beta[theta_slice].to(provider.weights)
)
_normalize_provider(provider_group)
return result
[docs]
def fit(
self,
samples: Sequence[FitSample],
*,
batch_size: int = 32,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
initialize: Literal["svd", "current"] = "svd",
checkpoint_directory: Path | str | None = None,
checkpoint_frequency: int = 1,
cg_checkpoint_directory: Path | str | None = None,
cg_checkpoint_frequency: int = 1,
cg_resume: bool = False,
progress: bool = False,
progress_frequency: int = 10,
) -> AlchemicalALSResult:
"""Alternate proxy and weight solves until convergence."""
if initialize not in {"svd", "current"}:
raise ValueError("`initialize` must be 'svd' or 'current'")
if progress_frequency <= 0:
raise ValueError("`progress_frequency` must be positive")
try:
true_problem = self.linear_fitter.build_problem(
samples,
batch_size=batch_size,
progress=progress,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
except KeyboardInterrupt:
empty_problem = BlockLinearProblem(
layout=BlockProblemLayout.from_blocks(
self.linear_fitter._direct_blocks()
),
batches=(),
)
return AlchemicalALSResult(
theta=self.layout.current_true_vector(
dtype=self.linear_fitter.dtype,
device=self.linear_fitter.device,
),
objective_history=(float("nan"),),
converged=False,
sweeps=0,
layout=self.layout,
problem=empty_problem,
interrupted=True,
)
non_identity_providers = self.layout.non_identity_providers()
if not non_identity_providers:
initial_theta = (
self._current_true_vector(true_problem)
if initialize == "current"
else None
)
direct_result = true_problem.solve(
solver=self.solver,
cg_tolerance=self.cg_tolerance,
cg_max_iter=self.cg_max_iter,
initial_theta=initial_theta,
fallback_theta=self._current_true_vector(true_problem),
return_info=True,
progress=progress,
progress_frequency=progress_frequency,
cg_checkpoint_path=(
None
if cg_checkpoint_directory is None
else Path(cg_checkpoint_directory) / "direct_cg.npz"
),
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
)
assert isinstance(direct_result, LinearSolveResult)
direct_theta = direct_result.theta
self.linear_fitter.write_back(direct_theta)
objective = (
float("nan")
if direct_result.interrupted
else float(true_problem.objective(direct_theta).item())
)
return AlchemicalALSResult(
theta=direct_theta,
objective_history=(objective,),
converged=not direct_result.interrupted,
sweeps=0,
layout=self.layout,
problem=true_problem,
interrupted=direct_result.interrupted,
restored_checkpoint_path=direct_result.restored_checkpoint_path,
)
if initialize == "svd":
direct_result = true_problem.solve(
solver=self.solver,
cg_tolerance=self.cg_tolerance,
cg_max_iter=self.cg_max_iter,
fallback_theta=self._current_true_vector(true_problem),
return_info=True,
progress=progress,
progress_frequency=progress_frequency,
cg_checkpoint_path=(
None
if cg_checkpoint_directory is None
else Path(cg_checkpoint_directory) / "direct_cg.npz"
),
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
)
assert isinstance(direct_result, LinearSolveResult)
direct_theta = direct_result.theta
if direct_result.interrupted:
objective_history = (float("nan"),)
self._write_checkpoint(
checkpoint_directory,
stage="interrupted",
sweep=0,
provider_index=None,
objective_history=objective_history,
true_problem=true_problem,
)
return AlchemicalALSResult(
theta=self._current_true_vector(true_problem),
objective_history=objective_history,
converged=False,
sweeps=0,
layout=self.layout,
problem=true_problem,
interrupted=True,
restored_checkpoint_path=direct_result.restored_checkpoint_path,
)
self._initialize_from_direct_solution(direct_theta)
objective_history = [
float(
true_problem.objective(self._current_true_vector(true_problem)).item()
)
]
self._write_checkpoint(
checkpoint_directory,
stage="initialized",
sweep=0,
provider_index=None,
objective_history=objective_history,
true_problem=true_problem,
)
converged = False
interrupted = False
restored_checkpoint_path = None
sweeps = 0
checkpoint_frequency = max(int(checkpoint_frequency), 1)
for sweep in range(1, self.max_sweeps + 1):
previous_theta = self._current_true_vector(true_problem)
previous_objective = objective_history[-1]
for provider_index, provider_group in enumerate(non_identity_providers):
cg_dir = (
None
if cg_checkpoint_directory is None
else Path(cg_checkpoint_directory)
)
solve_result = self._solve_provider_proxy_subproblem(
true_problem,
provider_group,
cg_checkpoint_path=(
None
if cg_dir is None
else cg_dir / f"sweep{sweep}_provider{provider_index}_proxy.npz"
),
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
progress=progress,
progress_frequency=progress_frequency,
)
if solve_result.interrupted:
interrupted = True
restored_checkpoint_path = solve_result.restored_checkpoint_path
self._write_checkpoint(
checkpoint_directory,
stage="proxy",
sweep=sweep,
provider_index=provider_index,
objective_history=objective_history,
true_problem=true_problem,
)
if interrupted:
break
if _provider_weights_are_trainable(provider_group.provider):
solve_result = self._solve_provider_weight_subproblem(
true_problem,
provider_group,
cg_checkpoint_path=(
None
if cg_dir is None
else cg_dir
/ f"sweep{sweep}_provider{provider_index}_weight.npz"
),
cg_checkpoint_frequency=cg_checkpoint_frequency,
cg_resume=cg_resume,
progress=progress,
progress_frequency=progress_frequency,
)
if solve_result.interrupted:
interrupted = True
restored_checkpoint_path = solve_result.restored_checkpoint_path
self._write_checkpoint(
checkpoint_directory,
stage="weight",
sweep=sweep,
provider_index=provider_index,
objective_history=objective_history,
true_problem=true_problem,
)
if interrupted:
break
if interrupted:
sweeps = sweep
self._write_checkpoint(
checkpoint_directory,
stage="interrupted",
sweep=sweep,
provider_index=None,
objective_history=objective_history,
true_problem=true_problem,
)
break
current_theta = self._current_true_vector(true_problem)
current_objective = float(true_problem.objective(current_theta).item())
objective_history.append(current_objective)
sweeps = sweep
if sweep % checkpoint_frequency == 0:
self._write_checkpoint(
checkpoint_directory,
stage="sweep",
sweep=sweep,
provider_index=None,
objective_history=objective_history,
true_problem=true_problem,
)
theta_norm = float(torch.linalg.norm(previous_theta).item())
theta_delta = float(
torch.linalg.norm(current_theta - previous_theta).item()
)
relative_theta_change = theta_delta / max(theta_norm, 1.0)
relative_objective_change = abs(
current_objective - previous_objective
) / max(
abs(previous_objective),
1.0,
)
if (
relative_theta_change <= self.tolerance
or relative_objective_change <= self.tolerance
):
converged = True
break
return AlchemicalALSResult(
theta=self._current_true_vector(true_problem),
objective_history=tuple(objective_history),
converged=converged,
sweeps=sweeps,
layout=self.layout,
problem=true_problem,
interrupted=interrupted,
restored_checkpoint_path=restored_checkpoint_path,
)
__all__ = [
"AlchemicalALSResult",
"AlchemicalALSFitter",
]