Source code for ufp.training.freezing

"""Coefficient freezing helpers for optimizer-based training."""

from __future__ import annotations

from dataclasses import dataclass, field
from types import MethodType
from typing import Sequence

import torch

from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import (
    BlockSelector,
    coefficient_indices_for_selector,
    resolve_coefficient_selection,
)
from ufp.terms.model import UFPModel


_OPTIMIZER_STATE_ATTR = "_ufp_coefficient_freeze_state"
_OPTIMIZER_ORIGINAL_STEP_ATTR = "_ufp_coefficient_freeze_original_step"


def _direct_parameter_for_block(block: TermBlock) -> torch.nn.Parameter:
    """Return the mutable parameter tensor backing one direct coefficient block."""
    provider = block.coefficient_provider
    if provider is not None:
        if not provider.uses_identity_weights:
            raise ValueError(
                "coefficient freezing for non-identity alchemical providers is not "
                f"supported for block {block.label!r}"
            )
        return provider.proxy_coeffs

    term = block.term
    for name in (
        "coeffs",
        "coeffs_by_pair",
        "coeffs_by_triplet",
        "values",
    ):
        value = getattr(term, name, None)
        if isinstance(value, torch.nn.Parameter):
            return value

    tensor = block.read()
    if isinstance(tensor, torch.nn.Parameter):
        return tensor
    raise ValueError(f"could not find a trainable parameter for block {block.label!r}")


def _block_indices_into_parameter(block: TermBlock) -> tuple[int, ...]:
    """Return flat parameter indices corresponding to one block's true coefficients."""
    provider = block.coefficient_provider
    if provider is not None:
        if block.coefficient_index is None:
            raise ValueError("alchemical coefficient blocks require an index")
        block_size = block.size
        start = int(block.coefficient_index) * block_size
        return tuple(range(start, start + block_size))
    return tuple(range(block.size))


