Source code for ufp.leastsquares._selection

"""Coefficient selection helpers for least-squares fitting."""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Literal

from ufp.leastsquares._layout import ParameterLayout, TermBlock


[docs] @dataclass(frozen=True) class CoefficientSelector: """ Select a subset of one or more coefficient blocks. ``block`` uses the existing block selector semantics: block index, block name, kind, label, or regularization group. ``channel`` optionally selects a pair or triplet channel inside categorized coefficient tensors. ``coeff_slice`` selects trailing coefficient dimensions inside the selected channel or channels. """ block: int | str channel: tuple[int, ...] | None = None coeff_slice: Any = None
BlockSelector = int | str | CoefficientSelector @dataclass(frozen=True) class SelectedCoefficientBlock: """Compact solve metadata for selected coefficients from one term block.""" block: TermBlock indices: tuple[int, ...] shape: tuple[int, ...] @property def key(self) -> int: """Return the solve key for this selected block.""" return int(self.block.index) @property def size(self) -> int: """Return the compact coefficient count.""" return len(self.indices) @property def is_full_block(self) -> bool: """Return whether all coefficients from the original block are selected.""" return self.size == self.block.size and self.indices == tuple( range(self.block.size) ) @dataclass(frozen=True) class ResolvedCoefficientSelectionEntry: """Inspectable metadata for one resolved fit or freeze selection entry.""" block: TermBlock source: Literal["fit", "freeze"] original_indices: tuple[int, ...] layout_indices: tuple[int, ...] compact_slice: slice | None block_shape: tuple[int, ...] channels: tuple[tuple[int, ...] | None, ...] coefficient_slices: tuple[tuple[slice, ...], ...] selectors: tuple[BlockSelector | None, ...] @property def block_index(self) -> int: """Return the original parameter-layout block index.""" return int(self.block.index) @property def block_label(self) -> str: """Return the original parameter-layout block label.""" return self.block.label @property def size(self) -> int: """Return the number of original coefficients covered by this entry.""" return len(self.original_indices) @property def is_fit(self) -> bool: """Return whether this entry contributes compact solve coefficients.""" return self.source == "fit" @property def is_freeze(self) -> bool: """Return whether this entry removes coefficients from the compact solve.""" return self.source == "freeze" @dataclass(frozen=True) class ResolvedCoefficientSelection: """Resolved fit/freeze selector metadata and compact solve blocks.""" entries: tuple[ResolvedCoefficientSelectionEntry, ...] selected_blocks: tuple[SelectedCoefficientBlock, ...] @property def fit_entries(self) -> tuple[ResolvedCoefficientSelectionEntry, ...]: """Return entries that remain in the compact solve.""" return tuple(entry for entry in self.entries if entry.is_fit) @property def freeze_entries(self) -> tuple[ResolvedCoefficientSelectionEntry, ...]: """Return entries that are fixed outside the compact solve.""" return tuple(entry for entry in self.entries if entry.is_freeze) @property def selected_block_labels(self) -> tuple[str, ...]: """Return labels for blocks that retain selected coefficients.""" return tuple(entry.block_label for entry in self.fit_entries) @dataclass(frozen=True) class _SelectorResolution: """Internal metadata for one selector resolved against one block.""" indices: tuple[int, ...] channel: tuple[int, ...] | None coefficient_slices: tuple[slice, ...] selector: BlockSelector | None def block_matches_selector(block: TermBlock, selector: BlockSelector) -> bool: """Return whether a parameter block matches one include/exclude selector.""" if isinstance(selector, CoefficientSelector): selector = selector.block if isinstance(selector, int): return int(block.index) == int(selector) value = str(selector) return value in { str(block.index), block.name, block.kind, block.label, block.regularization_group, } def _as_tuple_channel( channel: tuple[int, ...] | Sequence[int] | None, ) -> tuple[int, ...] | None: """Normalize an optional channel key.""" if channel is None: return None return tuple(int(value) for value in channel) def _normalize_slice_component(component: object, dim: int) -> slice: """Normalize one coefficient selector component into a non-empty unit slice.""" if component is None: return slice(None) if isinstance(component, int): index = int(component) if index < 0 or index >= dim: raise ValueError(f"coefficient index {index} is outside [0, {int(dim)})") return slice(index, index + 1) if isinstance(component, range): step = 1 if component.step is None else int(component.step) if step != 1: raise ValueError("coefficient ranges must be contiguous") start = int(component.start) stop = int(component.stop) if start < 0 or stop < 0 or start > dim or stop > dim: raise ValueError( f"coefficient range [{start}, {stop}) is outside [0, {int(dim)})" ) if stop <= start: raise ValueError("coefficient ranges must select at least one entry") return slice(start, stop) if isinstance(component, slice): step = 1 if component.step is None else int(component.step) if step != 1: raise ValueError("coefficient slices must be contiguous") start = 0 if component.start is None else int(component.start) stop = dim if component.stop is None else int(component.stop) if start < 0 or stop < 0 or start > dim or stop > dim: raise ValueError( f"coefficient slice [{start}, {stop}) is outside [0, {int(dim)})" ) if stop <= start: raise ValueError("coefficient slices must select at least one entry") return slice(start, stop) raise TypeError("`coeff_slice` entries must be int, slice, range, or None") def _normalize_coeff_slices( coeff_slice: object, coeff_shape: tuple[int, ...], ) -> tuple[slice, ...]: """Normalize a trailing coefficient slice tuple for a coefficient shape.""" if coeff_slice is None: components: tuple[object, ...] = () elif isinstance(coeff_slice, tuple): components = coeff_slice else: components = (coeff_slice,) if len(components) > len(coeff_shape): raise ValueError( "`coeff_slice` has more dimensions than the selected coefficient shape" ) padded = components + (None,) * (len(coeff_shape) - len(components)) return tuple( _normalize_slice_component(component, int(dim)) for component, dim in zip(padded, coeff_shape, strict=True) ) def _normalize_selector_coeff_slices( block: TermBlock, selector: CoefficientSelector, coeff_shape: tuple[int, ...], ) -> tuple[slice, ...]: """Normalize coefficient slices and add selector/block context to failures.""" try: return _normalize_coeff_slices(selector.coeff_slice, coeff_shape) except (TypeError, ValueError) as exc: raise type(exc)( f"{exc} for selector {selector!r} in block {block.label!r}" ) from exc def _full_slices(shape: tuple[int, ...]) -> tuple[slice, ...]: """Return explicit full slices for a coefficient shape.""" return tuple(slice(None) for _ in shape) def _flat_indices_from_mask(mask) -> tuple[int, ...]: """Return flattened true entries from a boolean tensor mask.""" import torch return tuple( int(value) for value in torch.nonzero(mask.reshape(-1), as_tuple=False) .reshape(-1) .detach() .cpu() .tolist() ) def _make_mask(shape: tuple[int, ...]): """Create a boolean mask for one coefficient block.""" import torch return torch.zeros(shape, dtype=torch.bool) def _indices_for_mask(mask) -> tuple[int, ...]: """Return selected flat indices from one selector mask.""" return _flat_indices_from_mask(mask) def _canonical_pair_for_term(term, channel: tuple[int, ...]) -> tuple[int, int]: """Normalize a two-body channel for a pair-like term.""" if len(channel) != 2: raise ValueError("pair channels must contain exactly two atomic numbers") from ufp.terms.twobody import _canonical_pair return _canonical_pair( int(channel[0]), int(channel[1]), symmetric=bool(getattr(term, "symmetric", True)), ) def _canonical_triplet_channel(channel: tuple[int, ...]) -> tuple[int, int, int]: """Normalize a three-body channel key.""" if len(channel) != 3: raise ValueError("triplet channels must contain exactly three atomic numbers") from ufp.terms._threebody_eval import _canonical_triplet return _canonical_triplet(int(channel[0]), int(channel[1]), int(channel[2])) def _selector_resolution_for_block( block: TermBlock, selector: CoefficientSelector, ) -> _SelectorResolution | None: """Return resolved metadata for one coefficient selector, or ``None``.""" if not block_matches_selector(block, selector.block): return None raw_channel = _as_tuple_channel(selector.channel) shape = tuple(int(dim) for dim in block.shape) mask = _make_mask(shape) term = block.term from ufp.terms.threebody import SplineThreeBodyTerm from ufp.terms.triplet2d import SplineTriplet2DTerm from ufp.terms.twobody import SplinePairTerm, SplineTwoBodyTerm if raw_channel is None: if block.kind in {"twobody", "threebody", "triplet2d"} and len(shape) > 1: coeff_shape = shape[1:] slices = _normalize_selector_coeff_slices(block, selector, coeff_shape) mask[(slice(None),) + slices] = True else: slices = _normalize_selector_coeff_slices(block, selector, shape) mask[slices] = True indices = _indices_for_mask(mask) return _SelectorResolution( indices=indices, channel=None, coefficient_slices=slices, selector=selector, ) if isinstance(term, SplinePairTerm): try: pair = _canonical_pair_for_term(term, raw_channel) except ValueError as exc: raise ValueError( f"{exc} for selector {selector!r} in block {block.label!r}" ) from exc if pair != tuple(int(value) for value in term.pair): return None slices = _normalize_selector_coeff_slices(block, selector, shape) mask[slices] = True return _SelectorResolution( indices=_indices_for_mask(mask), channel=pair, coefficient_slices=slices, selector=selector, ) if isinstance(term, SplineTwoBodyTerm): try: pair = _canonical_pair_for_term(term, raw_channel) except ValueError as exc: raise ValueError( f"{exc} for selector {selector!r} in block {block.label!r}" ) from exc pair_index = getattr(term, "_pair_index", {}).get(pair) if pair_index is None: return None slices = _normalize_selector_coeff_slices(block, selector, shape[1:]) mask[(int(pair_index),) + slices] = True return _SelectorResolution( indices=_indices_for_mask(mask), channel=pair, coefficient_slices=slices, selector=selector, ) if isinstance(term, SplineThreeBodyTerm): try: triplet = _canonical_triplet_channel(raw_channel) except ValueError as exc: raise ValueError( f"{exc} for selector {selector!r} in block {block.label!r}" ) from exc triplet_index = getattr(term, "_triplet_index", {}).get(triplet) if triplet_index is None: return None if len(shape) == 3: active = tuple(int(index) for index in term._active_triplet_indices) if active != (int(triplet_index),): return None slices = _normalize_selector_coeff_slices(block, selector, shape) mask[slices] = True else: slices = _normalize_selector_coeff_slices(block, selector, shape[1:]) mask[(int(triplet_index),) + slices] = True return _SelectorResolution( indices=_indices_for_mask(mask), channel=triplet, coefficient_slices=slices, selector=selector, ) if isinstance(term, SplineTriplet2DTerm): try: triplet = _canonical_triplet_channel(raw_channel) except ValueError as exc: raise ValueError( f"{exc} for selector {selector!r} in block {block.label!r}" ) from exc triplet_categories = getattr(term, "triplet_categories", ()) triplet_index = ( triplet_categories.index(triplet) if triplet in triplet_categories else None ) if triplet_index is None: return None slices = _normalize_selector_coeff_slices(block, selector, shape[1:]) mask[(int(triplet_index),) + slices] = True return _SelectorResolution( indices=_indices_for_mask(mask), channel=triplet, coefficient_slices=slices, selector=selector, ) raise ValueError( f"selector {selector!r} requested channel-level coefficient selection " f"for block {block.label!r}, but term type {type(term).__name__} " "does not support channels" ) def _coefficient_selector_no_match_message( selector: CoefficientSelector, matched_blocks: Sequence[TermBlock], ) -> str: """Return a detailed validation error for a coefficient selector miss.""" if not matched_blocks: return f"selector {selector!r} did not match any coefficient block" block_labels = ", ".join(block.label for block in matched_blocks) if selector.channel is None: return ( f"selector {selector!r} did not select any coefficients in " f"matched block(s): {block_labels}" ) channel = _as_tuple_channel(selector.channel) return ( f"selector {selector!r} requested channel {channel!r}, but no matching " f"channel exists in matched block(s): {block_labels}" ) def coefficient_indices_for_selector( block: TermBlock, selector: BlockSelector, ) -> tuple[int, ...]: """Return flattened coefficient indices selected from one block.""" if not isinstance(selector, CoefficientSelector): if not block_matches_selector(block, selector): return () return tuple(range(block.size)) resolution = _selector_resolution_for_block(block, selector) if resolution is None: return () if not resolution.indices: raise ValueError(f"selector did not select any coefficients in {block.label!r}") return resolution.indices def _resolutions_for_selector( layout: ParameterLayout, selector: BlockSelector, ) -> dict[int, tuple[_SelectorResolution, ...]]: """Resolve one block or coefficient selector across a layout.""" selected: dict[int, tuple[_SelectorResolution, ...]] = {} for block in layout.blocks: if not isinstance(selector, CoefficientSelector): indices = coefficient_indices_for_selector(block, selector) if indices: selected[int(block.index)] = ( _SelectorResolution( indices=indices, channel=None, coefficient_slices=_full_slices(block.shape), selector=selector, ), ) continue if not block_matches_selector(block, selector.block): continue resolution = _selector_resolution_for_block(block, selector) if resolution is not None: selected[int(block.index)] = (resolution,) if isinstance(selector, CoefficientSelector) and not selected: matched_blocks = tuple( block for block in layout.blocks if block_matches_selector(block, selector.block) ) raise ValueError( _coefficient_selector_no_match_message(selector, matched_blocks) ) return selected def _union_resolutions( target: dict[int, list[_SelectorResolution]], source: dict[int, tuple[_SelectorResolution, ...]], ) -> None: """Append resolved selector metadata into a mutable block-index map.""" for block_index, resolutions in source.items(): target.setdefault(int(block_index), []).extend(resolutions) def _union_indices(resolutions: Sequence[_SelectorResolution]) -> set[int]: """Return the union of selected indices from resolution metadata.""" return {int(index) for resolution in resolutions for index in resolution.indices} def _selection_shape(block: TermBlock, indices: tuple[int, ...]) -> tuple[int, ...]: """Return the solve shape used for one selected coefficient block.""" if indices == tuple(range(block.size)): return block.shape return (len(indices),) def _implicit_block_resolution(block: TermBlock) -> _SelectorResolution: """Return a full-block resolution for default fit or frozen-term handling.""" return _SelectorResolution( indices=tuple(range(block.size)), channel=None, coefficient_slices=_full_slices(block.shape), selector=None, ) def _selection_entry( block: TermBlock, *, source: Literal["fit", "freeze"], indices: tuple[int, ...], compact_slice: slice | None, resolutions: Sequence[_SelectorResolution], ) -> ResolvedCoefficientSelectionEntry: """Build one public summary entry from resolved selector metadata.""" return ResolvedCoefficientSelectionEntry( block=block, source=source, original_indices=indices, layout_indices=tuple(block.start + int(index) for index in indices), compact_slice=compact_slice, block_shape=tuple(int(dim) for dim in block.shape), channels=tuple(resolution.channel for resolution in resolutions), coefficient_slices=tuple( resolution.coefficient_slices for resolution in resolutions ), selectors=tuple(resolution.selector for resolution in resolutions), ) def resolve_coefficient_selection_summary( layout: ParameterLayout, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), ) -> ResolvedCoefficientSelection: """Resolve fit/freeze selectors into inspectable compact-selection metadata.""" selected: dict[int, list[_SelectorResolution]] = {} if fit_blocks is None: for block in layout.blocks: if block.frozen: continue selected[int(block.index)] = [_implicit_block_resolution(block)] else: for selector in fit_blocks: _union_resolutions(selected, _resolutions_for_selector(layout, selector)) frozen: dict[int, list[_SelectorResolution]] = {} for block in layout.blocks: if block.frozen: frozen[int(block.index)] = [_implicit_block_resolution(block)] for selector in freeze_blocks: _union_resolutions(frozen, _resolutions_for_selector(layout, selector)) entries: list[ResolvedCoefficientSelectionEntry] = [] blocks: list[SelectedCoefficientBlock] = [] compact_offset = 0 for block in layout.blocks: block_index = int(block.index) frozen_resolutions = tuple(frozen.get(block_index, ())) frozen_indices = _union_indices(frozen_resolutions) if frozen_indices: frozen_ordered = tuple(sorted(frozen_indices)) entries.append( _selection_entry( block, source="freeze", indices=frozen_ordered, compact_slice=None, resolutions=frozen_resolutions, ) ) selected_resolutions = tuple(selected.get(block_index, ())) selected_indices = _union_indices(selected_resolutions) if not selected_indices: continue selected_indices = selected_indices - frozen_indices if not selected_indices: continue selected_ordered = tuple(sorted(selected_indices)) compact_slice = slice(compact_offset, compact_offset + len(selected_ordered)) entries.append( _selection_entry( block, source="fit", indices=selected_ordered, compact_slice=compact_slice, resolutions=selected_resolutions, ) ) compact_offset += len(selected_ordered) blocks.append( SelectedCoefficientBlock( block=block, indices=selected_ordered, shape=_selection_shape(block, selected_ordered), ) ) return ResolvedCoefficientSelection( entries=tuple(entries), selected_blocks=tuple(blocks), ) def resolve_coefficient_selection( layout: ParameterLayout, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), ) -> tuple[SelectedCoefficientBlock, ...]: """Resolve fit/freeze selectors into compact selected coefficient blocks.""" return resolve_coefficient_selection_summary( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ).selected_blocks def selected_block_indices( layout: ParameterLayout, *, fit_blocks: Sequence[BlockSelector] | None = None, freeze_blocks: Sequence[BlockSelector] = (), ) -> tuple[int, ...]: """Return layout block indices included in direct assembly and solving.""" return tuple( int(selection.block.index) for selection in resolve_coefficient_selection( layout, fit_blocks=fit_blocks, freeze_blocks=freeze_blocks, ) ) __all__ = [ "BlockSelector", "CoefficientSelector", "ResolvedCoefficientSelection", "ResolvedCoefficientSelectionEntry", "SelectedCoefficientBlock", "block_matches_selector", "coefficient_indices_for_selector", "resolve_coefficient_selection", "resolve_coefficient_selection_summary", "selected_block_indices", ]