"""Reusable workflow helpers for examples and small supervised UFP studies."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
import ase
import numpy as np
[docs]
@dataclass(frozen=True)
class DatasetSplits:
"""Train, validation, and holdout indices for supervised examples."""
training_indices: np.ndarray
validation_indices: np.ndarray
holdout_indices: np.ndarray
@property
def testing_indices(self) -> np.ndarray:
"""Return the holdout split under the older testing alias."""
return self.holdout_indices
[docs]
@dataclass(frozen=True)
class SupervisedAtomsDataset:
"""ASE structures with total-energy/force labels and canonical splits."""
frames: tuple[ase.Atoms, ...]
energies: np.ndarray
forces: tuple[np.ndarray, ...]
sizes: np.ndarray
training_indices: np.ndarray
validation_indices: np.ndarray
holdout_indices: np.ndarray
@property
def testing_indices(self) -> np.ndarray:
"""Return the holdout split under the older testing alias."""
return self.holdout_indices
[docs]
def make_split_indices(
n_frames: int,
*,
seed: int,
n_train: int,
n_validation: int,
sort: bool = True,
) -> DatasetSplits:
"""Create deterministic train, validation, and holdout split indices."""
if n_frames <= 0:
raise ValueError("`n_frames` must be positive")
if n_train <= 0:
raise ValueError("`n_train` must be positive")
if n_validation <= 0:
raise ValueError("`n_validation` must be positive")
required = int(n_train) + int(n_validation)
if n_frames < required:
raise ValueError(f"need at least {required} frames, found {n_frames}")
permutation = np.random.default_rng(seed).permutation(n_frames)
training = permutation[:n_train]
validation = permutation[n_train:required]
holdout = permutation[required:]
if sort:
training = np.sort(training)
validation = np.sort(validation)
holdout = np.sort(holdout)
return DatasetSplits(
training_indices=np.asarray(training, dtype=int),
validation_indices=np.asarray(validation, dtype=int),
holdout_indices=np.asarray(holdout, dtype=int),
)
[docs]
def validate_split_indices(
*,
n_frames: int,
training_indices: Sequence[int] | np.ndarray,
validation_indices: Sequence[int] | np.ndarray,
holdout_indices: Sequence[int] | np.ndarray,
n_train: int | None = None,
n_validation: int | None = None,
allow_empty_validation: bool = True,
allow_empty_holdout: bool = True,
) -> DatasetSplits:
"""Validate split indices and return normalized integer arrays."""
splits = DatasetSplits(
training_indices=np.asarray(training_indices, dtype=int),
validation_indices=np.asarray(validation_indices, dtype=int),
holdout_indices=np.asarray(holdout_indices, dtype=int),
)
split_arrays = {
"training": splits.training_indices,
"validation": splits.validation_indices,
"holdout": splits.holdout_indices,
}
for name, values in split_arrays.items():
if values.ndim != 1:
raise ValueError(f"`{name}` split indices must be one-dimensional")
can_be_empty = (name == "validation" and allow_empty_validation) or (
name == "holdout" and allow_empty_holdout
)
if values.size == 0 and not can_be_empty:
raise ValueError(f"`{name}` split indices can not be empty")
if np.any(values < 0) or np.any(values >= n_frames):
raise ValueError(f"`{name}` split indices are outside the dataset")
if np.unique(values).size != values.size:
raise ValueError(f"`{name}` split indices contain duplicates")
combined = np.concatenate(tuple(split_arrays.values()))
if np.unique(combined).size != combined.size:
raise ValueError("training, validation, and holdout splits overlap")
if combined.size != n_frames:
raise ValueError(
"training, validation, and holdout splits do not cover all frames"
)
if n_train is not None and len(splits.training_indices) != int(n_train):
raise ValueError(f"training split must contain {int(n_train)} structures")
if n_validation is not None and len(splits.validation_indices) != int(n_validation):
raise ValueError(
f"validation split must contain {int(n_validation)} structures"
)
return splits
[docs]
def write_split_indices(
split_path: Path,
*,
dataset: str,
n_frames: int,
seed: int,
splits: DatasetSplits,
) -> None:
"""Write split indices to JSON so all example scripts share them."""
payload = {
"dataset": str(dataset),
"n_frames": int(n_frames),
"seed": int(seed),
"training": [int(index) for index in splits.training_indices.tolist()],
"validation": [int(index) for index in splits.validation_indices.tolist()],
"holdout": [int(index) for index in splits.holdout_indices.tolist()],
"testing": [int(index) for index in splits.holdout_indices.tolist()],
}
split_path.write_text(json.dumps(payload, indent=2) + "\n")
[docs]
def read_split_indices(split_path: Path) -> DatasetSplits:
"""Read train, validation, and holdout indices from JSON."""
payload = json.loads(split_path.read_text())
try:
holdout = payload["holdout"]
except KeyError:
try:
holdout = payload["testing"]
except KeyError as exc:
raise ValueError(f"{split_path} is missing split key {exc}") from exc
try:
return DatasetSplits(
training_indices=np.asarray(payload["training"], dtype=int),
validation_indices=np.asarray(payload["validation"], dtype=int),
holdout_indices=np.asarray(holdout, dtype=int),
)
except KeyError as exc:
raise ValueError(f"{split_path} is missing split key {exc}") from exc
[docs]
def load_or_create_split_indices(
split_path: Path,
*,
dataset: str,
n_frames: int,
seed: int,
n_train: int,
n_validation: int,
sort: bool = True,
) -> DatasetSplits:
"""Load split indices, creating them deterministically when absent."""
if split_path.exists():
splits = read_split_indices(split_path)
else:
splits = make_split_indices(
n_frames,
seed=seed,
n_train=n_train,
n_validation=n_validation,
sort=sort,
)
write_split_indices(
split_path,
dataset=dataset,
n_frames=n_frames,
seed=seed,
splits=splits,
)
return validate_split_indices(
n_frames=n_frames,
training_indices=splits.training_indices,
validation_indices=splits.validation_indices,
holdout_indices=splits.holdout_indices,
n_train=n_train,
n_validation=n_validation,
)
[docs]
def make_supervised_atoms_dataset(
frames: Sequence[ase.Atoms],
energies: Sequence[float],
forces: Sequence[object],
*,
splits: DatasetSplits,
) -> SupervisedAtomsDataset:
"""Build a labeled ASE dataset object from arrays and split indices."""
frame_tuple = tuple(frame.copy() for frame in frames)
energy_array = np.asarray(energies, dtype=float)
force_tuple = tuple(np.asarray(force, dtype=float) for force in forces)
if energy_array.shape != (len(frame_tuple),):
raise ValueError("`energies` must contain one scalar per frame")
if len(force_tuple) != len(frame_tuple):
raise ValueError("`forces` must contain one array per frame")
sizes = np.asarray([len(frame) for frame in frame_tuple], dtype=int)
normalized = validate_split_indices(
n_frames=len(frame_tuple),
training_indices=splits.training_indices,
validation_indices=splits.validation_indices,
holdout_indices=splits.holdout_indices,
)
return SupervisedAtomsDataset(
frames=frame_tuple,
energies=energy_array,
forces=force_tuple,
sizes=sizes,
training_indices=normalized.training_indices,
validation_indices=normalized.validation_indices,
holdout_indices=normalized.holdout_indices,
)
def _canonical_pair(first: int, second: int, *, symmetric: bool) -> tuple[int, int]:
"""Return a canonical pair key for workflow-level pair bookkeeping."""
first = int(first)
second = int(second)
if symmetric and second < first:
return second, first
return first, second
def _pair_categories(
atomic_types: Sequence[int],
*,
symmetric: bool,
) -> tuple[tuple[int, int], ...]:
"""Return all pair categories implied by an atomic-type list."""
values = tuple(int(value) for value in atomic_types)
if symmetric:
return tuple(
(first, second) for i, first in enumerate(values) for second in values[i:]
)
return tuple((first, second) for first in values for second in values)
def _dataset_atoms(
dataset,
*,
indices: Sequence[int] | np.ndarray | None,
) -> tuple[ase.Atoms, ...]:
"""Return ASE structures from supported workflow/training dataset inputs."""
if isinstance(dataset, SupervisedAtomsDataset):
selected_indices = (
dataset.training_indices if indices is None else np.asarray(indices)
)
return tuple(dataset.frames[int(index)] for index in selected_indices)
if isinstance(dataset, ase.Atoms):
if indices is not None:
raise ValueError("`indices` can not be used with a single Atoms object")
return (dataset,)
if hasattr(dataset, "__len__") and hasattr(dataset, "__getitem__"):
dataset_indices = (
range(len(dataset)) if indices is None else tuple(int(i) for i in indices)
)
atoms_list = []
for index in dataset_indices:
item = dataset[index]
atoms = item.atoms if hasattr(item, "atoms") else item
if not isinstance(atoms, ase.Atoms):
raise TypeError(
"dataset entries must be ASE Atoms or have an `atoms` field"
)
atoms_list.append(atoms)
return tuple(atoms_list)
raise TypeError("`dataset` must be an ASE Atoms object or an ASE-backed dataset")
[docs]
def minimum_pair_distances_from_dataset(
dataset,
*,
atomic_types: Sequence[int] | None = None,
active_pairs: Sequence[tuple[int, int]] | None = None,
symmetric: bool = True,
indices: Sequence[int] | np.ndarray | None = None,
) -> dict[tuple[int, int], float]:
"""Return shortest observed pair distance for each requested pair channel.
For ``SupervisedAtomsDataset`` inputs, the training split is used by default.
For ``ASEAtomsDataset`` or a sequence-like dataset, all entries are used
unless ``indices`` is provided.
"""
if atomic_types is None and active_pairs is None:
raise ValueError("provide `atomic_types` or `active_pairs`")
if active_pairs is None:
assert atomic_types is not None
requested_pairs = set(_pair_categories(atomic_types, symmetric=symmetric))
else:
requested_pairs = {
_canonical_pair(first, second, symmetric=symmetric)
for first, second in active_pairs
}
minima = {pair: np.inf for pair in requested_pairs}
atoms_tuple = _dataset_atoms(dataset, indices=indices)
if not atoms_tuple:
raise ValueError("dataset contains no structures")
for atoms in atoms_tuple:
numbers = np.asarray(atoms.numbers, dtype=int)
if numbers.size < 2:
continue
distances = atoms.get_all_distances(mic=bool(np.any(atoms.pbc)))
if symmetric:
index_iter = (
(i, j) for i in range(numbers.size) for j in range(i + 1, numbers.size)
)
else:
index_iter = (
(i, j)
for i in range(numbers.size)
for j in range(numbers.size)
if i != j
)
for i, j in index_iter:
distance = float(distances[i, j])
if distance <= 0.0:
continue
pair = _canonical_pair(numbers[i], numbers[j], symmetric=symmetric)
if pair in minima:
minima[pair] = min(minima[pair], distance)
missing = [pair for pair, distance in minima.items() if not np.isfinite(distance)]
if missing:
raise ValueError(f"no observed distances for requested pairs: {missing}")
return {pair: float(distance) for pair, distance in sorted(minima.items())}