"""Force-component selection helpers for supervised training subsets."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
[docs]
@dataclass(frozen=True)
class ForceComponent:
"""One selected Cartesian force component in a source dataset."""
dataset_index: int
atom_index: int
component_index: int
force: float
abs_force: float
[docs]
@dataclass(frozen=True)
class ForceComponentSelection:
"""Selected structures, force masks, and ranked force components."""
selected_indices: np.ndarray
force_masks: tuple[np.ndarray, ...]
ranked_components: tuple[ForceComponent, ...]
max_force: float
requested_configurations: int | None
requested_components: int | None
@property
def n_selected_configurations(self) -> int:
"""Return the number of selected structures."""
return int(self.selected_indices.size)
@property
def n_selected_components(self) -> int:
"""Return the number of selected Cartesian force components."""
return len(self.ranked_components)
@dataclass(frozen=True)
class _CandidateComponent:
"""Internal ranked component candidate."""
dataset_index: int
force_array_index: int
atom_index: int
component_index: int
force: float
abs_force: float
def public(self) -> ForceComponent:
"""Return the public representation for this selected component."""
return ForceComponent(
dataset_index=self.dataset_index,
atom_index=self.atom_index,
component_index=self.component_index,
force=self.force,
abs_force=self.abs_force,
)
[docs]
class LargestForceComponentSelector:
"""
Select high-force Cartesian components for focused force training.
The selector ranks absolute force components from largest to smallest,
excludes components above ``max_force``, and returns boolean masks aligned
with the selected structures.
"""
def __init__(
self,
*,
max_force: float,
n_configurations: int | None = None,
n_components: int | None = None,
allow_fewer: bool = False,
) -> None:
"""Store selection limits."""
self.max_force = float(max_force)
self.n_configurations = (
None if n_configurations is None else int(n_configurations)
)
self.n_components = None if n_components is None else int(n_components)
self.allow_fewer = bool(allow_fewer)
if not np.isfinite(self.max_force) or self.max_force <= 0.0:
raise ValueError("`max_force` must be finite and positive")
if self.n_configurations is None and self.n_components is None:
raise ValueError(
"at least one of `n_configurations` or `n_components` is required"
)
if self.n_configurations is not None and self.n_configurations <= 0:
raise ValueError("`n_configurations` must be positive")
if self.n_components is not None and self.n_components <= 0:
raise ValueError("`n_components` must be positive")
[docs]
def select(
self,
forces: Sequence[object],
*,
indices: Sequence[int] | np.ndarray | None = None,
) -> ForceComponentSelection:
"""Select high-force components from one force-array sequence."""
source_indices = (
np.arange(len(forces), dtype=int)
if indices is None
else np.asarray(indices, dtype=int)
)
if source_indices.ndim != 1:
raise ValueError("`indices` must be one-dimensional")
if np.unique(source_indices).size != source_indices.size:
raise ValueError("`indices` must not contain duplicates")
if np.any(source_indices < 0) or np.any(source_indices >= len(forces)):
raise ValueError("`indices` contains values outside `forces`")
force_arrays: list[np.ndarray] = []
candidates: list[_CandidateComponent] = []
for force_array_index, dataset_index in enumerate(source_indices.tolist()):
values = np.asarray(forces[int(dataset_index)], dtype=float)
if values.ndim != 2 or values.shape[1] != 3:
raise ValueError("each force array must have shape (n_atoms, 3)")
if not np.all(np.isfinite(values)):
raise ValueError("force arrays must contain only finite values")
force_arrays.append(values)
abs_values = np.abs(values)
eligible = np.argwhere(abs_values <= self.max_force)
for atom_index, component_index in eligible:
force = float(values[int(atom_index), int(component_index)])
candidates.append(
_CandidateComponent(
dataset_index=int(dataset_index),
force_array_index=int(force_array_index),
atom_index=int(atom_index),
component_index=int(component_index),
force=force,
abs_force=abs(force),
)
)
ranked = tuple(
sorted(
candidates,
key=lambda component: (
-component.abs_force,
component.dataset_index,
component.atom_index,
component.component_index,
),
)
)
selected = self._select_ranked(ranked)
if not selected:
raise ValueError("no force components satisfy the selection criteria")
self._validate_counts(selected)
selected_by_key = {
(
component.dataset_index,
component.atom_index,
component.component_index,
)
for component in selected
}
selected_indices = np.asarray(
sorted({component.dataset_index for component in selected}),
dtype=int,
)
index_to_force_array = {
int(dataset_index): force_array_index
for force_array_index, dataset_index in enumerate(source_indices.tolist())
}
masks: list[np.ndarray] = []
for dataset_index in selected_indices.tolist():
values = force_arrays[index_to_force_array[int(dataset_index)]]
mask = np.zeros(values.shape, dtype=bool)
for atom_index in range(values.shape[0]):
for component_index in range(3):
if (
int(dataset_index),
atom_index,
component_index,
) in selected_by_key:
mask[atom_index, component_index] = True
masks.append(mask)
ranked_selected = tuple(
component.public()
for component in sorted(
selected,
key=lambda component: (
-component.abs_force,
component.dataset_index,
component.atom_index,
component.component_index,
),
)
)
return ForceComponentSelection(
selected_indices=selected_indices,
force_masks=tuple(masks),
ranked_components=ranked_selected,
max_force=self.max_force,
requested_configurations=self.n_configurations,
requested_components=self.n_components,
)
def _select_ranked(
self,
ranked: Sequence[_CandidateComponent],
) -> list[_CandidateComponent]:
"""Return selected candidates while satisfying component/configuration caps."""
selected: list[_CandidateComponent] = []
selected_keys: set[tuple[int, int, int]] = set()
selected_configurations: set[int] = set()
if self.n_configurations is not None:
target_configurations = self.n_configurations
if self.n_components is not None:
target_configurations = min(target_configurations, self.n_components)
for component in ranked:
if component.dataset_index in selected_configurations:
continue
self._add_component(
component,
selected=selected,
selected_keys=selected_keys,
selected_configurations=selected_configurations,
)
if len(selected_configurations) >= target_configurations:
break
if self.n_components is not None:
for component in ranked:
if len(selected) >= self.n_components:
break
if self.n_configurations is not None and (
component.dataset_index not in selected_configurations
):
continue
self._add_component(
component,
selected=selected,
selected_keys=selected_keys,
selected_configurations=selected_configurations,
)
elif self.n_configurations is None:
for component in ranked:
self._add_component(
component,
selected=selected,
selected_keys=selected_keys,
selected_configurations=selected_configurations,
)
return selected
@staticmethod
def _add_component(
component: _CandidateComponent,
*,
selected: list[_CandidateComponent],
selected_keys: set[tuple[int, int, int]],
selected_configurations: set[int],
) -> None:
"""Add one component if it was not selected already."""
key = (
component.dataset_index,
component.atom_index,
component.component_index,
)
if key in selected_keys:
return
selected.append(component)
selected_keys.add(key)
selected_configurations.add(component.dataset_index)
def _validate_counts(self, selected: Sequence[_CandidateComponent]) -> None:
"""Fail when required counts could not be reached."""
if self.allow_fewer:
return
if self.n_components is not None and len(selected) < self.n_components:
raise ValueError(
"insufficient eligible force components: "
f"requested {self.n_components}, found {len(selected)}"
)
if self.n_configurations is not None:
selected_configurations = {
component.dataset_index for component in selected
}
requested = self.n_configurations
if self.n_components is not None:
requested = min(requested, self.n_components)
if len(selected_configurations) < requested:
raise ValueError(
"insufficient eligible configurations: "
f"requested {requested}, found {len(selected_configurations)}"
)
__all__ = [
"ForceComponent",
"ForceComponentSelection",
"LargestForceComponentSelector",
]