Source code for ufp.coefficients.interchange

"""Coefficient-channel interchange helpers for compatible UFP models."""

from __future__ import annotations

import copy
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import cast

import torch

from ufp.leastsquares import CoefficientSelector
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import block_matches_selector
from ufp.terms.model import UFPModel
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.triplet2d import SplineTriplet2DTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm


CoefficientIndex = tuple[int | slice, ...]


[docs] class CoefficientCompatibilityError(ValueError): """Raised when coefficient channels can not be safely interchanged."""
[docs] @dataclass(frozen=True) class CoefficientChannel: """Resolved physical channel inside one model coefficient block.""" block: TermBlock family: str channel: tuple[int, ...] index: CoefficientIndex value_shape: tuple[int, ...] grid: tuple[object, ...] cutoff: tuple[object, ...] category_order: tuple[tuple[int, ...], ...] active_channels: tuple[tuple[int, ...], ...] symmetric: bool | None @property def block_label(self) -> str: """Return the owning coefficient-block label.""" return self.block.label @property def block_kind(self) -> str: """Return the owning coefficient-block kind.""" return self.block.kind @property def term_type(self) -> str: """Return the owning term class name.""" return type(self.block.term).__name__ @property def dtype(self) -> torch.dtype: """Return the channel tensor dtype.""" return self.block.read().dtype @property def device(self) -> torch.device: """Return the channel tensor device.""" return self.block.read().device
[docs] @dataclass(frozen=True) class CoefficientCompatibility: """Compatibility result for one source/target channel pair.""" source: CoefficientChannel target: CoefficientChannel reasons: tuple[str, ...] @property def compatible(self) -> bool: """Return whether there are no incompatibility reasons.""" return not self.reasons
[docs] @dataclass(frozen=True) class CoefficientCopy: """Summary of one copied coefficient channel.""" family: str channel: tuple[int, ...] source_block_label: str target_block_label: str shape: tuple[int, ...]
[docs] @dataclass(frozen=True) class CoefficientCopyReport: """Summary of a copy-matching-coefficients operation.""" copied: tuple[CoefficientCopy, ...] skipped: tuple[str, ...] @property def copied_count(self) -> int: """Return the number of copied channels.""" return len(self.copied)
def _as_tuple_channel(channel: Sequence[int] | None) -> tuple[int, ...]: """Normalize a required physical channel.""" if channel is None: raise ValueError("`selector.channel` is required for coefficient channels") return tuple(int(value) for value in channel) def _canonical_pair(term: object, channel: Sequence[int]) -> tuple[int, int]: """Normalize one pair channel according to a term's symmetry convention.""" if len(channel) != 2: raise ValueError("pair channels must contain exactly two atomic numbers") from ufp.terms.twobody import _canonical_pair as canonical_pair return canonical_pair( int(channel[0]), int(channel[1]), symmetric=bool(getattr(term, "symmetric", True)), ) def _canonical_triplet(channel: Sequence[int]) -> tuple[int, int, int]: """Normalize one source-distinguished triplet channel.""" if len(channel) != 3: raise ValueError("triplet channels must contain exactly three atomic numbers") from ufp.terms._threebody_eval import _canonical_triplet as canonical_triplet return canonical_triplet(int(channel[0]), int(channel[1]), int(channel[2])) def _active_pair_channels(term: SplineTwoBodyTerm) -> tuple[tuple[int, int], ...]: """Return active pair categories for a categorized two-body term.""" active_indices = cast(tuple[int, ...], term._active_pair_indices) return tuple(term.pair_categories[index] for index in active_indices) def _active_triplet_channels( term: SplineThreeBodyTerm | SplineTriplet2DTerm, ) -> tuple[tuple[int, int, int], ...]: """Return active triplet categories for a three-body-like term.""" categories = term.triplet_categories active_indices = cast(tuple[int, ...], term._active_triplet_indices) return tuple(categories[index] for index in active_indices) def _required_cutoff(term: object) -> float: """Return a term cutoff, rejecting terms without one.""" cutoff = getattr(term, "cutoff", None) if cutoff is None: raise ValueError(f"term {type(term).__name__} does not define a cutoff") return float(cutoff) def _slice_shape(slices: tuple[slice, ...], shape: tuple[int, ...]) -> tuple[int, ...]: """Return the tensor shape selected by normalized unit-step slices.""" sizes = [] for item, dim in zip(slices, shape, strict=True): start = 0 if item.start is None else int(item.start) stop = int(dim) if item.stop is None else int(item.stop) sizes.append(stop - start) return tuple(sizes) def _normalize_slice_component(component: object, dim: int) -> slice: """Normalize one optional coefficient-slice component.""" if component is None: return slice(None) if isinstance(component, int): index = int(component) if index < 0 or index >= dim: raise ValueError(f"coefficient index {index} is outside [0, {dim})") return slice(index, index + 1) if isinstance(component, range): start = int(component.start) stop = int(component.stop) step = int(component.step) if step != 1: raise ValueError("coefficient ranges must be contiguous") if start < 0 or stop < 0 or start > dim or stop > dim or stop <= start: raise ValueError( f"coefficient range [{start}, {stop}) is outside [0, {dim})" ) return slice(start, stop) if isinstance(component, slice): step = 1 if component.step is None else int(component.step) if step != 1: raise ValueError("coefficient slices must be contiguous") start = 0 if component.start is None else int(component.start) stop = dim if component.stop is None else int(component.stop) if start < 0 or stop < 0 or start > dim or stop > dim or stop <= start: raise ValueError( f"coefficient slice [{start}, {stop}) is outside [0, {dim})" ) return slice(start, stop) raise TypeError("`coeff_slice` entries must be int, slice, range, or None") def _normalize_coeff_slices( coeff_slice: object, coeff_shape: tuple[int, ...], ) -> tuple[slice, ...]: """Normalize trailing coefficient slices for a channel coefficient shape.""" if coeff_slice is None: components: tuple[object, ...] = () elif isinstance(coeff_slice, tuple): components = coeff_slice else: components = (coeff_slice,) if len(components) > len(coeff_shape): raise ValueError( "`coeff_slice` has more dimensions than the selected coefficient shape" ) padded = components + (None,) * (len(coeff_shape) - len(components)) return tuple( _normalize_slice_component(component, int(dim)) for component, dim in zip(padded, coeff_shape, strict=True) ) def _pair_channel_location( block: TermBlock, selector: CoefficientSelector, ) -> CoefficientChannel | None: """Resolve a pair-family selector against one block.""" term = block.term channel = _as_tuple_channel(selector.channel) if isinstance(term, SplinePairTerm): pair = _canonical_pair(term, channel) if pair != tuple(int(value) for value in term.pair): return None if not term.enabled: raise CoefficientCompatibilityError( f"channel {pair!r} in block {block.label!r} is inactive" ) slices = _normalize_coeff_slices(selector.coeff_slice, block.shape) return CoefficientChannel( block=block, family="pair", channel=pair, index=slices, value_shape=_slice_shape(slices, block.shape), grid=( str(term.spline), float(term.full_support_start), float(term.first_knot), float(term.knot_spacing), ), cutoff=(_required_cutoff(term),), category_order=(pair,), active_channels=(pair,), symmetric=bool(term.symmetric), ) if isinstance(term, SplineTwoBodyTerm): pair = _canonical_pair(term, channel) pair_index_map = cast(dict[tuple[int, int], int], term._pair_index) pair_index = pair_index_map.get(pair) if pair_index is None: return None if not term.is_pair_active(pair[0], pair[1]): raise CoefficientCompatibilityError( f"channel {pair!r} in block {block.label!r} is inactive" ) coeff_shape = tuple(int(dim) for dim in block.shape[1:]) slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape) return CoefficientChannel( block=block, family="pair", channel=pair, index=(int(pair_index),) + slices, value_shape=_slice_shape(slices, coeff_shape), grid=( str(term.spline), float(term.full_support_start), float(term.first_knot), float(term.knot_spacing), ), cutoff=(_required_cutoff(term),), category_order=tuple(tuple(pair) for pair in term.pair_categories), active_channels=tuple(tuple(pair) for pair in _active_pair_channels(term)), symmetric=bool(term.symmetric), ) return None def _triplet_channel_location( block: TermBlock, selector: CoefficientSelector, ) -> CoefficientChannel | None: """Resolve a triplet-family selector against one block.""" term = block.term channel = _as_tuple_channel(selector.channel) if isinstance(term, SplineThreeBodyTerm): triplet = _canonical_triplet(channel) triplet_index_map = cast( dict[tuple[int, int, int], int], term._triplet_index, ) triplet_index = triplet_index_map.get(triplet) if triplet_index is None: return None if triplet not in _active_triplet_channels(term): raise CoefficientCompatibilityError( f"channel {triplet!r} in block {block.label!r} is inactive" ) if len(block.shape) == 3: active_indices = cast(tuple[int, ...], term._active_triplet_indices) active = tuple(int(index) for index in active_indices) if active != (int(triplet_index),): return None coeff_shape = tuple(int(dim) for dim in block.shape) slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape) index: CoefficientIndex = slices else: coeff_shape = tuple(int(dim) for dim in block.shape[1:]) slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape) index = (int(triplet_index),) + slices return CoefficientChannel( block=block, family="triplet", channel=triplet, index=index, value_shape=_slice_shape(slices, coeff_shape), grid=( str(term.spline), float(term.full_support_start_xy), float(term.full_support_start_z), float(term.first_knot_xy), float(term.first_knot_z), float(term.knot_spacing_xy), float(term.knot_spacing_z), ), cutoff=(_required_cutoff(term), float(term.neighbor_neighbor_cutoff)), category_order=tuple(tuple(item) for item in term.triplet_categories), active_channels=tuple( tuple(item) for item in _active_triplet_channels(term) ), symmetric=None, ) if isinstance(term, SplineTriplet2DTerm): triplet = _canonical_triplet(channel) triplet_categories = term.triplet_categories triplet_index = ( triplet_categories.index(triplet) if triplet in triplet_categories else None ) if triplet_index is None: return None if triplet not in _active_triplet_channels(term): raise CoefficientCompatibilityError( f"channel {triplet!r} in block {block.label!r} is inactive" ) coeff_shape = tuple(int(dim) for dim in block.shape[1:]) slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape) return CoefficientChannel( block=block, family="triplet2d", channel=triplet, index=(int(triplet_index),) + slices, value_shape=_slice_shape(slices, coeff_shape), grid=( str(term.spline), float(term.full_support_start), float(term.first_knot), float(term.knot_spacing), ), cutoff=(_required_cutoff(term),), category_order=tuple(tuple(item) for item in triplet_categories), active_channels=tuple( tuple(item) for item in _active_triplet_channels(term) ), symmetric=None, ) return None def _resolve_channel( model: UFPModel, selector: CoefficientSelector, ) -> CoefficientChannel: """Resolve one channel selector to exactly one model coefficient channel.""" layout = ParameterLayout.from_model(model, include_frozen=True) matches: list[CoefficientChannel] = [] for block in layout.blocks: if not block_matches_selector(block, selector.block): continue match = _pair_channel_location(block, selector) if match is None: match = _triplet_channel_location(block, selector) if match is not None: matches.append(match) if not matches: raise ValueError(f"selector {selector!r} did not match any coefficient channel") if len(matches) > 1: labels = ", ".join(match.block_label for match in matches) raise ValueError(f"selector {selector!r} matched multiple channels: {labels}") return matches[0] def _iter_model_channels(model: UFPModel) -> tuple[CoefficientChannel, ...]: """Return all active pair/triplet coefficient channels in layout order.""" layout = ParameterLayout.from_model(model, include_frozen=True) channels: list[CoefficientChannel] = [] for block in layout.blocks: term = block.term if isinstance(term, SplinePairTerm): selector = CoefficientSelector(block=block.index, channel=term.pair) channels.append(_resolve_channel(model, selector)) elif isinstance(term, SplineTwoBodyTerm): for pair in _active_pair_channels(term): selector = CoefficientSelector(block=block.index, channel=pair) channels.append(_resolve_channel(model, selector)) elif isinstance(term, SplineThreeBodyTerm): for triplet in _active_triplet_channels(term): selector = CoefficientSelector(block=block.index, channel=triplet) channels.append(_resolve_channel(model, selector)) elif isinstance(term, SplineTriplet2DTerm): for triplet in _active_triplet_channels(term): selector = CoefficientSelector(block=block.index, channel=triplet) channels.append(_resolve_channel(model, selector)) return tuple(channels) def _read_channel(channel: CoefficientChannel) -> torch.Tensor: """Read one resolved channel view as a detached clone.""" return channel.block.read()[channel.index].detach().clone()
[docs] def read_coefficient_channel( model: UFPModel, selector: CoefficientSelector, ) -> torch.Tensor: """Read one physical pair or triplet coefficient channel.""" return _read_channel(_resolve_channel(model, selector))
[docs] def write_coefficient_channel( model: UFPModel, selector: CoefficientSelector, values: torch.Tensor, ) -> None: """Write one physical pair or triplet coefficient channel.""" channel = _resolve_channel(model, selector) current = channel.block.read().detach().clone() reshaped = values.reshape(channel.value_shape).to( dtype=current.dtype, device=current.device, ) current[channel.index] = reshaped channel.block.write(current.reshape(channel.block.shape))
def _compatibility_reasons( source: CoefficientChannel, target: CoefficientChannel, ) -> tuple[str, ...]: """Return incompatibility reasons for one source/target channel pair.""" reasons: list[str] = [] if source.family != target.family: reasons.append(f"term family differs: {source.family!r} != {target.family!r}") if ( source.block_kind == target.block_kind and source.block_label != target.block_label ): reasons.append( f"block label differs: {source.block_label!r} != {target.block_label!r}" ) if source.channel != target.channel: reasons.append( f"channel identity differs: {source.channel!r} != {target.channel!r}" ) if source.value_shape != target.value_shape: reasons.append( f"coefficient shape differs: {source.value_shape} != {target.value_shape}" ) if source.dtype != target.dtype: reasons.append(f"dtype differs: {source.dtype} != {target.dtype}") if source.device != target.device: reasons.append(f"device differs: {source.device} != {target.device}") if source.grid != target.grid: reasons.append("spline/grid metadata differs") if source.cutoff != target.cutoff: reasons.append("cutoff metadata differs") if source.symmetric != target.symmetric and source.family == "pair": reasons.append( f"pair symmetry differs: {source.symmetric} != {target.symmetric}" ) if source.channel not in source.active_channels: reasons.append(f"source channel {source.channel!r} is inactive") if target.channel not in target.active_channels: reasons.append(f"target channel {target.channel!r} is inactive") if ( source.block_kind == target.block_kind and source.category_order != target.category_order ): reasons.append("category ordering differs") return tuple(reasons)
[docs] def validate_coefficient_compatibility( source_model: UFPModel, target_model: UFPModel, selector: CoefficientSelector, *, target_selector: CoefficientSelector | None = None, ) -> CoefficientCompatibility: """Validate that one source channel can be copied into one target channel.""" source = _resolve_channel(source_model, selector) target = _resolve_channel( target_model, selector if target_selector is None else target_selector, ) compatibility = CoefficientCompatibility( source=source, target=target, reasons=_compatibility_reasons(source, target), ) if not compatibility.compatible: message = "; ".join(compatibility.reasons) raise CoefficientCompatibilityError(message) return compatibility
[docs] def copy_coefficient_channel( source_model: UFPModel, target_model: UFPModel, selector: CoefficientSelector, *, target_selector: CoefficientSelector | None = None, ) -> CoefficientCopy: """Copy one compatible source coefficient channel into a target model.""" compatibility = validate_coefficient_compatibility( source_model, target_model, selector, target_selector=target_selector, ) values = _read_channel(compatibility.source) write_coefficient_channel( target_model, selector if target_selector is None else target_selector, values, ) return CoefficientCopy( family=compatibility.source.family, channel=compatibility.source.channel, source_block_label=compatibility.source.block_label, target_block_label=compatibility.target.block_label, shape=compatibility.source.value_shape, )
def _channel_key(channel: CoefficientChannel) -> tuple[str, tuple[int, ...]]: """Return the copy-matching identity key for one channel.""" return channel.family, channel.channel def _unique_channel_map( channels: Iterable[CoefficientChannel], *, model_label: str, ) -> dict[tuple[str, tuple[int, ...]], CoefficientChannel]: """Return a unique channel map or raise on ambiguous duplicate channels.""" mapped: dict[tuple[str, tuple[int, ...]], CoefficientChannel] = {} for channel in channels: key = _channel_key(channel) if key in mapped: raise CoefficientCompatibilityError( f"{model_label} has multiple coefficient channels for {key!r}" ) mapped[key] = channel return mapped def _write_channel(channel: CoefficientChannel, values: torch.Tensor) -> None: """Write values into one already resolved channel.""" current = channel.block.read().detach().clone() current[channel.index] = values.reshape(channel.value_shape).to( dtype=current.dtype, device=current.device, ) channel.block.write(current.reshape(channel.block.shape))
[docs] def copy_matching_coefficients( source_model: UFPModel, target_model: UFPModel, *, strict: bool = True, ) -> CoefficientCopyReport: """Copy every matching compatible pair/triplet channel into ``target_model``.""" source_channels = _unique_channel_map( _iter_model_channels(source_model), model_label="source model", ) target_channels = _unique_channel_map( _iter_model_channels(target_model), model_label="target model", ) copied: list[CoefficientCopy] = [] skipped: list[str] = [] for key, source in source_channels.items(): target = target_channels.get(key) if target is None: skipped.append(f"no target channel for {key!r}") continue reasons = _compatibility_reasons(source, target) if reasons: message = f"channel {key!r}: " + "; ".join(reasons) if strict: raise CoefficientCompatibilityError(message) skipped.append(message) continue _write_channel(target, _read_channel(source)) copied.append( CoefficientCopy( family=source.family, channel=source.channel, source_block_label=source.block_label, target_block_label=target.block_label, shape=source.value_shape, ) ) return CoefficientCopyReport(copied=tuple(copied), skipped=tuple(skipped))
[docs] def zero_model_coefficients(model: UFPModel) -> UFPModel: """Zero all fittable coefficient blocks in-place and return ``model``.""" layout = ParameterLayout.from_model(model, include_frozen=True) for block in layout.blocks: block.write(torch.zeros_like(block.read())) return model
[docs] def clone_model_with_copied_coefficients(model: UFPModel) -> UFPModel: """Deep-copy a model architecture and coefficient values.""" return copy.deepcopy(model)
[docs] def clone_model_with_zeroed_coefficients(model: UFPModel) -> UFPModel: """Deep-copy a model architecture and zero all copied coefficient blocks.""" clone = clone_model_with_copied_coefficients(model) zero_model_coefficients(clone) return clone
__all__ = [ "CoefficientChannel", "CoefficientCompatibility", "CoefficientCompatibilityError", "CoefficientCopy", "CoefficientCopyReport", "clone_model_with_copied_coefficients", "clone_model_with_zeroed_coefficients", "copy_coefficient_channel", "copy_matching_coefficients", "read_coefficient_channel", "validate_coefficient_compatibility", "write_coefficient_channel", "zero_model_coefficients", ]