Source code for ufp.leastsquares.alchemical

"""
Alternating least-squares fitting for alchemical coefficient providers.

Use this module when spline blocks are shared through proxy coefficients and
the proxy tensors and mixing weights should be fit in alternating subproblems.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Sequence

import torch

from ufp.leastsquares._block import (
    ColumnRowIndexedBlockMatrix,
    ColumnRowIndexedChunk,
    RowIndexedBlockMatrix,
    _block_matrix_diagonal,
    _block_matrix_rmatvec,
)
from ufp.leastsquares._layout import ParameterLayout, ProviderGroup
from ufp.leastsquares.dataset import FitSample
from ufp.leastsquares.linear import (
    AssembledBatchCacheMode,
    BlockLinearProblem,
    BlockMatrix,
    BlockProblemLayout,
    BlockSolveBatch,
    LinearFitter,
    LinearSolveResult,
    SolveBlock,
    _block_matrix_matvec,
    _make_block_regularization,
    _materialize_block_matrix,
    _twobody_shape_regularization_rows,
    load_cg_checkpoint,
)
from ufp.terms._twobody_shape import (
    TwoBodySplineShapePenalty,
    normalize_twobody_shape_penalty,
)
from ufp.terms.model import UFPModel


def _provider_proxy_key(
    provider_group: ProviderGroup, proxy_index: int
) -> tuple[str, int, int]:
    """Build the solve-key used for one provider proxy block."""
    return ("proxy", id(provider_group.provider), int(proxy_index))


def _provider_weight_key(
    provider_group: ProviderGroup, true_index: int
) -> tuple[str, int, int]:
    """Build the solve-key used for one provider weight row."""
    return ("weight", id(provider_group.provider), int(true_index))


def _svd_initialize(
    true_coeffs: torch.Tensor, n_proxy_terms: int
) -> tuple[torch.Tensor, torch.Tensor]:
    """Initialize proxy and weight factors from a direct true-coefficient matrix."""
    n_true_terms, width = true_coeffs.shape
    rank = min(int(n_proxy_terms), n_true_terms, width)
    weights = torch.zeros(
        (n_true_terms, int(n_proxy_terms)),
        dtype=true_coeffs.dtype,
        device=true_coeffs.device,
    )
    proxies = torch.zeros(
        (int(n_proxy_terms), width),
        dtype=true_coeffs.dtype,
        device=true_coeffs.device,
    )
    if rank == 0:
        return weights, proxies

    u, s, vh = torch.linalg.svd(true_coeffs, full_matrices=False)
    sqrt_s = torch.sqrt(torch.clamp(s[:rank], min=0.0))
    weights[:, :rank] = u[:, :rank] * sqrt_s[None, :]
    proxies[:rank] = sqrt_s[:, None] * vh[:rank]
    return weights, proxies


def _normalize_provider(provider_group: ProviderGroup) -> None:
    """Rescale proxy rows and weights without changing true coefficients."""
    provider = provider_group.provider
    if provider.weights is None or not _provider_weights_are_trainable(provider):
        return

    proxy = provider.proxy_coeffs.data.reshape(provider_group.n_proxy_terms, -1)
    weights = provider.weights.data
    for proxy_i in range(provider_group.n_proxy_terms):
        norm = torch.linalg.norm(proxy[proxy_i])
        if float(norm.item()) <= 0.0:
            continue
        proxy[proxy_i] /= norm
        weights[:, proxy_i] *= norm
    provider.proxy_coeffs.data.copy_(proxy.reshape_as(provider.proxy_coeffs))


def _provider_weights_are_trainable(provider) -> bool:
    """Return whether ALS is allowed to update provider mixing weights."""
    return (
        provider.weights is not None
        and isinstance(provider.weights, torch.nn.Parameter)
        and provider.weights.requires_grad
    )


def _fit_proxies_to_fixed_weights(
    true_coeffs: torch.Tensor,
    weights: torch.Tensor,
    n_proxy_terms: int,
) -> torch.Tensor:
    """Initialize proxy coefficients with fixed provider weights."""
    width = true_coeffs.shape[1]
    proxies = torch.zeros(
        (int(n_proxy_terms), width),
        dtype=true_coeffs.dtype,
        device=true_coeffs.device,
    )
    if weights.numel() == 0:
        return proxies
    solution = torch.linalg.lstsq(weights, true_coeffs).solution
    proxies[: solution.shape[0]] = solution[: int(n_proxy_terms)]
    return proxies


def _subtract_block_prediction(
    target: torch.Tensor,
    matrix: BlockMatrix,
    theta: torch.Tensor,
) -> torch.Tensor:
    """Subtract one fixed block contribution from a target vector."""
    return target - _block_matrix_matvec(matrix, theta).to(
        dtype=target.dtype,
        device=target.device,
    )


def _dense_block_matrix(matrix: BlockMatrix) -> torch.Tensor:
    """Return a dense block matrix for one transient ALS batch."""
    return _materialize_block_matrix(matrix)


def _scale_block_matrix(matrix: BlockMatrix, scalar: torch.Tensor) -> BlockMatrix:
    """Scale one block matrix while preserving compact storage when possible."""
    if isinstance(matrix, RowIndexedBlockMatrix):
        scale = scalar.to(dtype=matrix.values.dtype, device=matrix.values.device)
        return RowIndexedBlockMatrix(
            rows=matrix.rows,
            values=matrix.values * scale,
            n_rows=matrix.n_rows,
        )
    if isinstance(matrix, ColumnRowIndexedBlockMatrix):
        chunks = tuple(
            ColumnRowIndexedChunk(
                column_start=chunk.column_start,
                rows=chunk.rows,
                values=chunk.values
                * scalar.to(dtype=chunk.values.dtype, device=chunk.values.device),
            )
            for chunk in matrix.chunks
        )
        return ColumnRowIndexedBlockMatrix(
            chunks=chunks,
            n_rows=matrix.n_rows,
            n_cols=matrix.n_cols,
        )
    return matrix * scalar.to(dtype=matrix.dtype, device=matrix.device)


def _nonzero_weight_indices(weights: torch.Tensor) -> tuple[int, ...]:
    """Return exact nonzero entries in one provider weight row."""
    if weights.numel() == 0:
        return ()
    nonzero = torch.nonzero(weights.detach() != 0, as_tuple=False).reshape(-1)
    return tuple(int(index) for index in nonzero.tolist())


class _StreamingBlockLinearProblem(BlockLinearProblem):
    """Block problem whose normal-equation operations stream lazy batches."""

    @property
    def dtype(self) -> torch.dtype:
        """Return the dtype without forcing a transformed batch load."""
        batch_sequence = self.batches
        if hasattr(batch_sequence, "dtype"):
            return batch_sequence.dtype
        return super().dtype

    @property
    def device(self) -> torch.device:
        """Return the device without forcing a transformed batch load."""
        batch_sequence = self.batches
        if hasattr(batch_sequence, "device"):
            return batch_sequence.device
        return super().device

    def _prediction_for_batch(
        self,
        batch: BlockSolveBatch,
        theta: torch.Tensor,
    ) -> torch.Tensor:
        prediction = torch.zeros(
            (batch.n_rows,),
            dtype=theta.dtype,
            device=theta.device,
        )
        for key, block_matrix in batch.matrices.items():
            prediction = prediction + _block_matrix_matvec(
                block_matrix,
                theta[self.layout.theta_slice(key)],
            ).to(device=theta.device, dtype=theta.dtype)
        return prediction

    def normal_matvec(self, theta: torch.Tensor) -> torch.Tensor:
        """Apply the regularized normal operator without concatenating rows."""
        theta = theta.reshape(self.layout.size)
        output = torch.zeros(
            (self.layout.size,),
            dtype=theta.dtype,
            device=theta.device,
        )
        for batch in self.batches:
            prediction = self._prediction_for_batch(batch, theta)
            for key, block_matrix in batch.matrices.items():
                output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
                    block_matrix,
                    prediction,
                ).to(device=theta.device, dtype=theta.dtype)
        return output + self.regularization_apply(theta)

    def rhs(self) -> torch.Tensor:
        """Return the right-hand side by streaming transformed batches."""
        output = torch.zeros(
            (self.layout.size,),
            dtype=self.dtype,
            device=self.device,
        )
        for batch in self.batches:
            target = batch.target.to(device=self.device, dtype=self.dtype)
            for key, block_matrix in batch.matrices.items():
                output[self.layout.theta_slice(key)] += _block_matrix_rmatvec(
                    block_matrix,
                    target,
                ).to(device=self.device, dtype=self.dtype)
        return output + self.regularization_rhs()

    def normal_equation_diagonal(self) -> torch.Tensor:
        """Return the diagonal of ``A.T @ A`` by streaming transformed batches."""
        diagonal = torch.zeros(
            (self.layout.size,),
            dtype=self.dtype,
            device=self.device,
        )
        for batch in self.batches:
            for key, block_matrix in batch.matrices.items():
                diagonal[self.layout.theta_slice(key)] += _block_matrix_diagonal(
                    block_matrix,
                ).to(device=self.device, dtype=self.dtype)
        return diagonal

    def objective(self, theta: torch.Tensor) -> torch.Tensor:
        """Evaluate the regularized objective without materializing all rows."""
        theta = theta.reshape(self.layout.size)
        value = torch.zeros((), dtype=theta.dtype, device=theta.device)
        for batch in self.batches:
            prediction = self._prediction_for_batch(batch, theta)
            target = batch.target.to(device=theta.device, dtype=theta.dtype)
            residual = prediction - target
            value = value + torch.dot(residual, residual)
        for block in self.layout.blocks:
            if block.regularization is None:
                continue
            theta_slice = self.layout.theta_slice(block.key)
            value = value + block.regularization.quadratic(theta[theta_slice])
        return value

    def residual_norm(self, theta: torch.Tensor) -> torch.Tensor:
        """Return ``||A theta - b||`` without materializing all residual rows."""
        theta = theta.reshape(self.layout.size)
        squared = torch.zeros((), dtype=theta.dtype, device=theta.device)
        for batch in self.batches:
            prediction = self._prediction_for_batch(batch, theta)
            target = batch.target.to(device=theta.device, dtype=theta.dtype)
            residual = prediction - target
            squared = squared + torch.dot(residual, residual)
        return torch.sqrt(torch.clamp(squared, min=0.0))


class _AlchemicalSubproblemBatchSequence:
    """Lazy batch sequence for one alchemical ALS subproblem."""

    def __init__(
        self,
        *,
        fitter: "AlchemicalALSFitter",
        true_problem: BlockLinearProblem,
        provider_group: ProviderGroup,
        mode: str,
    ) -> None:
        """Store the fixed state used to transform true batches on demand."""
        if mode not in {"proxy", "weight"}:
            raise ValueError("`mode` must be 'proxy' or 'weight'")
        self._fitter = fitter
        self._true_problem = true_problem
        self._provider_group = provider_group
        self._mode = mode
        self._current_true = fitter._current_true_vector(true_problem)
        self._direct_block_indices = fitter._active_direct_blocks()
        self._fixed_provider_ids = {
            id(group.provider)
            for group in fitter.layout.non_identity_providers()
            if id(group.provider) != id(provider_group.provider)
        }

        provider = provider_group.provider
        if mode == "proxy":
            if provider.weights is None:
                raise ValueError("proxy subproblems require provider weights")
            self._weights = provider.weights.detach().to(
                dtype=true_problem.dtype,
                device=true_problem.device,
            )
            self._proxy = None
        else:
            self._weights = None
            self._proxy = (
                provider.proxy_coeffs.detach()
                .reshape(
                    provider_group.n_proxy_terms,
                    provider_group.block_size,
                )
                .to(dtype=true_problem.dtype, device=true_problem.device)
            )

    def __len__(self) -> int:
        """Return the number of source true-problem batches."""
        return len(self._true_problem.batches)

    @property
    def dtype(self) -> torch.dtype:
        """Return the true-problem dtype for transformed batches."""
        return self._true_problem.dtype

    @property
    def device(self) -> torch.device:
        """Return the true-problem device for transformed batches."""
        return self._true_problem.device

    def __iter__(self):
        """Yield transformed subproblem batches lazily."""
        for batch in self._true_problem.batches:
            yield self._transform_batch(batch)

    def __getitem__(self, index):
        """Transform one source batch by index, or a tuple for slices."""
        if isinstance(index, slice):
            return tuple(self[item] for item in range(*index.indices(len(self))))
        return self._transform_batch(self._true_problem.batches[index])

    def _transform_batch(self, batch: BlockSolveBatch) -> BlockSolveBatch:
        """Transform one true-coefficient batch into the requested subproblem."""
        target = batch.target.clone()
        matrices: dict[Any, BlockMatrix] = {}
        for block_index, matrix in batch.matrices.items():
            block = self._fitter.layout.block(int(block_index))
            provider_obj = block.coefficient_provider
            if (
                provider_obj is not None
                and id(provider_obj) in self._fixed_provider_ids
            ):
                target = _subtract_block_prediction(
                    target,
                    matrix,
                    self._current_true[block.theta_slice],
                )
                continue
            if block_index in self._direct_block_indices:
                matrices[block_index] = matrix
                continue

            if provider_obj is None or id(provider_obj) != id(
                self._provider_group.provider
            ):
                continue

            if self._mode == "proxy":
                self._add_proxy_matrices(matrices, block, matrix)
            else:
                self._add_weight_matrix(matrices, block, matrix)

        return BlockSolveBatch(target=target, matrices=matrices)

    def _add_proxy_matrices(
        self,
        matrices: dict[Any, BlockMatrix],
        block,
        matrix: BlockMatrix,
    ) -> None:
        """Add fixed-weight proxy solve matrices for one true block."""
        assert block.coefficient_index is not None
        assert self._weights is not None
        weights = self._weights[block.coefficient_index]
        for proxy_index in _nonzero_weight_indices(weights):
            key = _provider_proxy_key(self._provider_group, proxy_index)
            contribution = _scale_block_matrix(matrix, weights[proxy_index])
            if key in matrices:
                matrices[key] = _dense_block_matrix(
                    matrices[key]
                ) + _dense_block_matrix(contribution)
            else:
                matrices[key] = contribution

    def _add_weight_matrix(
        self,
        matrices: dict[Any, BlockMatrix],
        block,
        matrix: BlockMatrix,
    ) -> None:
        """Add fixed-proxy weight solve matrix for one true block."""
        assert block.coefficient_index is not None
        assert self._proxy is not None
        key = _provider_weight_key(self._provider_group, block.coefficient_index)
        matrices[key] = _dense_block_matrix(matrix) @ self._proxy.T


[docs] @dataclass(frozen=True) class AlchemicalALSResult: """Summary of one alternating least-squares fit over alchemical coefficients.""" theta: torch.Tensor objective_history: tuple[float, ...] converged: bool sweeps: int layout: ParameterLayout problem: BlockLinearProblem interrupted: bool = False restored_checkpoint_path: str | None = None
[docs] class AlchemicalALSFitter: """Alternating least-squares driver for models with shared alchemical providers.""" def __init__( self, model: UFPModel, *, fit_energy: bool = True, fit_forces: bool = True, fit_per_atom_energy: bool = False, solver: str = "cg", ridge: float = 0.0, onebody_ridge: float | None = None, pair_ridge: float | None = None, twobody_ridge: float | None = None, threebody_ridge: float | None = None, twobody_shape_penalty: TwoBodySplineShapePenalty | None = None, weight_ridge: float | None = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, cg_tolerance: float = 1.0e-10, cg_max_iter: int | None = None, threebody_lstsq_backend: str | None = None, threebody_bucket_backend: str | None = None, max_sweeps: int = 10, tolerance: float = 1.0e-8, ) -> None: """Store ALS settings and reuse one true-coefficient fitter.""" self.linear_fitter = LinearFitter( model, fit_energy=fit_energy, fit_forces=fit_forces, fit_per_atom_energy=fit_per_atom_energy, solver=solver, ridge=ridge, onebody_ridge=onebody_ridge, pair_ridge=pair_ridge, twobody_ridge=twobody_ridge, threebody_ridge=threebody_ridge, twobody_shape_penalty=twobody_shape_penalty, dtype=dtype, device=device, cg_tolerance=cg_tolerance, cg_max_iter=cg_max_iter, threebody_lstsq_backend=threebody_lstsq_backend, threebody_bucket_backend=threebody_bucket_backend, ) self.model = model self.solver = solver self.ridge = float(ridge) self.onebody_ridge = self.linear_fitter.onebody_ridge self.pair_ridge = self.linear_fitter.pair_ridge self.threebody_ridge = self.linear_fitter.threebody_ridge self.twobody_shape_penalty = normalize_twobody_shape_penalty( twobody_shape_penalty ) self.weight_ridge = ( max(float(ridge), 1.0e-12) if weight_ridge is None else float(weight_ridge) ) self.cg_tolerance = float(cg_tolerance) self.cg_max_iter = cg_max_iter self.max_sweeps = int(max_sweeps) self.tolerance = float(tolerance) self.layout = self.linear_fitter.layout def _ridge_for_block_index(self, block_index: int) -> float: """Return the coefficient ridge assigned to one true block.""" block = self.layout.block(block_index) if block.kind == "onebody": return self.onebody_ridge if block.kind in ("pair", "twobody"): return self.pair_ridge if block.kind == "threebody": return self.threebody_ridge return self.ridge def _third_difference_for_block_index(self, block_index: int) -> float: """Return the two-body third-difference penalty for one block.""" block = self.layout.block(block_index) if block.kind in {"pair", "twobody"}: return self.twobody_shape_penalty.third_difference_weight return 0.0 def _provider_twobody_active_rows( self, provider_group: ProviderGroup, ) -> tuple[int, ...] | None: """Return active two-body rows for a provider-owned coefficient shape.""" for block_index in provider_group.block_indices: block = self.layout.block(block_index) if block.kind == "twobody": return _twobody_shape_regularization_rows(block) return None def _third_difference_for_provider(self, provider_group: ProviderGroup) -> float: """Return the provider proxy third-difference penalty.""" if any( self.layout.block(block_index).kind in {"pair", "twobody"} for block_index in provider_group.block_indices ): return self.twobody_shape_penalty.third_difference_weight return 0.0 def _ridge_for_provider(self, provider_group: ProviderGroup) -> float: """Return the proxy-coefficient ridge for one alchemical provider.""" block_index = provider_group.block_indices[0] return self._ridge_for_block_index(block_index) def _initialize_from_direct_solution(self, theta: torch.Tensor) -> None: """Seed proxy and weight factors from the direct true-coefficient solve.""" for block in self.layout.blocks: if ( block.coefficient_provider is None or block.coefficient_provider.uses_identity_weights ): self.layout.write_block_vector( block.index, theta[block.theta_slice], ) for provider_group in self.layout.non_identity_providers(): true_matrix = self.layout.provider_true_matrix(theta, provider_group) weights, proxies = _svd_initialize( true_matrix, provider_group.n_proxy_terms ) provider = provider_group.provider if not _provider_weights_are_trainable(provider): assert provider.weights is not None fixed_weights = provider.weights.to( dtype=true_matrix.dtype, device=true_matrix.device, ) proxies = _fit_proxies_to_fixed_weights( true_matrix, fixed_weights, provider_group.n_proxy_terms, ) provider.proxy_coeffs.data.copy_( proxies.reshape_as(provider.proxy_coeffs).to(provider.proxy_coeffs) ) continue provider.proxy_coeffs.data.copy_( proxies.reshape_as(provider.proxy_coeffs).to(provider.proxy_coeffs) ) assert provider.weights is not None provider.weights.data.copy_(weights.to(provider.weights)) _normalize_provider(provider_group) def _current_true_vector(self, problem: BlockLinearProblem) -> torch.Tensor: """Read current true coefficients in the problem layout.""" return self.layout.current_true_vector( dtype=problem.dtype, device=problem.device )
[docs] def initialize_from_direct_cg_checkpoint(self, checkpoint_path: Path | str) -> None: """Initialize alchemical coefficients from a direct true-CG checkpoint.""" checkpoint = load_cg_checkpoint( checkpoint_path, dtype=self.linear_fitter.dtype, device=self.linear_fitter.device, ) theta = checkpoint.x.reshape(-1) if theta.numel() != self.layout.size: raise ValueError( "CG checkpoint parameter count does not match this alchemical layout" ) metadata_n_parameters = checkpoint.metadata.get("n_parameters") if metadata_n_parameters is not None: try: metadata_n_parameters = int(metadata_n_parameters) except (TypeError, ValueError) as exc: raise ValueError( "CG checkpoint metadata contains an invalid parameter count" ) from exc if metadata_n_parameters != self.layout.size: raise ValueError( "CG checkpoint metadata does not match this alchemical layout" ) if self.layout.non_identity_providers(): self._initialize_from_direct_solution(theta) else: self.linear_fitter.write_back(theta)
def _active_direct_blocks(self) -> tuple[int, ...]: """Return direct blocks that stay in every ALS subproblem.""" return self.layout.direct_block_indices() def _proxy_initial_vector( self, problem: BlockLinearProblem, provider_group: ProviderGroup, ) -> torch.Tensor: """Return the current model state in one proxy-subproblem layout.""" beta = torch.zeros( (problem.layout.size,), dtype=problem.dtype, device=problem.device, ) current_true = self._current_true_vector(problem) for block_index in self._active_direct_blocks(): beta[problem.layout.theta_slice(block_index)] = current_true[ self.layout.block(block_index).theta_slice ] provider = provider_group.provider proxy = provider.proxy_coeffs.reshape( provider_group.n_proxy_terms, provider_group.block_size, ).to(dtype=problem.dtype, device=problem.device) for proxy_index in range(provider_group.n_proxy_terms): beta[ problem.layout.theta_slice( _provider_proxy_key( provider_group, proxy_index, ) ) ] = proxy[proxy_index] return beta def _weight_initial_vector( self, problem: BlockLinearProblem, provider_group: ProviderGroup, ) -> torch.Tensor: """Return the current model state in one weight-subproblem layout.""" beta = torch.zeros( (problem.layout.size,), dtype=problem.dtype, device=problem.device, ) current_true = self._current_true_vector(problem) for block_index in self._active_direct_blocks(): beta[problem.layout.theta_slice(block_index)] = current_true[ self.layout.block(block_index).theta_slice ] provider = provider_group.provider assert provider.weights is not None weights = provider.weights.to(dtype=problem.dtype, device=problem.device) for true_index in range(provider_group.n_true_terms): beta[ problem.layout.theta_slice( _provider_weight_key( provider_group, true_index, ) ) ] = weights[true_index] return beta def _write_checkpoint( self, checkpoint_directory: Path | str | None, *, stage: str, sweep: int, provider_index: int | None, objective_history: Sequence[float], true_problem: BlockLinearProblem, ) -> None: """Write a restorable alchemical model checkpoint.""" if checkpoint_directory is None: return checkpoint_dir = Path(checkpoint_directory) checkpoint_dir.mkdir(parents=True, exist_ok=True) payload = { "model_state_dict": self.model.state_dict(), "theta": self._current_true_vector(true_problem).detach().cpu(), "objective_history": tuple(float(value) for value in objective_history), "sweep": int(sweep), "provider_index": provider_index, "stage": str(stage), } latest_path = checkpoint_dir / "alchemical_latest.pt" torch.save(payload, latest_path) stage_name = ( f"sweep{sweep}_{stage}" if provider_index is None else f"sweep{sweep}_provider{provider_index}_{stage}" ) torch.save(payload, checkpoint_dir / f"alchemical_{stage_name}.pt") def _make_proxy_problem( self, true_problem: BlockLinearProblem, provider_group: ProviderGroup, ) -> BlockLinearProblem: """Build the linear subproblem that updates one provider's proxy blocks.""" direct_block_indices = self._active_direct_blocks() solve_blocks: list[SolveBlock] = [] for block_index in direct_block_indices: block = self.layout.block(block_index) solve_blocks.append( SolveBlock( key=block_index, size=block.size, label=block.label, regularization=_make_block_regularization( block.shape, ridge=self._ridge_for_block_index(block_index), third_difference_penalty=( self._third_difference_for_block_index(block_index) ), active_rows=_twobody_shape_regularization_rows(block), ), ) ) for proxy_index in range(provider_group.n_proxy_terms): solve_blocks.append( SolveBlock( key=_provider_proxy_key(provider_group, proxy_index), size=provider_group.block_size, label=f"proxy[{proxy_index}]", regularization=_make_block_regularization( provider_group.coefficient_shape, ridge=self._ridge_for_provider(provider_group), third_difference_penalty=( self._third_difference_for_provider(provider_group) ), active_rows=self._provider_twobody_active_rows(provider_group), ), ) ) return _StreamingBlockLinearProblem( layout=BlockProblemLayout.from_blocks(tuple(solve_blocks)), batches=_AlchemicalSubproblemBatchSequence( fitter=self, true_problem=true_problem, provider_group=provider_group, mode="proxy", ), ) def _make_weight_problem( self, true_problem: BlockLinearProblem, provider_group: ProviderGroup, ) -> BlockLinearProblem: """Build the linear subproblem that updates one provider's mixing weights.""" direct_block_indices = self._active_direct_blocks() solve_blocks: list[SolveBlock] = [] for block_index in direct_block_indices: block = self.layout.block(block_index) solve_blocks.append( SolveBlock( key=block_index, size=block.size, label=block.label, regularization=_make_block_regularization( block.shape, ridge=self._ridge_for_block_index(block_index), third_difference_penalty=( self._third_difference_for_block_index(block_index) ), active_rows=_twobody_shape_regularization_rows(block), ), ) ) for true_index in range(provider_group.n_true_terms): solve_blocks.append( SolveBlock( key=_provider_weight_key(provider_group, true_index), size=provider_group.n_proxy_terms, label=f"weights[{true_index}]", regularization=_make_block_regularization( (provider_group.n_proxy_terms,), ridge=self.weight_ridge, ), ) ) return _StreamingBlockLinearProblem( layout=BlockProblemLayout.from_blocks(tuple(solve_blocks)), batches=_AlchemicalSubproblemBatchSequence( fitter=self, true_problem=true_problem, provider_group=provider_group, mode="weight", ), ) def _solve_provider_proxy_subproblem( self, true_problem: BlockLinearProblem, provider_group: ProviderGroup, *, cg_checkpoint_path: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, progress: bool = False, progress_frequency: int = 10, ) -> LinearSolveResult: """Solve and write back one provider's proxy-update subproblem.""" problem = self._make_proxy_problem(true_problem, provider_group) initial_theta = self._proxy_initial_vector(problem, provider_group) result = problem.solve( solver=self.solver, cg_tolerance=self.cg_tolerance, cg_max_iter=self.cg_max_iter, initial_theta=initial_theta, fallback_theta=initial_theta, return_info=True, progress=progress, progress_frequency=progress_frequency, cg_checkpoint_path=cg_checkpoint_path, cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, ) assert isinstance(result, LinearSolveResult) beta = result.theta for block_index in self._active_direct_blocks(): theta_slice = problem.layout.theta_slice(block_index) self.layout.write_block_vector(block_index, beta[theta_slice]) provider = provider_group.provider for proxy_index in range(provider_group.n_proxy_terms): theta_slice = problem.layout.theta_slice( _provider_proxy_key(provider_group, proxy_index) ) provider.proxy_coeffs.data[proxy_index].copy_( beta[theta_slice] .reshape(provider_group.coefficient_shape) .to(provider.proxy_coeffs) ) _normalize_provider(provider_group) return result def _solve_provider_weight_subproblem( self, true_problem: BlockLinearProblem, provider_group: ProviderGroup, *, cg_checkpoint_path: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, progress: bool = False, progress_frequency: int = 10, ) -> LinearSolveResult: """Solve and write back one provider's weight-update subproblem.""" problem = self._make_weight_problem(true_problem, provider_group) initial_theta = self._weight_initial_vector(problem, provider_group) result = problem.solve( solver=self.solver, cg_tolerance=self.cg_tolerance, cg_max_iter=self.cg_max_iter, initial_theta=initial_theta, fallback_theta=initial_theta, return_info=True, progress=progress, progress_frequency=progress_frequency, cg_checkpoint_path=cg_checkpoint_path, cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, ) assert isinstance(result, LinearSolveResult) beta = result.theta for block_index in self._active_direct_blocks(): theta_slice = problem.layout.theta_slice(block_index) self.layout.write_block_vector(block_index, beta[theta_slice]) provider = provider_group.provider assert provider.weights is not None for true_index in range(provider_group.n_true_terms): theta_slice = problem.layout.theta_slice( _provider_weight_key(provider_group, true_index) ) provider.weights.data[true_index].copy_( beta[theta_slice].to(provider.weights) ) _normalize_provider(provider_group) return result
[docs] def fit( self, samples: Sequence[FitSample], *, batch_size: int = 32, cache_directory: Path | str | None = None, cache_mode: AssembledBatchCacheMode = "auto", initialize: Literal["svd", "current"] = "svd", checkpoint_directory: Path | str | None = None, checkpoint_frequency: int = 1, cg_checkpoint_directory: Path | str | None = None, cg_checkpoint_frequency: int = 1, cg_resume: bool = False, progress: bool = False, progress_frequency: int = 10, ) -> AlchemicalALSResult: """Alternate proxy and weight solves until convergence.""" if initialize not in {"svd", "current"}: raise ValueError("`initialize` must be 'svd' or 'current'") if progress_frequency <= 0: raise ValueError("`progress_frequency` must be positive") try: true_problem = self.linear_fitter.build_problem( samples, batch_size=batch_size, progress=progress, cache_directory=cache_directory, cache_mode=cache_mode, ) except KeyboardInterrupt: empty_problem = BlockLinearProblem( layout=BlockProblemLayout.from_blocks( self.linear_fitter._direct_blocks() ), batches=(), ) return AlchemicalALSResult( theta=self.layout.current_true_vector( dtype=self.linear_fitter.dtype, device=self.linear_fitter.device, ), objective_history=(float("nan"),), converged=False, sweeps=0, layout=self.layout, problem=empty_problem, interrupted=True, ) non_identity_providers = self.layout.non_identity_providers() if not non_identity_providers: initial_theta = ( self._current_true_vector(true_problem) if initialize == "current" else None ) direct_result = true_problem.solve( solver=self.solver, cg_tolerance=self.cg_tolerance, cg_max_iter=self.cg_max_iter, initial_theta=initial_theta, fallback_theta=self._current_true_vector(true_problem), return_info=True, progress=progress, progress_frequency=progress_frequency, cg_checkpoint_path=( None if cg_checkpoint_directory is None else Path(cg_checkpoint_directory) / "direct_cg.npz" ), cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, ) assert isinstance(direct_result, LinearSolveResult) direct_theta = direct_result.theta self.linear_fitter.write_back(direct_theta) objective = ( float("nan") if direct_result.interrupted else float(true_problem.objective(direct_theta).item()) ) return AlchemicalALSResult( theta=direct_theta, objective_history=(objective,), converged=not direct_result.interrupted, sweeps=0, layout=self.layout, problem=true_problem, interrupted=direct_result.interrupted, restored_checkpoint_path=direct_result.restored_checkpoint_path, ) if initialize == "svd": direct_result = true_problem.solve( solver=self.solver, cg_tolerance=self.cg_tolerance, cg_max_iter=self.cg_max_iter, fallback_theta=self._current_true_vector(true_problem), return_info=True, progress=progress, progress_frequency=progress_frequency, cg_checkpoint_path=( None if cg_checkpoint_directory is None else Path(cg_checkpoint_directory) / "direct_cg.npz" ), cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, ) assert isinstance(direct_result, LinearSolveResult) direct_theta = direct_result.theta if direct_result.interrupted: objective_history = (float("nan"),) self._write_checkpoint( checkpoint_directory, stage="interrupted", sweep=0, provider_index=None, objective_history=objective_history, true_problem=true_problem, ) return AlchemicalALSResult( theta=self._current_true_vector(true_problem), objective_history=objective_history, converged=False, sweeps=0, layout=self.layout, problem=true_problem, interrupted=True, restored_checkpoint_path=direct_result.restored_checkpoint_path, ) self._initialize_from_direct_solution(direct_theta) objective_history = [ float( true_problem.objective(self._current_true_vector(true_problem)).item() ) ] self._write_checkpoint( checkpoint_directory, stage="initialized", sweep=0, provider_index=None, objective_history=objective_history, true_problem=true_problem, ) converged = False interrupted = False restored_checkpoint_path = None sweeps = 0 checkpoint_frequency = max(int(checkpoint_frequency), 1) for sweep in range(1, self.max_sweeps + 1): previous_theta = self._current_true_vector(true_problem) previous_objective = objective_history[-1] for provider_index, provider_group in enumerate(non_identity_providers): cg_dir = ( None if cg_checkpoint_directory is None else Path(cg_checkpoint_directory) ) solve_result = self._solve_provider_proxy_subproblem( true_problem, provider_group, cg_checkpoint_path=( None if cg_dir is None else cg_dir / f"sweep{sweep}_provider{provider_index}_proxy.npz" ), cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, progress=progress, progress_frequency=progress_frequency, ) if solve_result.interrupted: interrupted = True restored_checkpoint_path = solve_result.restored_checkpoint_path self._write_checkpoint( checkpoint_directory, stage="proxy", sweep=sweep, provider_index=provider_index, objective_history=objective_history, true_problem=true_problem, ) if interrupted: break if _provider_weights_are_trainable(provider_group.provider): solve_result = self._solve_provider_weight_subproblem( true_problem, provider_group, cg_checkpoint_path=( None if cg_dir is None else cg_dir / f"sweep{sweep}_provider{provider_index}_weight.npz" ), cg_checkpoint_frequency=cg_checkpoint_frequency, cg_resume=cg_resume, progress=progress, progress_frequency=progress_frequency, ) if solve_result.interrupted: interrupted = True restored_checkpoint_path = solve_result.restored_checkpoint_path self._write_checkpoint( checkpoint_directory, stage="weight", sweep=sweep, provider_index=provider_index, objective_history=objective_history, true_problem=true_problem, ) if interrupted: break if interrupted: sweeps = sweep self._write_checkpoint( checkpoint_directory, stage="interrupted", sweep=sweep, provider_index=None, objective_history=objective_history, true_problem=true_problem, ) break current_theta = self._current_true_vector(true_problem) current_objective = float(true_problem.objective(current_theta).item()) objective_history.append(current_objective) sweeps = sweep if sweep % checkpoint_frequency == 0: self._write_checkpoint( checkpoint_directory, stage="sweep", sweep=sweep, provider_index=None, objective_history=objective_history, true_problem=true_problem, ) theta_norm = float(torch.linalg.norm(previous_theta).item()) theta_delta = float( torch.linalg.norm(current_theta - previous_theta).item() ) relative_theta_change = theta_delta / max(theta_norm, 1.0) relative_objective_change = abs( current_objective - previous_objective ) / max( abs(previous_objective), 1.0, ) if ( relative_theta_change <= self.tolerance or relative_objective_change <= self.tolerance ): converged = True break return AlchemicalALSResult( theta=self._current_true_vector(true_problem), objective_history=tuple(objective_history), converged=converged, sweeps=sweeps, layout=self.layout, problem=true_problem, interrupted=interrupted, restored_checkpoint_path=restored_checkpoint_path, )
__all__ = [ "AlchemicalALSResult", "AlchemicalALSFitter", ]