Source code for ufp.training.force_selection

"""Force-component selection helpers for supervised training subsets."""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass

import numpy as np


[docs] @dataclass(frozen=True) class ForceComponent: """One selected Cartesian force component in a source dataset.""" dataset_index: int atom_index: int component_index: int force: float abs_force: float
[docs] @dataclass(frozen=True) class ForceComponentSelection: """Selected structures, force masks, and ranked force components.""" selected_indices: np.ndarray force_masks: tuple[np.ndarray, ...] ranked_components: tuple[ForceComponent, ...] max_force: float requested_configurations: int | None requested_components: int | None @property def n_selected_configurations(self) -> int: """Return the number of selected structures.""" return int(self.selected_indices.size) @property def n_selected_components(self) -> int: """Return the number of selected Cartesian force components.""" return len(self.ranked_components)
[docs] def as_metadata(self, *, prefix: str = "force_selection") -> dict[str, object]: """Return JSON-friendly scalar metadata for checkpoints.""" selected_abs_forces = [ component.abs_force for component in self.ranked_components ] key_prefix = "" if prefix == "" else f"{prefix}_" return { f"{key_prefix}max_force": float(self.max_force), f"{key_prefix}requested_configurations": self.requested_configurations, f"{key_prefix}requested_components": self.requested_components, f"{key_prefix}selected_configurations": self.n_selected_configurations, f"{key_prefix}selected_components": self.n_selected_components, f"{key_prefix}selected_indices": [ int(index) for index in self.selected_indices.tolist() ], f"{key_prefix}min_abs_force": ( None if not selected_abs_forces else float(min(selected_abs_forces)) ), f"{key_prefix}max_abs_force": ( None if not selected_abs_forces else float(max(selected_abs_forces)) ), }
@dataclass(frozen=True) class _CandidateComponent: """Internal ranked component candidate.""" dataset_index: int force_array_index: int atom_index: int component_index: int force: float abs_force: float def public(self) -> ForceComponent: """Return the public representation for this selected component.""" return ForceComponent( dataset_index=self.dataset_index, atom_index=self.atom_index, component_index=self.component_index, force=self.force, abs_force=self.abs_force, )
[docs] class LargestForceComponentSelector: """ Select high-force Cartesian components for focused force training. The selector ranks absolute force components from largest to smallest, excludes components above ``max_force``, and returns boolean masks aligned with the selected structures. """ def __init__( self, *, max_force: float, n_configurations: int | None = None, n_components: int | None = None, allow_fewer: bool = False, ) -> None: """Store selection limits.""" self.max_force = float(max_force) self.n_configurations = ( None if n_configurations is None else int(n_configurations) ) self.n_components = None if n_components is None else int(n_components) self.allow_fewer = bool(allow_fewer) if not np.isfinite(self.max_force) or self.max_force <= 0.0: raise ValueError("`max_force` must be finite and positive") if self.n_configurations is None and self.n_components is None: raise ValueError( "at least one of `n_configurations` or `n_components` is required" ) if self.n_configurations is not None and self.n_configurations <= 0: raise ValueError("`n_configurations` must be positive") if self.n_components is not None and self.n_components <= 0: raise ValueError("`n_components` must be positive")
[docs] def select( self, forces: Sequence[object], *, indices: Sequence[int] | np.ndarray | None = None, ) -> ForceComponentSelection: """Select high-force components from one force-array sequence.""" source_indices = ( np.arange(len(forces), dtype=int) if indices is None else np.asarray(indices, dtype=int) ) if source_indices.ndim != 1: raise ValueError("`indices` must be one-dimensional") if np.unique(source_indices).size != source_indices.size: raise ValueError("`indices` must not contain duplicates") if np.any(source_indices < 0) or np.any(source_indices >= len(forces)): raise ValueError("`indices` contains values outside `forces`") force_arrays: list[np.ndarray] = [] candidates: list[_CandidateComponent] = [] for force_array_index, dataset_index in enumerate(source_indices.tolist()): values = np.asarray(forces[int(dataset_index)], dtype=float) if values.ndim != 2 or values.shape[1] != 3: raise ValueError("each force array must have shape (n_atoms, 3)") if not np.all(np.isfinite(values)): raise ValueError("force arrays must contain only finite values") force_arrays.append(values) abs_values = np.abs(values) eligible = np.argwhere(abs_values <= self.max_force) for atom_index, component_index in eligible: force = float(values[int(atom_index), int(component_index)]) candidates.append( _CandidateComponent( dataset_index=int(dataset_index), force_array_index=int(force_array_index), atom_index=int(atom_index), component_index=int(component_index), force=force, abs_force=abs(force), ) ) ranked = tuple( sorted( candidates, key=lambda component: ( -component.abs_force, component.dataset_index, component.atom_index, component.component_index, ), ) ) selected = self._select_ranked(ranked) if not selected: raise ValueError("no force components satisfy the selection criteria") self._validate_counts(selected) selected_by_key = { ( component.dataset_index, component.atom_index, component.component_index, ) for component in selected } selected_indices = np.asarray( sorted({component.dataset_index for component in selected}), dtype=int, ) index_to_force_array = { int(dataset_index): force_array_index for force_array_index, dataset_index in enumerate(source_indices.tolist()) } masks: list[np.ndarray] = [] for dataset_index in selected_indices.tolist(): values = force_arrays[index_to_force_array[int(dataset_index)]] mask = np.zeros(values.shape, dtype=bool) for atom_index in range(values.shape[0]): for component_index in range(3): if ( int(dataset_index), atom_index, component_index, ) in selected_by_key: mask[atom_index, component_index] = True masks.append(mask) ranked_selected = tuple( component.public() for component in sorted( selected, key=lambda component: ( -component.abs_force, component.dataset_index, component.atom_index, component.component_index, ), ) ) return ForceComponentSelection( selected_indices=selected_indices, force_masks=tuple(masks), ranked_components=ranked_selected, max_force=self.max_force, requested_configurations=self.n_configurations, requested_components=self.n_components, )
def _select_ranked( self, ranked: Sequence[_CandidateComponent], ) -> list[_CandidateComponent]: """Return selected candidates while satisfying component/configuration caps.""" selected: list[_CandidateComponent] = [] selected_keys: set[tuple[int, int, int]] = set() selected_configurations: set[int] = set() if self.n_configurations is not None: target_configurations = self.n_configurations if self.n_components is not None: target_configurations = min(target_configurations, self.n_components) for component in ranked: if component.dataset_index in selected_configurations: continue self._add_component( component, selected=selected, selected_keys=selected_keys, selected_configurations=selected_configurations, ) if len(selected_configurations) >= target_configurations: break if self.n_components is not None: for component in ranked: if len(selected) >= self.n_components: break if self.n_configurations is not None and ( component.dataset_index not in selected_configurations ): continue self._add_component( component, selected=selected, selected_keys=selected_keys, selected_configurations=selected_configurations, ) elif self.n_configurations is None: for component in ranked: self._add_component( component, selected=selected, selected_keys=selected_keys, selected_configurations=selected_configurations, ) return selected @staticmethod def _add_component( component: _CandidateComponent, *, selected: list[_CandidateComponent], selected_keys: set[tuple[int, int, int]], selected_configurations: set[int], ) -> None: """Add one component if it was not selected already.""" key = ( component.dataset_index, component.atom_index, component.component_index, ) if key in selected_keys: return selected.append(component) selected_keys.add(key) selected_configurations.add(component.dataset_index) def _validate_counts(self, selected: Sequence[_CandidateComponent]) -> None: """Fail when required counts could not be reached.""" if self.allow_fewer: return if self.n_components is not None and len(selected) < self.n_components: raise ValueError( "insufficient eligible force components: " f"requested {self.n_components}, found {len(selected)}" ) if self.n_configurations is not None: selected_configurations = { component.dataset_index for component in selected } requested = self.n_configurations if self.n_components is not None: requested = min(requested, self.n_components) if len(selected_configurations) < requested: raise ValueError( "insufficient eligible configurations: " f"requested {requested}, found {len(selected_configurations)}" )
__all__ = [ "ForceComponent", "ForceComponentSelection", "LargestForceComponentSelector", ]