"""Coefficient-channel interchange helpers for compatible UFP models."""
from __future__ import annotations
import copy
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import cast
import torch
from ufp.leastsquares import CoefficientSelector
from ufp.leastsquares._layout import ParameterLayout, TermBlock
from ufp.leastsquares._selection import block_matches_selector
from ufp.terms.model import UFPModel
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.triplet2d import SplineTriplet2DTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm
CoefficientIndex = tuple[int | slice, ...]
[docs]
class CoefficientCompatibilityError(ValueError):
"""Raised when coefficient channels can not be safely interchanged."""
[docs]
@dataclass(frozen=True)
class CoefficientChannel:
"""Resolved physical channel inside one model coefficient block."""
block: TermBlock
family: str
channel: tuple[int, ...]
index: CoefficientIndex
value_shape: tuple[int, ...]
grid: tuple[object, ...]
cutoff: tuple[object, ...]
category_order: tuple[tuple[int, ...], ...]
active_channels: tuple[tuple[int, ...], ...]
symmetric: bool | None
@property
def block_label(self) -> str:
"""Return the owning coefficient-block label."""
return self.block.label
@property
def block_kind(self) -> str:
"""Return the owning coefficient-block kind."""
return self.block.kind
@property
def term_type(self) -> str:
"""Return the owning term class name."""
return type(self.block.term).__name__
@property
def dtype(self) -> torch.dtype:
"""Return the channel tensor dtype."""
return self.block.read().dtype
@property
def device(self) -> torch.device:
"""Return the channel tensor device."""
return self.block.read().device
[docs]
@dataclass(frozen=True)
class CoefficientCompatibility:
"""Compatibility result for one source/target channel pair."""
source: CoefficientChannel
target: CoefficientChannel
reasons: tuple[str, ...]
@property
def compatible(self) -> bool:
"""Return whether there are no incompatibility reasons."""
return not self.reasons
[docs]
@dataclass(frozen=True)
class CoefficientCopy:
"""Summary of one copied coefficient channel."""
family: str
channel: tuple[int, ...]
source_block_label: str
target_block_label: str
shape: tuple[int, ...]
[docs]
@dataclass(frozen=True)
class CoefficientCopyReport:
"""Summary of a copy-matching-coefficients operation."""
copied: tuple[CoefficientCopy, ...]
skipped: tuple[str, ...]
@property
def copied_count(self) -> int:
"""Return the number of copied channels."""
return len(self.copied)
def _as_tuple_channel(channel: Sequence[int] | None) -> tuple[int, ...]:
"""Normalize a required physical channel."""
if channel is None:
raise ValueError("`selector.channel` is required for coefficient channels")
return tuple(int(value) for value in channel)
def _canonical_pair(term: object, channel: Sequence[int]) -> tuple[int, int]:
"""Normalize one pair channel according to a term's symmetry convention."""
if len(channel) != 2:
raise ValueError("pair channels must contain exactly two atomic numbers")
from ufp.terms.twobody import _canonical_pair as canonical_pair
return canonical_pair(
int(channel[0]),
int(channel[1]),
symmetric=bool(getattr(term, "symmetric", True)),
)
def _canonical_triplet(channel: Sequence[int]) -> tuple[int, int, int]:
"""Normalize one source-distinguished triplet channel."""
if len(channel) != 3:
raise ValueError("triplet channels must contain exactly three atomic numbers")
from ufp.terms._threebody_eval import _canonical_triplet as canonical_triplet
return canonical_triplet(int(channel[0]), int(channel[1]), int(channel[2]))
def _active_pair_channels(term: SplineTwoBodyTerm) -> tuple[tuple[int, int], ...]:
"""Return active pair categories for a categorized two-body term."""
active_indices = cast(tuple[int, ...], term._active_pair_indices)
return tuple(term.pair_categories[index] for index in active_indices)
def _active_triplet_channels(
term: SplineThreeBodyTerm | SplineTriplet2DTerm,
) -> tuple[tuple[int, int, int], ...]:
"""Return active triplet categories for a three-body-like term."""
categories = term.triplet_categories
active_indices = cast(tuple[int, ...], term._active_triplet_indices)
return tuple(categories[index] for index in active_indices)
def _required_cutoff(term: object) -> float:
"""Return a term cutoff, rejecting terms without one."""
cutoff = getattr(term, "cutoff", None)
if cutoff is None:
raise ValueError(f"term {type(term).__name__} does not define a cutoff")
return float(cutoff)
def _slice_shape(slices: tuple[slice, ...], shape: tuple[int, ...]) -> tuple[int, ...]:
"""Return the tensor shape selected by normalized unit-step slices."""
sizes = []
for item, dim in zip(slices, shape, strict=True):
start = 0 if item.start is None else int(item.start)
stop = int(dim) if item.stop is None else int(item.stop)
sizes.append(stop - start)
return tuple(sizes)
def _normalize_slice_component(component: object, dim: int) -> slice:
"""Normalize one optional coefficient-slice component."""
if component is None:
return slice(None)
if isinstance(component, int):
index = int(component)
if index < 0 or index >= dim:
raise ValueError(f"coefficient index {index} is outside [0, {dim})")
return slice(index, index + 1)
if isinstance(component, range):
start = int(component.start)
stop = int(component.stop)
step = int(component.step)
if step != 1:
raise ValueError("coefficient ranges must be contiguous")
if start < 0 or stop < 0 or start > dim or stop > dim or stop <= start:
raise ValueError(
f"coefficient range [{start}, {stop}) is outside [0, {dim})"
)
return slice(start, stop)
if isinstance(component, slice):
step = 1 if component.step is None else int(component.step)
if step != 1:
raise ValueError("coefficient slices must be contiguous")
start = 0 if component.start is None else int(component.start)
stop = dim if component.stop is None else int(component.stop)
if start < 0 or stop < 0 or start > dim or stop > dim or stop <= start:
raise ValueError(
f"coefficient slice [{start}, {stop}) is outside [0, {dim})"
)
return slice(start, stop)
raise TypeError("`coeff_slice` entries must be int, slice, range, or None")
def _normalize_coeff_slices(
coeff_slice: object,
coeff_shape: tuple[int, ...],
) -> tuple[slice, ...]:
"""Normalize trailing coefficient slices for a channel coefficient shape."""
if coeff_slice is None:
components: tuple[object, ...] = ()
elif isinstance(coeff_slice, tuple):
components = coeff_slice
else:
components = (coeff_slice,)
if len(components) > len(coeff_shape):
raise ValueError(
"`coeff_slice` has more dimensions than the selected coefficient shape"
)
padded = components + (None,) * (len(coeff_shape) - len(components))
return tuple(
_normalize_slice_component(component, int(dim))
for component, dim in zip(padded, coeff_shape, strict=True)
)
def _pair_channel_location(
block: TermBlock,
selector: CoefficientSelector,
) -> CoefficientChannel | None:
"""Resolve a pair-family selector against one block."""
term = block.term
channel = _as_tuple_channel(selector.channel)
if isinstance(term, SplinePairTerm):
pair = _canonical_pair(term, channel)
if pair != tuple(int(value) for value in term.pair):
return None
if not term.enabled:
raise CoefficientCompatibilityError(
f"channel {pair!r} in block {block.label!r} is inactive"
)
slices = _normalize_coeff_slices(selector.coeff_slice, block.shape)
return CoefficientChannel(
block=block,
family="pair",
channel=pair,
index=slices,
value_shape=_slice_shape(slices, block.shape),
grid=(
str(term.spline),
float(term.full_support_start),
float(term.first_knot),
float(term.knot_spacing),
),
cutoff=(_required_cutoff(term),),
category_order=(pair,),
active_channels=(pair,),
symmetric=bool(term.symmetric),
)
if isinstance(term, SplineTwoBodyTerm):
pair = _canonical_pair(term, channel)
pair_index_map = cast(dict[tuple[int, int], int], term._pair_index)
pair_index = pair_index_map.get(pair)
if pair_index is None:
return None
if not term.is_pair_active(pair[0], pair[1]):
raise CoefficientCompatibilityError(
f"channel {pair!r} in block {block.label!r} is inactive"
)
coeff_shape = tuple(int(dim) for dim in block.shape[1:])
slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape)
return CoefficientChannel(
block=block,
family="pair",
channel=pair,
index=(int(pair_index),) + slices,
value_shape=_slice_shape(slices, coeff_shape),
grid=(
str(term.spline),
float(term.full_support_start),
float(term.first_knot),
float(term.knot_spacing),
),
cutoff=(_required_cutoff(term),),
category_order=tuple(tuple(pair) for pair in term.pair_categories),
active_channels=tuple(tuple(pair) for pair in _active_pair_channels(term)),
symmetric=bool(term.symmetric),
)
return None
def _triplet_channel_location(
block: TermBlock,
selector: CoefficientSelector,
) -> CoefficientChannel | None:
"""Resolve a triplet-family selector against one block."""
term = block.term
channel = _as_tuple_channel(selector.channel)
if isinstance(term, SplineThreeBodyTerm):
triplet = _canonical_triplet(channel)
triplet_index_map = cast(
dict[tuple[int, int, int], int],
term._triplet_index,
)
triplet_index = triplet_index_map.get(triplet)
if triplet_index is None:
return None
if triplet not in _active_triplet_channels(term):
raise CoefficientCompatibilityError(
f"channel {triplet!r} in block {block.label!r} is inactive"
)
if len(block.shape) == 3:
active_indices = cast(tuple[int, ...], term._active_triplet_indices)
active = tuple(int(index) for index in active_indices)
if active != (int(triplet_index),):
return None
coeff_shape = tuple(int(dim) for dim in block.shape)
slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape)
index: CoefficientIndex = slices
else:
coeff_shape = tuple(int(dim) for dim in block.shape[1:])
slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape)
index = (int(triplet_index),) + slices
return CoefficientChannel(
block=block,
family="triplet",
channel=triplet,
index=index,
value_shape=_slice_shape(slices, coeff_shape),
grid=(
str(term.spline),
float(term.full_support_start_xy),
float(term.full_support_start_z),
float(term.first_knot_xy),
float(term.first_knot_z),
float(term.knot_spacing_xy),
float(term.knot_spacing_z),
),
cutoff=(_required_cutoff(term), float(term.neighbor_neighbor_cutoff)),
category_order=tuple(tuple(item) for item in term.triplet_categories),
active_channels=tuple(
tuple(item) for item in _active_triplet_channels(term)
),
symmetric=None,
)
if isinstance(term, SplineTriplet2DTerm):
triplet = _canonical_triplet(channel)
triplet_categories = term.triplet_categories
triplet_index = (
triplet_categories.index(triplet) if triplet in triplet_categories else None
)
if triplet_index is None:
return None
if triplet not in _active_triplet_channels(term):
raise CoefficientCompatibilityError(
f"channel {triplet!r} in block {block.label!r} is inactive"
)
coeff_shape = tuple(int(dim) for dim in block.shape[1:])
slices = _normalize_coeff_slices(selector.coeff_slice, coeff_shape)
return CoefficientChannel(
block=block,
family="triplet2d",
channel=triplet,
index=(int(triplet_index),) + slices,
value_shape=_slice_shape(slices, coeff_shape),
grid=(
str(term.spline),
float(term.full_support_start),
float(term.first_knot),
float(term.knot_spacing),
),
cutoff=(_required_cutoff(term),),
category_order=tuple(tuple(item) for item in triplet_categories),
active_channels=tuple(
tuple(item) for item in _active_triplet_channels(term)
),
symmetric=None,
)
return None
def _resolve_channel(
model: UFPModel,
selector: CoefficientSelector,
) -> CoefficientChannel:
"""Resolve one channel selector to exactly one model coefficient channel."""
layout = ParameterLayout.from_model(model, include_frozen=True)
matches: list[CoefficientChannel] = []
for block in layout.blocks:
if not block_matches_selector(block, selector.block):
continue
match = _pair_channel_location(block, selector)
if match is None:
match = _triplet_channel_location(block, selector)
if match is not None:
matches.append(match)
if not matches:
raise ValueError(f"selector {selector!r} did not match any coefficient channel")
if len(matches) > 1:
labels = ", ".join(match.block_label for match in matches)
raise ValueError(f"selector {selector!r} matched multiple channels: {labels}")
return matches[0]
def _iter_model_channels(model: UFPModel) -> tuple[CoefficientChannel, ...]:
"""Return all active pair/triplet coefficient channels in layout order."""
layout = ParameterLayout.from_model(model, include_frozen=True)
channels: list[CoefficientChannel] = []
for block in layout.blocks:
term = block.term
if isinstance(term, SplinePairTerm):
selector = CoefficientSelector(block=block.index, channel=term.pair)
channels.append(_resolve_channel(model, selector))
elif isinstance(term, SplineTwoBodyTerm):
for pair in _active_pair_channels(term):
selector = CoefficientSelector(block=block.index, channel=pair)
channels.append(_resolve_channel(model, selector))
elif isinstance(term, SplineThreeBodyTerm):
for triplet in _active_triplet_channels(term):
selector = CoefficientSelector(block=block.index, channel=triplet)
channels.append(_resolve_channel(model, selector))
elif isinstance(term, SplineTriplet2DTerm):
for triplet in _active_triplet_channels(term):
selector = CoefficientSelector(block=block.index, channel=triplet)
channels.append(_resolve_channel(model, selector))
return tuple(channels)
def _read_channel(channel: CoefficientChannel) -> torch.Tensor:
"""Read one resolved channel view as a detached clone."""
return channel.block.read()[channel.index].detach().clone()
[docs]
def read_coefficient_channel(
model: UFPModel,
selector: CoefficientSelector,
) -> torch.Tensor:
"""Read one physical pair or triplet coefficient channel."""
return _read_channel(_resolve_channel(model, selector))
[docs]
def write_coefficient_channel(
model: UFPModel,
selector: CoefficientSelector,
values: torch.Tensor,
) -> None:
"""Write one physical pair or triplet coefficient channel."""
channel = _resolve_channel(model, selector)
current = channel.block.read().detach().clone()
reshaped = values.reshape(channel.value_shape).to(
dtype=current.dtype,
device=current.device,
)
current[channel.index] = reshaped
channel.block.write(current.reshape(channel.block.shape))
def _compatibility_reasons(
source: CoefficientChannel,
target: CoefficientChannel,
) -> tuple[str, ...]:
"""Return incompatibility reasons for one source/target channel pair."""
reasons: list[str] = []
if source.family != target.family:
reasons.append(f"term family differs: {source.family!r} != {target.family!r}")
if (
source.block_kind == target.block_kind
and source.block_label != target.block_label
):
reasons.append(
f"block label differs: {source.block_label!r} != {target.block_label!r}"
)
if source.channel != target.channel:
reasons.append(
f"channel identity differs: {source.channel!r} != {target.channel!r}"
)
if source.value_shape != target.value_shape:
reasons.append(
f"coefficient shape differs: {source.value_shape} != {target.value_shape}"
)
if source.dtype != target.dtype:
reasons.append(f"dtype differs: {source.dtype} != {target.dtype}")
if source.device != target.device:
reasons.append(f"device differs: {source.device} != {target.device}")
if source.grid != target.grid:
reasons.append("spline/grid metadata differs")
if source.cutoff != target.cutoff:
reasons.append("cutoff metadata differs")
if source.symmetric != target.symmetric and source.family == "pair":
reasons.append(
f"pair symmetry differs: {source.symmetric} != {target.symmetric}"
)
if source.channel not in source.active_channels:
reasons.append(f"source channel {source.channel!r} is inactive")
if target.channel not in target.active_channels:
reasons.append(f"target channel {target.channel!r} is inactive")
if (
source.block_kind == target.block_kind
and source.category_order != target.category_order
):
reasons.append("category ordering differs")
return tuple(reasons)
[docs]
def validate_coefficient_compatibility(
source_model: UFPModel,
target_model: UFPModel,
selector: CoefficientSelector,
*,
target_selector: CoefficientSelector | None = None,
) -> CoefficientCompatibility:
"""Validate that one source channel can be copied into one target channel."""
source = _resolve_channel(source_model, selector)
target = _resolve_channel(
target_model,
selector if target_selector is None else target_selector,
)
compatibility = CoefficientCompatibility(
source=source,
target=target,
reasons=_compatibility_reasons(source, target),
)
if not compatibility.compatible:
message = "; ".join(compatibility.reasons)
raise CoefficientCompatibilityError(message)
return compatibility
[docs]
def copy_coefficient_channel(
source_model: UFPModel,
target_model: UFPModel,
selector: CoefficientSelector,
*,
target_selector: CoefficientSelector | None = None,
) -> CoefficientCopy:
"""Copy one compatible source coefficient channel into a target model."""
compatibility = validate_coefficient_compatibility(
source_model,
target_model,
selector,
target_selector=target_selector,
)
values = _read_channel(compatibility.source)
write_coefficient_channel(
target_model,
selector if target_selector is None else target_selector,
values,
)
return CoefficientCopy(
family=compatibility.source.family,
channel=compatibility.source.channel,
source_block_label=compatibility.source.block_label,
target_block_label=compatibility.target.block_label,
shape=compatibility.source.value_shape,
)
def _channel_key(channel: CoefficientChannel) -> tuple[str, tuple[int, ...]]:
"""Return the copy-matching identity key for one channel."""
return channel.family, channel.channel
def _unique_channel_map(
channels: Iterable[CoefficientChannel],
*,
model_label: str,
) -> dict[tuple[str, tuple[int, ...]], CoefficientChannel]:
"""Return a unique channel map or raise on ambiguous duplicate channels."""
mapped: dict[tuple[str, tuple[int, ...]], CoefficientChannel] = {}
for channel in channels:
key = _channel_key(channel)
if key in mapped:
raise CoefficientCompatibilityError(
f"{model_label} has multiple coefficient channels for {key!r}"
)
mapped[key] = channel
return mapped
def _write_channel(channel: CoefficientChannel, values: torch.Tensor) -> None:
"""Write values into one already resolved channel."""
current = channel.block.read().detach().clone()
current[channel.index] = values.reshape(channel.value_shape).to(
dtype=current.dtype,
device=current.device,
)
channel.block.write(current.reshape(channel.block.shape))
[docs]
def copy_matching_coefficients(
source_model: UFPModel,
target_model: UFPModel,
*,
strict: bool = True,
) -> CoefficientCopyReport:
"""Copy every matching compatible pair/triplet channel into ``target_model``."""
source_channels = _unique_channel_map(
_iter_model_channels(source_model),
model_label="source model",
)
target_channels = _unique_channel_map(
_iter_model_channels(target_model),
model_label="target model",
)
copied: list[CoefficientCopy] = []
skipped: list[str] = []
for key, source in source_channels.items():
target = target_channels.get(key)
if target is None:
skipped.append(f"no target channel for {key!r}")
continue
reasons = _compatibility_reasons(source, target)
if reasons:
message = f"channel {key!r}: " + "; ".join(reasons)
if strict:
raise CoefficientCompatibilityError(message)
skipped.append(message)
continue
_write_channel(target, _read_channel(source))
copied.append(
CoefficientCopy(
family=source.family,
channel=source.channel,
source_block_label=source.block_label,
target_block_label=target.block_label,
shape=source.value_shape,
)
)
return CoefficientCopyReport(copied=tuple(copied), skipped=tuple(skipped))
[docs]
def zero_model_coefficients(model: UFPModel) -> UFPModel:
"""Zero all fittable coefficient blocks in-place and return ``model``."""
layout = ParameterLayout.from_model(model, include_frozen=True)
for block in layout.blocks:
block.write(torch.zeros_like(block.read()))
return model
[docs]
def clone_model_with_copied_coefficients(model: UFPModel) -> UFPModel:
"""Deep-copy a model architecture and coefficient values."""
return copy.deepcopy(model)
[docs]
def clone_model_with_zeroed_coefficients(model: UFPModel) -> UFPModel:
"""Deep-copy a model architecture and zero all copied coefficient blocks."""
clone = clone_model_with_copied_coefficients(model)
zero_model_coefficients(clone)
return clone
__all__ = [
"CoefficientChannel",
"CoefficientCompatibility",
"CoefficientCompatibilityError",
"CoefficientCopy",
"CoefficientCopyReport",
"clone_model_with_copied_coefficients",
"clone_model_with_zeroed_coefficients",
"copy_coefficient_channel",
"copy_matching_coefficients",
"read_coefficient_channel",
"validate_coefficient_compatibility",
"write_coefficient_channel",
"zero_model_coefficients",
]