"""Progressive regularization tuning for linear least-squares workflows."""
from __future__ import annotations
import itertools
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import cast
import numpy as np
import torch
from ufp.leastsquares import FitSample, LinearFitResult, LinearFitter
from ufp.terms import UFPModel
from ufp.workflows.data import SupervisedAtomsDataset
from ufp.workflows.predictions import (
fit_samples_from_dataset,
prediction_metrics_for_split,
)
_RIDGE_GROUPS = ("ridge", "onebody", "twobody", "threebody")
_ENERGY_RMSE_KEY = "rmse_energy_mev_per_atom"
_FORCE_RMSE_KEY = "rmse_force_mev_per_angstrom"
[docs]
@dataclass(frozen=True)
class RidgeGroupEstimate:
"""Data-scale ridge estimate for one coefficient regularization group."""
group: str
n_parameters: int
design_trace: float
trace_per_parameter: float
suggested_ridge: float
[docs]
def to_dict(self) -> dict[str, object]:
"""Return a JSON-friendly representation."""
return {
"group": self.group,
"n_parameters": int(self.n_parameters),
"design_trace": float(self.design_trace),
"trace_per_parameter": float(self.trace_per_parameter),
"suggested_ridge": float(self.suggested_ridge),
}
[docs]
@dataclass(frozen=True)
class RidgeScaleEstimate:
"""Block-scale ridge suggestions derived from least-squares design statistics."""
groups: tuple[RidgeGroupEstimate, ...]
alpha: float
sample_count: int
diagnostics: Mapping[str, float] = field(default_factory=dict)
@property
def by_group(self) -> dict[str, RidgeGroupEstimate]:
"""Return group estimates keyed by group name."""
return {estimate.group: estimate for estimate in self.groups}
[docs]
def candidate(self) -> "RegularizationCandidate":
"""Return the direct ridge candidate implied by this estimate."""
by_group = self.by_group
ridge = by_group.get("ridge")
fallback = 0.0 if ridge is None else float(ridge.suggested_ridge)
onebody = by_group.get("onebody", ridge)
twobody = by_group.get("twobody", ridge)
threebody = by_group.get("threebody", ridge)
return RegularizationCandidate(
ridge=fallback,
onebody_ridge=(
fallback if onebody is None else float(onebody.suggested_ridge)
),
twobody_ridge=(
fallback if twobody is None else float(twobody.suggested_ridge)
),
threebody_ridge=(
fallback if threebody is None else float(threebody.suggested_ridge)
),
)
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON-friendly estimate metadata."""
return {
"alpha": float(self.alpha),
"sample_count": int(self.sample_count),
"diagnostics": {
str(key): float(value) for key, value in self.diagnostics.items()
},
"groups": [group.to_dict() for group in self.groups],
}
[docs]
@dataclass(frozen=True)
class RegularizationCandidate:
"""Concrete ridge settings passed to ``LinearFitter``."""
ridge: float = 0.0
onebody_ridge: float = 0.0
twobody_ridge: float = 0.0
threebody_ridge: float = 0.0
def __post_init__(self) -> None:
"""Validate candidate weights."""
for name in _RIDGE_GROUPS:
value = float(getattr(self, name if name == "ridge" else f"{name}_ridge"))
if value < 0.0:
raise ValueError(f"`{name}` ridge value must be non-negative")
object.__setattr__(
self,
name if name == "ridge" else f"{name}_ridge",
value,
)
[docs]
def as_fitter_kwargs(self) -> dict[str, float]:
"""Return ridge keyword arguments accepted by ``LinearFitter``."""
return {
"ridge": float(self.ridge),
"onebody_ridge": float(self.onebody_ridge),
"twobody_ridge": float(self.twobody_ridge),
"threebody_ridge": float(self.threebody_ridge),
}
[docs]
def group_value(self, group: str) -> float:
"""Return the ridge value for one canonical group."""
if group == "ridge":
return float(self.ridge)
if group == "onebody":
return float(self.onebody_ridge)
if group == "twobody":
return float(self.twobody_ridge)
if group == "threebody":
return float(self.threebody_ridge)
raise ValueError(f"unsupported regularization group {group!r}")
[docs]
def with_group_value(self, group: str, value: float) -> "RegularizationCandidate":
"""Return this candidate with one group changed."""
kwargs = self.as_fitter_kwargs()
key = "ridge" if group == "ridge" else f"{group}_ridge"
if key not in kwargs:
raise ValueError(f"unsupported regularization group {group!r}")
kwargs[key] = float(value)
return RegularizationCandidate(**kwargs)
[docs]
def total_ridge(self) -> float:
"""Return a scalar used for deterministic tie-breaking."""
return (
float(self.ridge)
+ float(self.onebody_ridge)
+ float(self.twobody_ridge)
+ float(self.threebody_ridge)
)
[docs]
def to_dict(self) -> dict[str, float]:
"""Return a JSON-friendly representation."""
return self.as_fitter_kwargs()
[docs]
@dataclass(frozen=True)
class RegularizationSearchConfig:
"""Options for progressive regularization tuning."""
seed: int = 0
alpha: float = 1.0e-6
estimate_subset_size: int | None = 64
stage_subset_sizes: tuple[int, ...] = (64, 256)
candidate_multipliers: tuple[float, ...] = (
1.0e-3,
1.0e-2,
1.0e-1,
1.0,
10.0,
100.0,
1.0e3,
)
refinement_multipliers: tuple[float, ...] = (1.0e-1, 1.0, 10.0)
top_k_per_stage: int = 5
validation_fraction: float = 0.2
minimum_validation_size: int = 1
energy_score_weight: float = 1.0
force_score_weight: float = 1.0
refit_full: bool = True
cache_directory: Path | str | None = None
cache_mode: str = "auto"
dense_cache_parameter_limit: int = 20_000
prediction_batch_size: int = 64
progress: bool = False
def __post_init__(self) -> None:
"""Validate search settings and normalize tuple-like options."""
object.__setattr__(
self,
"stage_subset_sizes",
tuple(int(size) for size in self.stage_subset_sizes),
)
object.__setattr__(
self,
"candidate_multipliers",
tuple(float(value) for value in self.candidate_multipliers),
)
object.__setattr__(
self,
"refinement_multipliers",
tuple(float(value) for value in self.refinement_multipliers),
)
if self.alpha < 0.0:
raise ValueError("`alpha` must be non-negative")
if self.estimate_subset_size is not None and self.estimate_subset_size <= 0:
raise ValueError("`estimate_subset_size` must be positive")
if any(size <= 0 for size in self.stage_subset_sizes):
raise ValueError("`stage_subset_sizes` entries must be positive")
if any(value < 0.0 for value in self.candidate_multipliers):
raise ValueError("`candidate_multipliers` entries must be non-negative")
if any(value < 0.0 for value in self.refinement_multipliers):
raise ValueError("`refinement_multipliers` entries must be non-negative")
if self.top_k_per_stage <= 0:
raise ValueError("`top_k_per_stage` must be positive")
if not 0.0 < self.validation_fraction < 1.0:
raise ValueError("`validation_fraction` must be between 0 and 1")
if self.minimum_validation_size <= 0:
raise ValueError("`minimum_validation_size` must be positive")
if self.energy_score_weight < 0.0 or self.force_score_weight < 0.0:
raise ValueError("score weights must be non-negative")
if self.energy_score_weight == 0.0 and self.force_score_weight == 0.0:
raise ValueError("at least one score weight must be positive")
if self.dense_cache_parameter_limit < 0:
raise ValueError("`dense_cache_parameter_limit` must be non-negative")
if self.prediction_batch_size <= 0:
raise ValueError("`prediction_batch_size` must be positive")
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON-friendly config metadata."""
return {
"seed": int(self.seed),
"alpha": float(self.alpha),
"estimate_subset_size": self.estimate_subset_size,
"stage_subset_sizes": list(self.stage_subset_sizes),
"candidate_multipliers": list(self.candidate_multipliers),
"refinement_multipliers": list(self.refinement_multipliers),
"top_k_per_stage": int(self.top_k_per_stage),
"validation_fraction": float(self.validation_fraction),
"minimum_validation_size": int(self.minimum_validation_size),
"energy_score_weight": float(self.energy_score_weight),
"force_score_weight": float(self.force_score_weight),
"refit_full": bool(self.refit_full),
"cache_directory": (
None if self.cache_directory is None else str(self.cache_directory)
),
"cache_mode": str(self.cache_mode),
"dense_cache_parameter_limit": int(self.dense_cache_parameter_limit),
"prediction_batch_size": int(self.prediction_batch_size),
"progress": bool(self.progress),
}
[docs]
@dataclass(frozen=True)
class RegularizationTrial:
"""One candidate fit and validation attempt."""
stage: str
candidate: RegularizationCandidate
training_size: int
validation_size: int
metrics: Mapping[str, float] = field(default_factory=dict)
score: float | None = None
status: str = "ok"
error: str | None = None
[docs]
def to_dict(self) -> dict[str, object]:
"""Return JSON-friendly trial metadata."""
return {
"stage": self.stage,
"candidate": self.candidate.to_dict(),
"training_size": int(self.training_size),
"validation_size": int(self.validation_size),
"metrics": {str(key): float(value) for key, value in self.metrics.items()},
"score": None if self.score is None else float(self.score),
"status": self.status,
"error": self.error,
}
[docs]
@dataclass(frozen=True)
class RegularizationSearchResult:
"""Result of a progressive regularization search."""
estimates: RidgeScaleEstimate
trials: tuple[RegularizationTrial, ...]
best_candidate: RegularizationCandidate
best_fitter_kwargs: Mapping[str, object]
final_model: UFPModel | None = None
final_fit_result: LinearFitResult | None = None
@property
def metadata(self) -> dict[str, object]:
"""Return JSON-friendly search metadata."""
return {
"estimates": self.estimates.to_dict(),
"trials": [trial.to_dict() for trial in self.trials],
"best_candidate": self.best_candidate.to_dict(),
"best_fitter_kwargs": _json_friendly_mapping(self.best_fitter_kwargs),
"final_fit": (
None
if self.final_fit_result is None
else {
"solver": self.final_fit_result.solver,
"objective": float(self.final_fit_result.objective),
"residual_norm": float(self.final_fit_result.residual_norm),
"n_rows": int(self.final_fit_result.n_rows),
"n_parameters": int(self.final_fit_result.n_parameters),
}
),
}
def _json_friendly_mapping(values: Mapping[str, object]) -> dict[str, object]:
"""Return a JSON-friendly shallow mapping."""
normalized: dict[str, object] = {}
for key, value in values.items():
if isinstance(value, torch.dtype):
normalized[str(key)] = str(value)
elif isinstance(value, torch.device):
normalized[str(key)] = str(value)
elif isinstance(value, Path):
normalized[str(key)] = str(value)
elif isinstance(value, (str, int, float, bool)) or value is None:
normalized[str(key)] = value
else:
normalized[str(key)] = repr(value)
return normalized
def _canonical_group(group: str) -> str:
"""Normalize parameter-layout regularization groups to tuner groups."""
if group in {"pair", "twobody"}:
return "twobody"
if group in {"onebody", "threebody"}:
return group
return "ridge"
def _subset_samples(
samples: Sequence[FitSample],
*,
subset_size: int | None,
seed: int,
) -> tuple[FitSample, ...]:
"""Return a deterministic sample subset."""
items = tuple(samples)
if subset_size is None or int(subset_size) >= len(items):
return items
rng = np.random.default_rng(int(seed))
indices = np.sort(rng.choice(len(items), size=int(subset_size), replace=False))
return tuple(items[int(index)] for index in indices)
def _build_problem_kwargs(fit_kwargs: Mapping[str, object] | None) -> dict[str, object]:
"""Extract ``LinearFitter.build_problem`` keyword arguments."""
supplied = {} if fit_kwargs is None else dict(fit_kwargs)
allowed = {"batch_size", "progress", "cache_directory", "cache_mode"}
return {key: value for key, value in supplied.items() if key in allowed}
def _numpy_pairs(pairs: object) -> np.ndarray:
"""Return neighbor-list pairs as a CPU numpy array."""
if isinstance(pairs, torch.Tensor):
return pairs.detach().cpu().numpy()
return np.asarray(pairs)
def _sample_interaction_diagnostics(
samples: Sequence[FitSample],
) -> dict[str, float]:
"""Return raw interaction-count diagnostics when samples carry them."""
diagnostics = {
"atom_count": float(sum(len(sample.atoms) for sample in samples)),
}
pair_count = 0
centered_triplet_count = 0
samples_with_neighbor_lists = 0
for sample in samples:
neighbor_list = sample.neighbor_list
if neighbor_list is None:
continue
samples_with_neighbor_lists += 1
pair_count += int(neighbor_list.n_pairs)
pairs = _numpy_pairs(neighbor_list.pairs)
if pairs.size == 0:
continue
centers = np.asarray(pairs[0], dtype=int)
degrees = np.bincount(centers, minlength=len(sample.atoms))
centered_triplet_count += int(np.sum(degrees * np.maximum(degrees - 1, 0)))
if samples_with_neighbor_lists:
diagnostics["samples_with_neighbor_lists"] = float(samples_with_neighbor_lists)
diagnostics["explicit_neighbor_pairs"] = float(pair_count)
diagnostics["estimated_centered_triplets"] = float(centered_triplet_count)
return diagnostics
[docs]
def estimate_linear_ridge_scales(
model: UFPModel,
samples: Sequence[FitSample],
*,
fitter_kwargs: Mapping[str, object] | None = None,
fit_kwargs: Mapping[str, object] | None = None,
subset_size: int | None = None,
seed: int = 0,
alpha: float = 1.0e-6,
) -> RidgeScaleEstimate:
"""
Estimate block-scale ridge weights from weighted least-squares design traces.
The suggested group weight is ``alpha * trace(A_g.T @ A_g) / n_params_g``.
"""
if alpha < 0.0:
raise ValueError("`alpha` must be non-negative")
subset = _subset_samples(samples, subset_size=subset_size, seed=seed)
if not subset:
raise ValueError("`samples` must contain at least one FitSample")
fitter = LinearFitter(model, **dict({} if fitter_kwargs is None else fitter_kwargs))
problem = fitter.build_problem(subset, **_build_problem_kwargs(fit_kwargs))
block_traces = problem.design_trace_by_block()
aggregates: dict[str, list[float]] = {group: [0.0, 0.0] for group in _RIDGE_GROUPS}
for solve_block in problem.layout.blocks:
trace = float(block_traces.get(solve_block.key, 0.0))
size = int(solve_block.size)
aggregates["ridge"][0] += size
aggregates["ridge"][1] += trace
layout_block = fitter.layout.block(int(solve_block.key))
group = _canonical_group(layout_block.regularization_group)
aggregates[group][0] += size
aggregates[group][1] += trace
estimates = []
for group in _RIDGE_GROUPS:
n_parameters = int(aggregates[group][0])
if n_parameters <= 0:
continue
design_trace = float(aggregates[group][1])
trace_per_parameter = design_trace / float(n_parameters)
estimates.append(
RidgeGroupEstimate(
group=group,
n_parameters=n_parameters,
design_trace=design_trace,
trace_per_parameter=trace_per_parameter,
suggested_ridge=float(alpha) * trace_per_parameter,
)
)
return RidgeScaleEstimate(
groups=tuple(estimates),
alpha=float(alpha),
sample_count=len(subset),
diagnostics=_sample_interaction_diagnostics(subset),
)
def _regularization_values(
seed: float,
multipliers: Sequence[float],
) -> tuple[float, ...]:
"""Return candidate ridge values for a seed scale."""
values = [0.0]
if seed > 0.0:
values.extend(float(seed) * float(multiplier) for multiplier in multipliers)
return tuple(dict.fromkeys(float(value) for value in values))
def _active_groups(estimates: RidgeScaleEstimate) -> tuple[str, ...]:
"""Return groups with active parameter blocks."""
by_group = estimates.by_group
return tuple(
group
for group in _RIDGE_GROUPS
if group in by_group and int(by_group[group].n_parameters) > 0
)
def _dedupe_candidates(
candidates: Sequence[RegularizationCandidate],
) -> tuple[RegularizationCandidate, ...]:
"""Remove duplicate candidates while preserving order."""
seen: set[tuple[float, float, float, float]] = set()
unique = []
for candidate in candidates:
key = (
float(candidate.ridge),
float(candidate.onebody_ridge),
float(candidate.twobody_ridge),
float(candidate.threebody_ridge),
)
if key in seen:
continue
seen.add(key)
unique.append(candidate)
return tuple(unique)
def _initial_candidates(
estimates: RidgeScaleEstimate,
config: RegularizationSearchConfig,
) -> tuple[RegularizationCandidate, ...]:
"""Generate baseline and group-wise sweep candidates."""
baseline = estimates.candidate()
by_group = estimates.by_group
candidates: list[RegularizationCandidate] = [baseline]
for group in _active_groups(estimates):
estimate = by_group[group]
values = _regularization_values(
estimate.suggested_ridge,
config.candidate_multipliers,
)
for value in values:
candidates.append(baseline.with_group_value(group, value))
return _dedupe_candidates(candidates)
def _refinement_candidates(
seeds: Sequence[RegularizationCandidate],
groups: Sequence[str],
config: RegularizationSearchConfig,
) -> tuple[RegularizationCandidate, ...]:
"""Generate local Cartesian refinement candidates around successful seeds."""
candidates: list[RegularizationCandidate] = list(seeds)
active_groups = tuple(groups)
if not active_groups:
return _dedupe_candidates(candidates)
for seed in seeds:
values_by_group = []
for group in active_groups:
value = seed.group_value(group)
values_by_group.append(
_regularization_values(value, config.refinement_multipliers)
)
for values in itertools.product(*values_by_group):
candidate = seed
for group, value in zip(active_groups, values, strict=True):
candidate = candidate.with_group_value(group, value)
candidates.append(candidate)
return _dedupe_candidates(candidates)
def _as_index_array(indices: Sequence[int] | np.ndarray) -> np.ndarray:
"""Return a one-dimensional integer index array."""
values = np.asarray(indices, dtype=int).reshape(-1)
if values.size == 0:
raise ValueError("index arrays must not be empty")
return values
def _representative_indices(
dataset: SupervisedAtomsDataset,
indices: Sequence[int] | np.ndarray,
*,
max_size: int | None,
seed: int,
) -> np.ndarray:
"""Choose deterministic atom-count/energy-stratified indices."""
values = _as_index_array(indices)
if max_size is None or int(max_size) >= int(values.size):
return np.sort(values)
requested = int(max_size)
rng = np.random.default_rng(int(seed))
if not hasattr(dataset, "sizes") or not hasattr(dataset, "energies"):
return np.sort(rng.choice(values, size=requested, replace=False))
sizes = np.asarray(dataset.sizes, dtype=float)[values]
energies = np.asarray(dataset.energies, dtype=float)[values]
safe_sizes = np.maximum(sizes, 1.0)
energy_per_atom = energies / safe_sizes
n_bins = max(1, min(6, int(np.sqrt(requested))))
size_bins = _quantile_bins(sizes, n_bins)
energy_bins = _quantile_bins(energy_per_atom, n_bins)
buckets: dict[tuple[int, int], list[int]] = {}
for position, index in enumerate(values):
bucket = (int(size_bins[position]), int(energy_bins[position]))
buckets.setdefault(bucket, []).append(int(index))
bucket_items = sorted(buckets.items())
for _, bucket_values in bucket_items:
rng.shuffle(bucket_values)
selected: list[int] = []
while len(selected) < requested and bucket_items:
next_items = []
for bucket, bucket_values in bucket_items:
if bucket_values and len(selected) < requested:
selected.append(bucket_values.pop())
if bucket_values:
next_items.append((bucket, bucket_values))
bucket_items = next_items
return np.sort(np.asarray(selected[:requested], dtype=int))
def _quantile_bins(values: np.ndarray, n_bins: int) -> np.ndarray:
"""Return stable quantile-bin labels for a one-dimensional array."""
if values.size == 0:
return np.zeros(0, dtype=int)
if n_bins <= 1 or np.all(values == values[0]):
return np.zeros(values.shape, dtype=int)
quantiles = np.linspace(0.0, 1.0, int(n_bins) + 1)[1:-1]
boundaries = np.quantile(values, quantiles)
return np.searchsorted(boundaries, values, side="right").astype(int)
def _resolve_training_validation_indices(
dataset: SupervisedAtomsDataset,
*,
training_indices: Sequence[int] | np.ndarray | None,
validation_indices: Sequence[int] | np.ndarray | None,
config: RegularizationSearchConfig,
) -> tuple[np.ndarray, np.ndarray]:
"""Resolve tuning train/validation indices, carving validation if needed."""
training = _as_index_array(
dataset.training_indices if training_indices is None else training_indices
)
if validation_indices is not None:
return np.sort(training), np.sort(_as_index_array(validation_indices))
dataset_validation = np.asarray(dataset.validation_indices, dtype=int).reshape(-1)
if dataset_validation.size:
return np.sort(training), np.sort(dataset_validation)
if training.size == 1:
return np.sort(training), np.sort(training)
validation_size = max(
int(config.minimum_validation_size),
int(round(float(training.size) * float(config.validation_fraction))),
)
validation_size = min(validation_size, int(training.size) - 1)
validation = _representative_indices(
dataset,
training,
max_size=validation_size,
seed=int(config.seed) + 17,
)
validation_set = {int(index) for index in validation.tolist()}
retained_training = np.asarray(
[int(index) for index in training.tolist() if int(index) not in validation_set],
dtype=int,
)
return np.sort(retained_training), np.sort(validation)
def _candidate_fitter_kwargs(
base: Mapping[str, object] | None,
candidate: RegularizationCandidate,
) -> dict[str, object]:
"""Merge caller fitter kwargs with candidate ridge settings."""
kwargs = {} if base is None else dict(base)
kwargs.update(candidate.as_fitter_kwargs())
return kwargs
def _target_weights_are_uniform(samples: Sequence[FitSample]) -> bool:
"""Return whether normal-equation caching can reuse split target weights."""
for weight_name in ("energy_weight", "force_weight", "per_atom_weight"):
weights = [float(getattr(sample, weight_name)) for sample in samples]
if weights and any(not np.isclose(weight, weights[0]) for weight in weights):
return False
return True
def _fit_kwargs_for_stage(
base: Mapping[str, object] | None,
*,
config: RegularizationSearchConfig,
stage: str,
fitter: LinearFitter,
samples: Sequence[FitSample],
) -> dict[str, object]:
"""Return candidate-fit kwargs, adding safe cache reuse when configured."""
kwargs = {} if base is None else dict(base)
kwargs.pop("cg_checkpoint_path", None)
kwargs.pop("cg_resume", None)
if config.progress:
kwargs.setdefault("progress", True)
if config.cache_directory is None:
return kwargs
cache_directory = Path(config.cache_directory) / stage
if (
fitter.solver == "normal_equation_direct"
and fitter._selected_size() <= int(config.dense_cache_parameter_limit)
and _target_weights_are_uniform(samples)
):
kwargs.setdefault("normal_equation_cache", True)
kwargs.setdefault(
"normal_equation_cache_directory",
cache_directory / "normal_equations",
)
kwargs.setdefault("normal_equation_cache_mode", config.cache_mode)
else:
kwargs.setdefault("cache_directory", cache_directory / "assembled")
kwargs.setdefault("cache_mode", config.cache_mode)
return kwargs
def _scalar_metrics(metrics: Mapping[str, object]) -> dict[str, float]:
"""Keep only scalar validation metrics."""
scalars: dict[str, float] = {}
for key, value in metrics.items():
array = np.asarray(value)
if array.ndim == 0:
scalars[str(key)] = float(array.item())
return scalars
def _evaluate_candidate(
*,
stage: str,
candidate: RegularizationCandidate,
model_factory: Callable[[], UFPModel],
dataset: SupervisedAtomsDataset,
train_samples: Sequence[FitSample],
validation_indices: np.ndarray,
config: RegularizationSearchConfig,
fitter_kwargs: Mapping[str, object] | None,
fit_kwargs: Mapping[str, object] | None,
) -> RegularizationTrial:
"""Fit and validate one regularization candidate."""
try:
model = model_factory()
if not isinstance(model, UFPModel):
raise TypeError("`model_factory` must return a UFPModel")
merged_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, candidate)
fitter = LinearFitter(model, **merged_fitter_kwargs)
merged_fit_kwargs = _fit_kwargs_for_stage(
fit_kwargs,
config=config,
stage=stage,
fitter=fitter,
samples=train_samples,
)
fitter.fit(train_samples, **merged_fit_kwargs)
device = cast(torch.device | str | None, merged_fitter_kwargs.get("device"))
dtype = cast(torch.dtype | None, merged_fitter_kwargs.get("dtype"))
metrics = prediction_metrics_for_split(
model,
dataset,
validation_indices,
batch_size=int(config.prediction_batch_size),
device=device,
dtype=dtype,
progress=bool(config.progress),
)
return RegularizationTrial(
stage=stage,
candidate=candidate,
training_size=len(train_samples),
validation_size=int(validation_indices.size),
metrics=_scalar_metrics(metrics),
status="ok",
)
except Exception as exc: # pragma: no cover - exercised by behavior tests.
return RegularizationTrial(
stage=stage,
candidate=candidate,
training_size=len(train_samples),
validation_size=int(validation_indices.size),
status="error",
error=f"{type(exc).__name__}: {exc}",
)
def _metric_normalizers(
trials: Sequence[RegularizationTrial],
baseline: RegularizationCandidate,
) -> dict[str, float]:
"""Return per-metric baseline normalizers for one stage."""
successful = [trial for trial in trials if trial.status == "ok"]
if not successful:
return {}
baseline_trial = next(
(trial for trial in successful if trial.candidate == baseline),
successful[0],
)
normalizers = {}
for key in (_ENERGY_RMSE_KEY, _FORCE_RMSE_KEY):
value = baseline_trial.metrics.get(key)
if value is not None:
normalizers[key] = max(abs(float(value)), 1.0e-12)
return normalizers
def _score_metrics(
metrics: Mapping[str, float],
normalizers: Mapping[str, float],
config: RegularizationSearchConfig,
) -> float:
"""Return normalized scalar score for validation metrics."""
weighted = 0.0
total_weight = 0.0
metric_weights = {
_ENERGY_RMSE_KEY: float(config.energy_score_weight),
_FORCE_RMSE_KEY: float(config.force_score_weight),
}
for key, weight in metric_weights.items():
if weight <= 0.0 or key not in metrics:
continue
normalizer = max(float(normalizers.get(key, 1.0)), 1.0e-12)
weighted += weight * float(metrics[key]) / normalizer
total_weight += weight
if total_weight == 0.0:
return float("inf")
return weighted / total_weight
def _score_stage_trials(
trials: Sequence[RegularizationTrial],
*,
baseline: RegularizationCandidate,
config: RegularizationSearchConfig,
) -> tuple[RegularizationTrial, ...]:
"""Attach normalized stage scores to successful trials."""
normalizers = _metric_normalizers(trials, baseline)
scored = []
for trial in trials:
if trial.status != "ok":
scored.append(trial)
continue
scored.append(
replace(
trial,
score=_score_metrics(trial.metrics, normalizers, config),
)
)
return tuple(scored)
def _trial_rank_key(trial: RegularizationTrial) -> tuple[float, float, float, float]:
"""Return deterministic candidate ranking key."""
score = float("inf") if trial.score is None else float(trial.score)
force = float(trial.metrics.get(_FORCE_RMSE_KEY, float("inf")))
energy = float(trial.metrics.get(_ENERGY_RMSE_KEY, float("inf")))
return score, force, energy, trial.candidate.total_ridge()
def _successful_trials(
trials: Sequence[RegularizationTrial],
) -> tuple[RegularizationTrial, ...]:
"""Return successful trials sorted by score."""
return tuple(
sorted(
(trial for trial in trials if trial.status == "ok"),
key=_trial_rank_key,
)
)
def _evaluate_stage(
*,
stage: str,
candidates: Sequence[RegularizationCandidate],
baseline: RegularizationCandidate,
model_factory: Callable[[], UFPModel],
dataset: SupervisedAtomsDataset,
training_indices: np.ndarray,
validation_indices: np.ndarray,
sample_weights: dict[int, float] | None,
energy_weight: float,
force_weight: float,
config: RegularizationSearchConfig,
fitter_kwargs: Mapping[str, object] | None,
fit_kwargs: Mapping[str, object] | None,
) -> tuple[RegularizationTrial, ...]:
"""Evaluate one progressive tuning stage."""
train_samples = fit_samples_from_dataset(
dataset,
indices=training_indices,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
)
trials = [
_evaluate_candidate(
stage=stage,
candidate=candidate,
model_factory=model_factory,
dataset=dataset,
train_samples=train_samples,
validation_indices=validation_indices,
config=config,
fitter_kwargs=fitter_kwargs,
fit_kwargs=fit_kwargs,
)
for candidate in candidates
]
return _score_stage_trials(trials, baseline=baseline, config=config)
def _best_trials_for_next_stage(
trials: Sequence[RegularizationTrial],
*,
top_k: int,
) -> tuple[RegularizationTrial, ...]:
"""Return successful trials retained for the next stage."""
return _successful_trials(trials)[: int(top_k)]
def _final_fit(
*,
candidate: RegularizationCandidate,
model_factory: Callable[[], UFPModel],
dataset: SupervisedAtomsDataset,
training_indices: np.ndarray,
sample_weights: dict[int, float] | None,
energy_weight: float,
force_weight: float,
fitter_kwargs: Mapping[str, object] | None,
fit_kwargs: Mapping[str, object] | None,
config: RegularizationSearchConfig,
) -> tuple[UFPModel, LinearFitResult]:
"""Fit the selected candidate on the full tuning-training split."""
model = model_factory()
if not isinstance(model, UFPModel):
raise TypeError("`model_factory` must return a UFPModel")
samples = fit_samples_from_dataset(
dataset,
indices=training_indices,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
)
merged_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, candidate)
fitter = LinearFitter(model, **merged_fitter_kwargs)
merged_fit_kwargs = _fit_kwargs_for_stage(
fit_kwargs,
config=config,
stage="final_refit",
fitter=fitter,
samples=samples,
)
return model, fitter.fit(samples, **merged_fit_kwargs)
[docs]
def tune_linear_regularization(
model_factory: Callable[[], UFPModel],
dataset: SupervisedAtomsDataset,
*,
training_indices: Sequence[int] | np.ndarray | None = None,
validation_indices: Sequence[int] | np.ndarray | None = None,
sample_weights: dict[int, float] | None = None,
energy_weight: float = 1.0,
force_weight: float = 1.0,
config: RegularizationSearchConfig | None = None,
fitter_kwargs: Mapping[str, object] | None = None,
fit_kwargs: Mapping[str, object] | None = None,
) -> RegularizationSearchResult:
"""Tune linear ridge weights with deterministic progressive subsets."""
resolved_config = RegularizationSearchConfig() if config is None else config
train_indices, validation_pool = _resolve_training_validation_indices(
dataset,
training_indices=training_indices,
validation_indices=validation_indices,
config=resolved_config,
)
estimate_training = _representative_indices(
dataset,
train_indices,
max_size=resolved_config.estimate_subset_size,
seed=int(resolved_config.seed),
)
estimate_samples = fit_samples_from_dataset(
dataset,
indices=estimate_training,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
)
estimates = estimate_linear_ridge_scales(
model_factory(),
estimate_samples,
fitter_kwargs=fitter_kwargs,
fit_kwargs=fit_kwargs,
subset_size=None,
seed=int(resolved_config.seed),
alpha=float(resolved_config.alpha),
)
baseline = estimates.candidate()
active_groups = _active_groups(estimates)
candidates = _initial_candidates(estimates, resolved_config)
all_trials: list[RegularizationTrial] = []
retained = candidates
last_successes: tuple[RegularizationTrial, ...] = ()
for stage_number, subset_size in enumerate(resolved_config.stage_subset_sizes, 1):
stage = f"subset_{int(subset_size)}"
stage_training = _representative_indices(
dataset,
train_indices,
max_size=int(subset_size),
seed=int(resolved_config.seed) + stage_number,
)
stage_validation = _representative_indices(
dataset,
validation_pool,
max_size=int(subset_size),
seed=int(resolved_config.seed) + 100 + stage_number,
)
trials = _evaluate_stage(
stage=stage,
candidates=retained,
baseline=baseline,
model_factory=model_factory,
dataset=dataset,
training_indices=stage_training,
validation_indices=stage_validation,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
config=resolved_config,
fitter_kwargs=fitter_kwargs,
fit_kwargs=fit_kwargs,
)
all_trials.extend(trials)
last_successes = _best_trials_for_next_stage(
trials,
top_k=resolved_config.top_k_per_stage,
)
if not last_successes:
continue
retained_candidates = tuple(trial.candidate for trial in last_successes)
retained = _refinement_candidates(
retained_candidates,
active_groups,
resolved_config,
)[
: max(
int(resolved_config.top_k_per_stage) * 8,
int(resolved_config.top_k_per_stage),
)
]
if last_successes:
final_candidates = tuple(trial.candidate for trial in last_successes)
else:
final_candidates = candidates
final_trials = _evaluate_stage(
stage="full_validation",
candidates=final_candidates,
baseline=baseline,
model_factory=model_factory,
dataset=dataset,
training_indices=train_indices,
validation_indices=validation_pool,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
config=resolved_config,
fitter_kwargs=fitter_kwargs,
fit_kwargs=fit_kwargs,
)
all_trials.extend(final_trials)
final_successes = _successful_trials(final_trials)
best_trial_pool = (
final_successes or last_successes or _successful_trials(all_trials)
)
if not best_trial_pool:
raise RuntimeError("regularization search did not produce a successful trial")
best_candidate = best_trial_pool[0].candidate
best_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, best_candidate)
final_model = None
final_fit_result = None
if resolved_config.refit_full:
final_model, final_fit_result = _final_fit(
candidate=best_candidate,
model_factory=model_factory,
dataset=dataset,
training_indices=train_indices,
sample_weights=sample_weights,
energy_weight=energy_weight,
force_weight=force_weight,
fitter_kwargs=fitter_kwargs,
fit_kwargs=fit_kwargs,
config=resolved_config,
)
return RegularizationSearchResult(
estimates=estimates,
trials=tuple(all_trials),
best_candidate=best_candidate,
best_fitter_kwargs=best_fitter_kwargs,
final_model=final_model,
final_fit_result=final_fit_result,
)
__all__ = [
"RegularizationCandidate",
"RegularizationSearchConfig",
"RegularizationSearchResult",
"RegularizationTrial",
"RidgeGroupEstimate",
"RidgeScaleEstimate",
"estimate_linear_ridge_scales",
"tune_linear_regularization",
]