[docs] @dataclass(frozen=True) class FrozenParameterMetadata: """Inspection summary for one optimizer parameter touched by a freeze mask.""" name: str shape: tuple[int, ...] frozen_count: int trainable_count: int block_labels: tuple[str, ...]
[docs] @dataclass(frozen=True) class FrozenSelectorMetadata: """Inspection summary for one user selector after layout resolution.""" selector: str block_labels: tuple[str, ...] frozen_count: int
[docs] @dataclass class CoefficientFreezeState: """Frozen coefficient masks and reference values for optimizer training.""" masks: dict[torch.nn.Parameter, torch.Tensor] values: dict[torch.nn.Parameter, torch.Tensor] parameter_metadata: tuple[FrozenParameterMetadata, ...] = field( default_factory=tuple ) selector_metadata: tuple[FrozenSelectorMetadata, ...] = field(default_factory=tuple) @property def affected_parameter_names(self) -> tuple[str, ...]: """Return parameter names touched by this freeze state.""" return tuple(metadata.name for metadata in self.parameter_metadata) @property def frozen_counts(self) -> dict[str, int]: """Return frozen-entry counts keyed by parameter name.""" return { metadata.name: int(metadata.frozen_count) for metadata in self.parameter_metadata } @property def trainable_counts(self) -> dict[str, int]: """Return unfrozen-entry counts keyed by parameter name.""" return { metadata.name: int(metadata.trainable_count) for metadata in self.parameter_metadata } @property def frozen_count(self) -> int: """Return the total number of frozen coefficient entries.""" return sum(metadata.frozen_count for metadata in self.parameter_metadata) @property def trainable_count(self) -> int: """Return the total number of unfrozen entries in affected parameters.""" return sum(metadata.trainable_count for metadata in self.parameter_metadata)
[docs] def apply_gradient_masks(self) -> None: """Zero gradients for frozen coefficient entries.""" for parameter, mask in self.masks.items(): if parameter.grad is None: continue parameter.grad.masked_fill_(mask, 0.0)
[docs] def restore(self) -> None: """Restore frozen coefficient entries after an optimizer step.""" with torch.no_grad(): for parameter, mask in self.masks.items(): parameter.data[mask] = self.values[parameter][mask]
[docs] def clear_optimizer_state(self, optimizer: torch.optim.Optimizer) -> int: """ Clear tensor optimizer-state entries for frozen coefficient positions. Adam-style optimizers keep per-parameter moment buffers. Clearing the frozen entries prevents stale moments from reappearing when the same optimizer is later unfrozen or rewrapped with a different mask. """ cleared = 0 with torch.no_grad(): for parameter, mask in self.masks.items(): optimizer_state = optimizer.state.get(parameter) if not optimizer_state: continue flat_mask = mask.reshape(-1) for value in optimizer_state.values(): if not isinstance(value, torch.Tensor): continue if tuple(value.shape) == tuple(parameter.shape): value.masked_fill_(mask, 0.0) cleared += 1 elif value.ndim > 0 and int(value.numel()) == int( parameter.numel() ): value.reshape(-1).masked_fill_(flat_mask, 0.0) cleared += 1 return cleared
[docs] def wrap_optimizer(self, optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer: """Patch one optimizer so frozen entries remain unchanged after ``step``.""" existing_state = getattr(optimizer, _OPTIMIZER_STATE_ATTR, None) if existing_state is self: return optimizer if existing_state is not None: raise RuntimeError( "optimizer is already wrapped by a different " "CoefficientFreezeState; unwrap it before applying another " "coefficient freeze mask" ) original_step = optimizer.step state = self def step(wrapped_optimizer, closure=None): if closure is None: state.apply_gradient_masks() result = original_step() else: def wrapped_closure(): loss = closure() state.apply_gradient_masks() return loss result = original_step(closure=wrapped_closure) state.restore() state.clear_optimizer_state(wrapped_optimizer) return result setattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR, original_step) setattr(optimizer, _OPTIMIZER_STATE_ATTR, self) optimizer.step = MethodType(step, optimizer) # type: ignore[method-assign] return optimizer
[docs] def unwrap_optimizer( self, optimizer: torch.optim.Optimizer, ) -> torch.optim.Optimizer: """Restore an optimizer previously patched by this freeze state.""" existing_state = getattr(optimizer, _OPTIMIZER_STATE_ATTR, None) if existing_state is None: return optimizer if existing_state is not self: raise RuntimeError( "optimizer is wrapped by a different CoefficientFreezeState" ) original_step = getattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR) optimizer.step = original_step # type: ignore[method-assign] delattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR) delattr(optimizer, _OPTIMIZER_STATE_ATTR) return optimizer
def _selector_metadata( layout: ParameterLayout, selectors: Sequence[BlockSelector], ) -> tuple[FrozenSelectorMetadata, ...]: """Return per-selector summaries using the same block semantics as fitting.""" metadata: list[FrozenSelectorMetadata] = [] for selector in selectors: block_labels: list[str] = [] frozen_count = 0 for block in layout.blocks: indices = coefficient_indices_for_selector(block, selector) if not indices: continue block_labels.append(block.label) frozen_count += len(indices) metadata.append( FrozenSelectorMetadata( selector=repr(selector), block_labels=tuple(block_labels), frozen_count=int(frozen_count), ) ) return tuple(metadata)
[docs] def freeze_model_coefficients( model: UFPModel, selectors: Sequence[BlockSelector], ) -> CoefficientFreezeState: """Return masks that freeze selected coefficients during optimizer training.""" layout = ParameterLayout.from_model(model, include_frozen=True) selector_metadata = _selector_metadata(layout, selectors) selected = resolve_coefficient_selection(layout, fit_blocks=tuple(selectors)) masks: dict[torch.nn.Parameter, torch.Tensor] = {} block_labels: dict[torch.nn.Parameter, set[str]] = {} for selection in selected: block = selection.block parameter = _direct_parameter_for_block(block) mask = masks.get(parameter) if mask is None: mask = torch.zeros_like(parameter, dtype=torch.bool) masks[parameter] = mask flat_mask = mask.reshape(-1) parameter_indices = _block_indices_into_parameter(block) for local_index in selection.indices: flat_mask[parameter_indices[int(local_index)]] = True block_labels.setdefault(parameter, set()).add(block.label) values = { parameter: parameter.detach().clone() for parameter, mask in masks.items() if bool(torch.any(mask)) } masks = { parameter: mask for parameter, mask in masks.items() if bool(torch.any(mask)) } named_parameters = { id(parameter): name for name, parameter in model.named_parameters() } parameter_metadata = [] for index, (parameter, mask) in enumerate(masks.items()): frozen_count = int(torch.count_nonzero(mask).item()) name = named_parameters.get(id(parameter), f"parameter_{index}") parameter_metadata.append( FrozenParameterMetadata( name=name, shape=tuple(int(dim) for dim in parameter.shape), frozen_count=frozen_count, trainable_count=int(parameter.numel()) - frozen_count, block_labels=tuple(sorted(block_labels.get(parameter, ()))), ) ) return CoefficientFreezeState( masks=masks, values=values, parameter_metadata=tuple(parameter_metadata), selector_metadata=selector_metadata, )
__all__ = [ "CoefficientFreezeState", "FrozenParameterMetadata", "FrozenSelectorMetadata", "freeze_model_coefficients", ]