"""Coefficient freezing helpers for optimizer-based training."""
from __future__ import annotations
from dataclasses import dataclass, field
from types import MethodType
from typing import Sequence
import torch
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import (
BlockSelector,
coefficient_indices_for_selector,
resolve_coefficient_selection,
)
from ufp.terms.model import UFPModel
_OPTIMIZER_STATE_ATTR = "_ufp_coefficient_freeze_state"
_OPTIMIZER_ORIGINAL_STEP_ATTR = "_ufp_coefficient_freeze_original_step"
def _direct_parameter_for_block(block: TermBlock) -> torch.nn.Parameter:
"""Return the mutable parameter tensor backing one direct coefficient block."""
provider = block.coefficient_provider
if provider is not None:
if not provider.uses_identity_weights:
raise ValueError(
"coefficient freezing for non-identity alchemical providers is not "
f"supported for block {block.label!r}"
)
return provider.proxy_coeffs
term = block.term
for name in (
"coeffs",
"coeffs_by_pair",
"coeffs_by_triplet",
"values",
):
value = getattr(term, name, None)
if isinstance(value, torch.nn.Parameter):
return value
tensor = block.read()
if isinstance(tensor, torch.nn.Parameter):
return tensor
raise ValueError(f"could not find a trainable parameter for block {block.label!r}")
def _block_indices_into_parameter(block: TermBlock) -> tuple[int, ...]:
"""Return flat parameter indices corresponding to one block's true coefficients."""
provider = block.coefficient_provider
if provider is not None:
if block.coefficient_index is None:
raise ValueError("alchemical coefficient blocks require an index")
block_size = block.size
start = int(block.coefficient_index) * block_size
return tuple(range(start, start + block_size))
return tuple(range(block.size))
[docs]
@dataclass
class CoefficientFreezeState:
"""Frozen coefficient masks and reference values for optimizer training."""
masks: dict[torch.nn.Parameter, torch.Tensor]
values: dict[torch.nn.Parameter, torch.Tensor]
parameter_metadata: tuple[FrozenParameterMetadata, ...] = field(
default_factory=tuple
)
selector_metadata: tuple[FrozenSelectorMetadata, ...] = field(default_factory=tuple)
@property
def affected_parameter_names(self) -> tuple[str, ...]:
"""Return parameter names touched by this freeze state."""
return tuple(metadata.name for metadata in self.parameter_metadata)
@property
def frozen_counts(self) -> dict[str, int]:
"""Return frozen-entry counts keyed by parameter name."""
return {
metadata.name: int(metadata.frozen_count)
for metadata in self.parameter_metadata
}
@property
def trainable_counts(self) -> dict[str, int]:
"""Return unfrozen-entry counts keyed by parameter name."""
return {
metadata.name: int(metadata.trainable_count)
for metadata in self.parameter_metadata
}
@property
def frozen_count(self) -> int:
"""Return the total number of frozen coefficient entries."""
return sum(metadata.frozen_count for metadata in self.parameter_metadata)
@property
def trainable_count(self) -> int:
"""Return the total number of unfrozen entries in affected parameters."""
return sum(metadata.trainable_count for metadata in self.parameter_metadata)
[docs]
def apply_gradient_masks(self) -> None:
"""Zero gradients for frozen coefficient entries."""
for parameter, mask in self.masks.items():
if parameter.grad is None:
continue
parameter.grad.masked_fill_(mask, 0.0)
[docs]
def restore(self) -> None:
"""Restore frozen coefficient entries after an optimizer step."""
with torch.no_grad():
for parameter, mask in self.masks.items():
parameter.data[mask] = self.values[parameter][mask]
[docs]
def clear_optimizer_state(self, optimizer: torch.optim.Optimizer) -> int:
"""
Clear tensor optimizer-state entries for frozen coefficient positions.
Adam-style optimizers keep per-parameter moment buffers. Clearing the
frozen entries prevents stale moments from reappearing when the same
optimizer is later unfrozen or rewrapped with a different mask.
"""
cleared = 0
with torch.no_grad():
for parameter, mask in self.masks.items():
optimizer_state = optimizer.state.get(parameter)
if not optimizer_state:
continue
flat_mask = mask.reshape(-1)
for value in optimizer_state.values():
if not isinstance(value, torch.Tensor):
continue
if tuple(value.shape) == tuple(parameter.shape):
value.masked_fill_(mask, 0.0)
cleared += 1
elif value.ndim > 0 and int(value.numel()) == int(
parameter.numel()
):
value.reshape(-1).masked_fill_(flat_mask, 0.0)
cleared += 1
return cleared
[docs]
def wrap_optimizer(self, optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""Patch one optimizer so frozen entries remain unchanged after ``step``."""
existing_state = getattr(optimizer, _OPTIMIZER_STATE_ATTR, None)
if existing_state is self:
return optimizer
if existing_state is not None:
raise RuntimeError(
"optimizer is already wrapped by a different "
"CoefficientFreezeState; unwrap it before applying another "
"coefficient freeze mask"
)
original_step = optimizer.step
state = self
def step(wrapped_optimizer, closure=None):
if closure is None:
state.apply_gradient_masks()
result = original_step()
else:
def wrapped_closure():
loss = closure()
state.apply_gradient_masks()
return loss
result = original_step(closure=wrapped_closure)
state.restore()
state.clear_optimizer_state(wrapped_optimizer)
return result
setattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR, original_step)
setattr(optimizer, _OPTIMIZER_STATE_ATTR, self)
optimizer.step = MethodType(step, optimizer) # type: ignore[method-assign]
return optimizer
[docs]
def unwrap_optimizer(
self,
optimizer: torch.optim.Optimizer,
) -> torch.optim.Optimizer:
"""Restore an optimizer previously patched by this freeze state."""
existing_state = getattr(optimizer, _OPTIMIZER_STATE_ATTR, None)
if existing_state is None:
return optimizer
if existing_state is not self:
raise RuntimeError(
"optimizer is wrapped by a different CoefficientFreezeState"
)
original_step = getattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR)
optimizer.step = original_step # type: ignore[method-assign]
delattr(optimizer, _OPTIMIZER_ORIGINAL_STEP_ATTR)
delattr(optimizer, _OPTIMIZER_STATE_ATTR)
return optimizer
def _selector_metadata(
layout: ParameterLayout,
selectors: Sequence[BlockSelector],
) -> tuple[FrozenSelectorMetadata, ...]:
"""Return per-selector summaries using the same block semantics as fitting."""
metadata: list[FrozenSelectorMetadata] = []
for selector in selectors:
block_labels: list[str] = []
frozen_count = 0
for block in layout.blocks:
indices = coefficient_indices_for_selector(block, selector)
if not indices:
continue
block_labels.append(block.label)
frozen_count += len(indices)
metadata.append(
FrozenSelectorMetadata(
selector=repr(selector),
block_labels=tuple(block_labels),
frozen_count=int(frozen_count),
)
)
return tuple(metadata)
[docs]
def freeze_model_coefficients(
model: UFPModel,
selectors: Sequence[BlockSelector],
) -> CoefficientFreezeState:
"""Return masks that freeze selected coefficients during optimizer training."""
layout = ParameterLayout.from_model(model, include_frozen=True)
selector_metadata = _selector_metadata(layout, selectors)
selected = resolve_coefficient_selection(layout, fit_blocks=tuple(selectors))
masks: dict[torch.nn.Parameter, torch.Tensor] = {}
block_labels: dict[torch.nn.Parameter, set[str]] = {}
for selection in selected:
block = selection.block
parameter = _direct_parameter_for_block(block)
mask = masks.get(parameter)
if mask is None:
mask = torch.zeros_like(parameter, dtype=torch.bool)
masks[parameter] = mask
flat_mask = mask.reshape(-1)
parameter_indices = _block_indices_into_parameter(block)
for local_index in selection.indices:
flat_mask[parameter_indices[int(local_index)]] = True
block_labels.setdefault(parameter, set()).add(block.label)
values = {
parameter: parameter.detach().clone()
for parameter, mask in masks.items()
if bool(torch.any(mask))
}
masks = {
parameter: mask for parameter, mask in masks.items() if bool(torch.any(mask))
}
named_parameters = {
id(parameter): name for name, parameter in model.named_parameters()
}
parameter_metadata = []
for index, (parameter, mask) in enumerate(masks.items()):
frozen_count = int(torch.count_nonzero(mask).item())
name = named_parameters.get(id(parameter), f"parameter_{index}")
parameter_metadata.append(
FrozenParameterMetadata(
name=name,
shape=tuple(int(dim) for dim in parameter.shape),
frozen_count=frozen_count,
trainable_count=int(parameter.numel()) - frozen_count,
block_labels=tuple(sorted(block_labels.get(parameter, ()))),
)
)
return CoefficientFreezeState(
masks=masks,
values=values,
parameter_metadata=tuple(parameter_metadata),
selector_metadata=selector_metadata,
)
__all__ = [
"CoefficientFreezeState",
"FrozenParameterMetadata",
"FrozenSelectorMetadata",
"freeze_model_coefficients",
]