"""Coefficient selection helpers for least-squares fitting."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Literal
from ufp.leastsquares._layout import ParameterLayout, TermBlock
[docs]
@dataclass(frozen=True)
class CoefficientSelector:
"""
Select a subset of one or more coefficient blocks.
``block`` uses the existing block selector semantics: block index, block name,
kind, label, or regularization group. ``channel`` optionally selects a pair or
triplet channel inside categorized coefficient tensors. ``coeff_slice`` selects
trailing coefficient dimensions inside the selected channel or channels.
"""
block: int | str
channel: tuple[int, ...] | None = None
coeff_slice: Any = None
BlockSelector = int | str | CoefficientSelector
@dataclass(frozen=True)
class SelectedCoefficientBlock:
"""Compact solve metadata for selected coefficients from one term block."""
block: TermBlock
indices: tuple[int, ...]
shape: tuple[int, ...]
@property
def key(self) -> int:
"""Return the solve key for this selected block."""
return int(self.block.index)
@property
def size(self) -> int:
"""Return the compact coefficient count."""
return len(self.indices)
@property
def is_full_block(self) -> bool:
"""Return whether all coefficients from the original block are selected."""
return self.size == self.block.size and self.indices == tuple(
range(self.block.size)
)
@dataclass(frozen=True)
class ResolvedCoefficientSelectionEntry:
"""Inspectable metadata for one resolved fit or freeze selection entry."""
block: TermBlock
source: Literal["fit", "freeze"]
original_indices: tuple[int, ...]
layout_indices: tuple[int, ...]
compact_slice: slice | None
block_shape: tuple[int, ...]
channels: tuple[tuple[int, ...] | None, ...]
coefficient_slices: tuple[tuple[slice, ...], ...]
selectors: tuple[BlockSelector | None, ...]
@property
def block_index(self) -> int:
"""Return the original parameter-layout block index."""
return int(self.block.index)
@property
def block_label(self) -> str:
"""Return the original parameter-layout block label."""
return self.block.label
@property
def size(self) -> int:
"""Return the number of original coefficients covered by this entry."""
return len(self.original_indices)
@property
def is_fit(self) -> bool:
"""Return whether this entry contributes compact solve coefficients."""
return self.source == "fit"
@property
def is_freeze(self) -> bool:
"""Return whether this entry removes coefficients from the compact solve."""
return self.source == "freeze"
@dataclass(frozen=True)
class ResolvedCoefficientSelection:
"""Resolved fit/freeze selector metadata and compact solve blocks."""
entries: tuple[ResolvedCoefficientSelectionEntry, ...]
selected_blocks: tuple[SelectedCoefficientBlock, ...]
@property
def fit_entries(self) -> tuple[ResolvedCoefficientSelectionEntry, ...]:
"""Return entries that remain in the compact solve."""
return tuple(entry for entry in self.entries if entry.is_fit)
@property
def freeze_entries(self) -> tuple[ResolvedCoefficientSelectionEntry, ...]:
"""Return entries that are fixed outside the compact solve."""
return tuple(entry for entry in self.entries if entry.is_freeze)
@property
def selected_block_labels(self) -> tuple[str, ...]:
"""Return labels for blocks that retain selected coefficients."""
return tuple(entry.block_label for entry in self.fit_entries)
@dataclass(frozen=True)
class _SelectorResolution:
"""Internal metadata for one selector resolved against one block."""
indices: tuple[int, ...]
channel: tuple[int, ...] | None
coefficient_slices: tuple[slice, ...]
selector: BlockSelector | None
def block_matches_selector(block: TermBlock, selector: BlockSelector) -> bool:
"""Return whether a parameter block matches one include/exclude selector."""
if isinstance(selector, CoefficientSelector):
selector = selector.block
if isinstance(selector, int):
return int(block.index) == int(selector)
value = str(selector)
return value in {
str(block.index),
block.name,
block.kind,
block.label,
block.regularization_group,
}
def _as_tuple_channel(
channel: tuple[int, ...] | Sequence[int] | None,
) -> tuple[int, ...] | None:
"""Normalize an optional channel key."""
if channel is None:
return None
return tuple(int(value) for value in channel)
def _normalize_slice_component(component: object, dim: int) -> slice:
"""Normalize one coefficient selector component into a non-empty unit slice."""
if component is None:
return slice(None)
if isinstance(component, int):
index = int(component)
if index < 0 or index >= dim:
raise ValueError(f"coefficient index {index} is outside [0, {int(dim)})")
return slice(index, index + 1)
if isinstance(component, range):
step = 1 if component.step is None else int(component.step)
if step != 1:
raise ValueError("coefficient ranges must be contiguous")
start = int(component.start)
stop = int(component.stop)
if start < 0 or stop < 0 or start > dim or stop > dim:
raise ValueError(
f"coefficient range [{start}, {stop}) is outside [0, {int(dim)})"
)
if stop <= start:
raise ValueError("coefficient ranges must select at least one entry")
return slice(start, stop)
if isinstance(component, slice):
step = 1 if component.step is None else int(component.step)
if step != 1:
raise ValueError("coefficient slices must be contiguous")
start = 0 if component.start is None else int(component.start)
stop = dim if component.stop is None else int(component.stop)
if start < 0 or stop < 0 or start > dim or stop > dim:
raise ValueError(
f"coefficient slice [{start}, {stop}) is outside [0, {int(dim)})"
)
if stop <= start:
raise ValueError("coefficient slices must select at least one entry")
return slice(start, stop)
raise TypeError("`coeff_slice` entries must be int, slice, range, or None")
def _normalize_coeff_slices(
coeff_slice: object,
coeff_shape: tuple[int, ...],
) -> tuple[slice, ...]:
"""Normalize a trailing coefficient slice tuple for a coefficient shape."""
if coeff_slice is None:
components: tuple[object, ...] = ()
elif isinstance(coeff_slice, tuple):
components = coeff_slice
else:
components = (coeff_slice,)
if len(components) > len(coeff_shape):
raise ValueError(
"`coeff_slice` has more dimensions than the selected coefficient shape"
)
padded = components + (None,) * (len(coeff_shape) - len(components))
return tuple(
_normalize_slice_component(component, int(dim))
for component, dim in zip(padded, coeff_shape, strict=True)
)
def _normalize_selector_coeff_slices(
block: TermBlock,
selector: CoefficientSelector,
coeff_shape: tuple[int, ...],
) -> tuple[slice, ...]:
"""Normalize coefficient slices and add selector/block context to failures."""
try:
return _normalize_coeff_slices(selector.coeff_slice, coeff_shape)
except (TypeError, ValueError) as exc:
raise type(exc)(
f"{exc} for selector {selector!r} in block {block.label!r}"
) from exc
def _full_slices(shape: tuple[int, ...]) -> tuple[slice, ...]:
"""Return explicit full slices for a coefficient shape."""
return tuple(slice(None) for _ in shape)
def _flat_indices_from_mask(mask) -> tuple[int, ...]:
"""Return flattened true entries from a boolean tensor mask."""
import torch
return tuple(
int(value)
for value in torch.nonzero(mask.reshape(-1), as_tuple=False)
.reshape(-1)
.detach()
.cpu()
.tolist()
)
def _make_mask(shape: tuple[int, ...]):
"""Create a boolean mask for one coefficient block."""
import torch
return torch.zeros(shape, dtype=torch.bool)
def _indices_for_mask(mask) -> tuple[int, ...]:
"""Return selected flat indices from one selector mask."""
return _flat_indices_from_mask(mask)
def _canonical_pair_for_term(term, channel: tuple[int, ...]) -> tuple[int, int]:
"""Normalize a two-body channel for a pair-like term."""
if len(channel) != 2:
raise ValueError("pair channels must contain exactly two atomic numbers")
from ufp.terms.twobody import _canonical_pair
return _canonical_pair(
int(channel[0]),
int(channel[1]),
symmetric=bool(getattr(term, "symmetric", True)),
)
def _canonical_triplet_channel(channel: tuple[int, ...]) -> tuple[int, int, int]:
"""Normalize a three-body channel key."""
if len(channel) != 3:
raise ValueError("triplet channels must contain exactly three atomic numbers")
from ufp.terms._threebody_eval import _canonical_triplet
return _canonical_triplet(int(channel[0]), int(channel[1]), int(channel[2]))
def _selector_resolution_for_block(
block: TermBlock,
selector: CoefficientSelector,
) -> _SelectorResolution | None:
"""Return resolved metadata for one coefficient selector, or ``None``."""
if not block_matches_selector(block, selector.block):
return None
raw_channel = _as_tuple_channel(selector.channel)
shape = tuple(int(dim) for dim in block.shape)
mask = _make_mask(shape)
term = block.term
from ufp.terms.threebody import SplineThreeBodyTerm
from ufp.terms.triplet2d import SplineTriplet2DTerm
from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm
if raw_channel is None:
if block.kind in {"twobody", "threebody", "triplet2d"} and len(shape) > 1:
coeff_shape = shape[1:]
slices = _normalize_selector_coeff_slices(block, selector, coeff_shape)
mask[(slice(None),) + slices] = True
else:
slices = _normalize_selector_coeff_slices(block, selector, shape)
mask[slices] = True
indices = _indices_for_mask(mask)
return _SelectorResolution(
indices=indices,
channel=None,
coefficient_slices=slices,
selector=selector,
)
if isinstance(term, SplinePairTerm):
try:
pair = _canonical_pair_for_term(term, raw_channel)
except ValueError as exc:
raise ValueError(
f"{exc} for selector {selector!r} in block {block.label!r}"
) from exc
if pair != tuple(int(value) for value in term.pair):
return None
slices = _normalize_selector_coeff_slices(block, selector, shape)
mask[slices] = True
return _SelectorResolution(
indices=_indices_for_mask(mask),
channel=pair,
coefficient_slices=slices,
selector=selector,
)
if isinstance(term, SplineTwoBodyTerm):
try:
pair = _canonical_pair_for_term(term, raw_channel)
except ValueError as exc:
raise ValueError(
f"{exc} for selector {selector!r} in block {block.label!r}"
) from exc
pair_index = getattr(term, "_pair_index", {}).get(pair)
if pair_index is None:
return None
slices = _normalize_selector_coeff_slices(block, selector, shape[1:])
mask[(int(pair_index),) + slices] = True
return _SelectorResolution(
indices=_indices_for_mask(mask),
channel=pair,
coefficient_slices=slices,
selector=selector,
)
if isinstance(term, SplineThreeBodyTerm):
try:
triplet = _canonical_triplet_channel(raw_channel)
except ValueError as exc:
raise ValueError(
f"{exc} for selector {selector!r} in block {block.label!r}"
) from exc
triplet_index = getattr(term, "_triplet_index", {}).get(triplet)
if triplet_index is None:
return None
if len(shape) == 3:
active = tuple(int(index) for index in term._active_triplet_indices)
if active != (int(triplet_index),):
return None
slices = _normalize_selector_coeff_slices(block, selector, shape)
mask[slices] = True
else:
slices = _normalize_selector_coeff_slices(block, selector, shape[1:])
mask[(int(triplet_index),) + slices] = True
return _SelectorResolution(
indices=_indices_for_mask(mask),
channel=triplet,
coefficient_slices=slices,
selector=selector,
)
if isinstance(term, SplineTriplet2DTerm):
try:
triplet = _canonical_triplet_channel(raw_channel)
except ValueError as exc:
raise ValueError(
f"{exc} for selector {selector!r} in block {block.label!r}"
) from exc
triplet_categories = getattr(term, "triplet_categories", ())
triplet_index = (
triplet_categories.index(triplet) if triplet in triplet_categories else None
)
if triplet_index is None:
return None
slices = _normalize_selector_coeff_slices(block, selector, shape[1:])
mask[(int(triplet_index),) + slices] = True
return _SelectorResolution(
indices=_indices_for_mask(mask),
channel=triplet,
coefficient_slices=slices,
selector=selector,
)
raise ValueError(
f"selector {selector!r} requested channel-level coefficient selection "
f"for block {block.label!r}, but term type {type(term).__name__} "
"does not support channels"
)
def _coefficient_selector_no_match_message(
selector: CoefficientSelector,
matched_blocks: Sequence[TermBlock],
) -> str:
"""Return a detailed validation error for a coefficient selector miss."""
if not matched_blocks:
return f"selector {selector!r} did not match any coefficient block"
block_labels = ", ".join(block.label for block in matched_blocks)
if selector.channel is None:
return (
f"selector {selector!r} did not select any coefficients in "
f"matched block(s): {block_labels}"
)
channel = _as_tuple_channel(selector.channel)
return (
f"selector {selector!r} requested channel {channel!r}, but no matching "
f"channel exists in matched block(s): {block_labels}"
)
def coefficient_indices_for_selector(
block: TermBlock,
selector: BlockSelector,
) -> tuple[int, ...]:
"""Return flattened coefficient indices selected from one block."""
if not isinstance(selector, CoefficientSelector):
if not block_matches_selector(block, selector):
return ()
return tuple(range(block.size))
resolution = _selector_resolution_for_block(block, selector)
if resolution is None:
return ()
if not resolution.indices:
raise ValueError(f"selector did not select any coefficients in {block.label!r}")
return resolution.indices
def _resolutions_for_selector(
layout: ParameterLayout,
selector: BlockSelector,
) -> dict[int, tuple[_SelectorResolution, ...]]:
"""Resolve one block or coefficient selector across a layout."""
selected: dict[int, tuple[_SelectorResolution, ...]] = {}
for block in layout.blocks:
if not isinstance(selector, CoefficientSelector):
indices = coefficient_indices_for_selector(block, selector)
if indices:
selected[int(block.index)] = (
_SelectorResolution(
indices=indices,
channel=None,
coefficient_slices=_full_slices(block.shape),
selector=selector,
),
)
continue
if not block_matches_selector(block, selector.block):
continue
resolution = _selector_resolution_for_block(block, selector)
if resolution is not None:
selected[int(block.index)] = (resolution,)
if isinstance(selector, CoefficientSelector) and not selected:
matched_blocks = tuple(
block
for block in layout.blocks
if block_matches_selector(block, selector.block)
)
raise ValueError(
_coefficient_selector_no_match_message(selector, matched_blocks)
)
return selected
def _union_resolutions(
target: dict[int, list[_SelectorResolution]],
source: dict[int, tuple[_SelectorResolution, ...]],
) -> None:
"""Append resolved selector metadata into a mutable block-index map."""
for block_index, resolutions in source.items():
target.setdefault(int(block_index), []).extend(resolutions)
def _union_indices(resolutions: Sequence[_SelectorResolution]) -> set[int]:
"""Return the union of selected indices from resolution metadata."""
return {int(index) for resolution in resolutions for index in resolution.indices}
def _selection_shape(block: TermBlock, indices: tuple[int, ...]) -> tuple[int, ...]:
"""Return the solve shape used for one selected coefficient block."""
if indices == tuple(range(block.size)):
return block.shape
return (len(indices),)
def _implicit_block_resolution(block: TermBlock) -> _SelectorResolution:
"""Return a full-block resolution for default fit or frozen-term handling."""
return _SelectorResolution(
indices=tuple(range(block.size)),
channel=None,
coefficient_slices=_full_slices(block.shape),
selector=None,
)
def _selection_entry(
block: TermBlock,
*,
source: Literal["fit", "freeze"],
indices: tuple[int, ...],
compact_slice: slice | None,
resolutions: Sequence[_SelectorResolution],
) -> ResolvedCoefficientSelectionEntry:
"""Build one public summary entry from resolved selector metadata."""
return ResolvedCoefficientSelectionEntry(
block=block,
source=source,
original_indices=indices,
layout_indices=tuple(block.start + int(index) for index in indices),
compact_slice=compact_slice,
block_shape=tuple(int(dim) for dim in block.shape),
channels=tuple(resolution.channel for resolution in resolutions),
coefficient_slices=tuple(
resolution.coefficient_slices for resolution in resolutions
),
selectors=tuple(resolution.selector for resolution in resolutions),
)
def resolve_coefficient_selection_summary(
layout: ParameterLayout,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
) -> ResolvedCoefficientSelection:
"""Resolve fit/freeze selectors into inspectable compact-selection metadata."""
selected: dict[int, list[_SelectorResolution]] = {}
if fit_blocks is None:
for block in layout.blocks:
if block.frozen:
continue
selected[int(block.index)] = [_implicit_block_resolution(block)]
else:
for selector in fit_blocks:
_union_resolutions(selected, _resolutions_for_selector(layout, selector))
frozen: dict[int, list[_SelectorResolution]] = {}
for block in layout.blocks:
if block.frozen:
frozen[int(block.index)] = [_implicit_block_resolution(block)]
for selector in freeze_blocks:
_union_resolutions(frozen, _resolutions_for_selector(layout, selector))
entries: list[ResolvedCoefficientSelectionEntry] = []
blocks: list[SelectedCoefficientBlock] = []
compact_offset = 0
for block in layout.blocks:
block_index = int(block.index)
frozen_resolutions = tuple(frozen.get(block_index, ()))
frozen_indices = _union_indices(frozen_resolutions)
if frozen_indices:
frozen_ordered = tuple(sorted(frozen_indices))
entries.append(
_selection_entry(
block,
source="freeze",
indices=frozen_ordered,
compact_slice=None,
resolutions=frozen_resolutions,
)
)
selected_resolutions = tuple(selected.get(block_index, ()))
selected_indices = _union_indices(selected_resolutions)
if not selected_indices:
continue
selected_indices = selected_indices - frozen_indices
if not selected_indices:
continue
selected_ordered = tuple(sorted(selected_indices))
compact_slice = slice(compact_offset, compact_offset + len(selected_ordered))
entries.append(
_selection_entry(
block,
source="fit",
indices=selected_ordered,
compact_slice=compact_slice,
resolutions=selected_resolutions,
)
)
compact_offset += len(selected_ordered)
blocks.append(
SelectedCoefficientBlock(
block=block,
indices=selected_ordered,
shape=_selection_shape(block, selected_ordered),
)
)
return ResolvedCoefficientSelection(
entries=tuple(entries),
selected_blocks=tuple(blocks),
)
def resolve_coefficient_selection(
layout: ParameterLayout,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
) -> tuple[SelectedCoefficientBlock, ...]:
"""Resolve fit/freeze selectors into compact selected coefficient blocks."""
return resolve_coefficient_selection_summary(
layout,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
).selected_blocks
def selected_block_indices(
layout: ParameterLayout,
*,
fit_blocks: Sequence[BlockSelector] | None = None,
freeze_blocks: Sequence[BlockSelector] = (),
) -> tuple[int, ...]:
"""Return layout block indices included in direct assembly and solving."""
return tuple(
int(selection.block.index)
for selection in resolve_coefficient_selection(
layout,
fit_blocks=fit_blocks,
freeze_blocks=freeze_blocks,
)
)
__all__ = [
"BlockSelector",
"CoefficientSelector",
"ResolvedCoefficientSelection",
"ResolvedCoefficientSelectionEntry",
"SelectedCoefficientBlock",
"block_matches_selector",
"coefficient_indices_for_selector",
"resolve_coefficient_selection",
"resolve_coefficient_selection_summary",
"selected_block_indices",
]