Source code for ufp.workflows.regularization

"""Progressive regularization tuning for linear least-squares workflows."""

from __future__ import annotations

import itertools
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import cast

import numpy as np
import torch

from ufp.leastsquares import FitSample, LinearFitResult, LinearFitter
from ufp.terms import UFPModel
from ufp.workflows.data import SupervisedAtomsDataset
from ufp.workflows.predictions import (
    fit_samples_from_dataset,
    prediction_metrics_for_split,
)


_RIDGE_GROUPS = ("ridge", "onebody", "twobody", "threebody")
_ENERGY_RMSE_KEY = "rmse_energy_mev_per_atom"
_FORCE_RMSE_KEY = "rmse_force_mev_per_angstrom"


[docs] @dataclass(frozen=True) class RidgeGroupEstimate: """Data-scale ridge estimate for one coefficient regularization group.""" group: str n_parameters: int design_trace: float trace_per_parameter: float suggested_ridge: float
[docs] def to_dict(self) -> dict[str, object]: """Return a JSON-friendly representation.""" return { "group": self.group, "n_parameters": int(self.n_parameters), "design_trace": float(self.design_trace), "trace_per_parameter": float(self.trace_per_parameter), "suggested_ridge": float(self.suggested_ridge), }
[docs] @dataclass(frozen=True) class RidgeScaleEstimate: """Block-scale ridge suggestions derived from least-squares design statistics.""" groups: tuple[RidgeGroupEstimate, ...] alpha: float sample_count: int diagnostics: Mapping[str, float] = field(default_factory=dict) @property def by_group(self) -> dict[str, RidgeGroupEstimate]: """Return group estimates keyed by group name.""" return {estimate.group: estimate for estimate in self.groups}
[docs] def candidate(self) -> "RegularizationCandidate": """Return the direct ridge candidate implied by this estimate.""" by_group = self.by_group ridge = by_group.get("ridge") fallback = 0.0 if ridge is None else float(ridge.suggested_ridge) onebody = by_group.get("onebody", ridge) twobody = by_group.get("twobody", ridge) threebody = by_group.get("threebody", ridge) return RegularizationCandidate( ridge=fallback, onebody_ridge=( fallback if onebody is None else float(onebody.suggested_ridge) ), twobody_ridge=( fallback if twobody is None else float(twobody.suggested_ridge) ), threebody_ridge=( fallback if threebody is None else float(threebody.suggested_ridge) ), )
[docs] def to_dict(self) -> dict[str, object]: """Return JSON-friendly estimate metadata.""" return { "alpha": float(self.alpha), "sample_count": int(self.sample_count), "diagnostics": { str(key): float(value) for key, value in self.diagnostics.items() }, "groups": [group.to_dict() for group in self.groups], }
[docs] @dataclass(frozen=True) class RegularizationCandidate: """Concrete ridge settings passed to ``LinearFitter``.""" ridge: float = 0.0 onebody_ridge: float = 0.0 twobody_ridge: float = 0.0 threebody_ridge: float = 0.0 def __post_init__(self) -> None: """Validate candidate weights.""" for name in _RIDGE_GROUPS: value = float(getattr(self, name if name == "ridge" else f"{name}_ridge")) if value < 0.0: raise ValueError(f"`{name}` ridge value must be non-negative") object.__setattr__( self, name if name == "ridge" else f"{name}_ridge", value, )
[docs] def as_fitter_kwargs(self) -> dict[str, float]: """Return ridge keyword arguments accepted by ``LinearFitter``.""" return { "ridge": float(self.ridge), "onebody_ridge": float(self.onebody_ridge), "twobody_ridge": float(self.twobody_ridge), "threebody_ridge": float(self.threebody_ridge), }
[docs] def group_value(self, group: str) -> float: """Return the ridge value for one canonical group.""" if group == "ridge": return float(self.ridge) if group == "onebody": return float(self.onebody_ridge) if group == "twobody": return float(self.twobody_ridge) if group == "threebody": return float(self.threebody_ridge) raise ValueError(f"unsupported regularization group {group!r}")
[docs] def with_group_value(self, group: str, value: float) -> "RegularizationCandidate": """Return this candidate with one group changed.""" kwargs = self.as_fitter_kwargs() key = "ridge" if group == "ridge" else f"{group}_ridge" if key not in kwargs: raise ValueError(f"unsupported regularization group {group!r}") kwargs[key] = float(value) return RegularizationCandidate(**kwargs)
[docs] def total_ridge(self) -> float: """Return a scalar used for deterministic tie-breaking.""" return ( float(self.ridge) + float(self.onebody_ridge) + float(self.twobody_ridge) + float(self.threebody_ridge) )
[docs] def to_dict(self) -> dict[str, float]: """Return a JSON-friendly representation.""" return self.as_fitter_kwargs()
[docs] @dataclass(frozen=True) class RegularizationSearchConfig: """Options for progressive regularization tuning.""" seed: int = 0 alpha: float = 1.0e-6 estimate_subset_size: int | None = 64 stage_subset_sizes: tuple[int, ...] = (64, 256) candidate_multipliers: tuple[float, ...] = ( 1.0e-3, 1.0e-2, 1.0e-1, 1.0, 10.0, 100.0, 1.0e3, ) refinement_multipliers: tuple[float, ...] = (1.0e-1, 1.0, 10.0) top_k_per_stage: int = 5 validation_fraction: float = 0.2 minimum_validation_size: int = 1 energy_score_weight: float = 1.0 force_score_weight: float = 1.0 refit_full: bool = True cache_directory: Path | str | None = None cache_mode: str = "auto" dense_cache_parameter_limit: int = 20_000 prediction_batch_size: int = 64 progress: bool = False def __post_init__(self) -> None: """Validate search settings and normalize tuple-like options.""" object.__setattr__( self, "stage_subset_sizes", tuple(int(size) for size in self.stage_subset_sizes), ) object.__setattr__( self, "candidate_multipliers", tuple(float(value) for value in self.candidate_multipliers), ) object.__setattr__( self, "refinement_multipliers", tuple(float(value) for value in self.refinement_multipliers), ) if self.alpha < 0.0: raise ValueError("`alpha` must be non-negative") if self.estimate_subset_size is not None and self.estimate_subset_size <= 0: raise ValueError("`estimate_subset_size` must be positive") if any(size <= 0 for size in self.stage_subset_sizes): raise ValueError("`stage_subset_sizes` entries must be positive") if any(value < 0.0 for value in self.candidate_multipliers): raise ValueError("`candidate_multipliers` entries must be non-negative") if any(value < 0.0 for value in self.refinement_multipliers): raise ValueError("`refinement_multipliers` entries must be non-negative") if self.top_k_per_stage <= 0: raise ValueError("`top_k_per_stage` must be positive") if not 0.0 < self.validation_fraction < 1.0: raise ValueError("`validation_fraction` must be between 0 and 1") if self.minimum_validation_size <= 0: raise ValueError("`minimum_validation_size` must be positive") if self.energy_score_weight < 0.0 or self.force_score_weight < 0.0: raise ValueError("score weights must be non-negative") if self.energy_score_weight == 0.0 and self.force_score_weight == 0.0: raise ValueError("at least one score weight must be positive") if self.dense_cache_parameter_limit < 0: raise ValueError("`dense_cache_parameter_limit` must be non-negative") if self.prediction_batch_size <= 0: raise ValueError("`prediction_batch_size` must be positive")
[docs] def to_dict(self) -> dict[str, object]: """Return JSON-friendly config metadata.""" return { "seed": int(self.seed), "alpha": float(self.alpha), "estimate_subset_size": self.estimate_subset_size, "stage_subset_sizes": list(self.stage_subset_sizes), "candidate_multipliers": list(self.candidate_multipliers), "refinement_multipliers": list(self.refinement_multipliers), "top_k_per_stage": int(self.top_k_per_stage), "validation_fraction": float(self.validation_fraction), "minimum_validation_size": int(self.minimum_validation_size), "energy_score_weight": float(self.energy_score_weight), "force_score_weight": float(self.force_score_weight), "refit_full": bool(self.refit_full), "cache_directory": ( None if self.cache_directory is None else str(self.cache_directory) ), "cache_mode": str(self.cache_mode), "dense_cache_parameter_limit": int(self.dense_cache_parameter_limit), "prediction_batch_size": int(self.prediction_batch_size), "progress": bool(self.progress), }
[docs] @dataclass(frozen=True) class RegularizationTrial: """One candidate fit and validation attempt.""" stage: str candidate: RegularizationCandidate training_size: int validation_size: int metrics: Mapping[str, float] = field(default_factory=dict) score: float | None = None status: str = "ok" error: str | None = None
[docs] def to_dict(self) -> dict[str, object]: """Return JSON-friendly trial metadata.""" return { "stage": self.stage, "candidate": self.candidate.to_dict(), "training_size": int(self.training_size), "validation_size": int(self.validation_size), "metrics": {str(key): float(value) for key, value in self.metrics.items()}, "score": None if self.score is None else float(self.score), "status": self.status, "error": self.error, }
[docs] @dataclass(frozen=True) class RegularizationSearchResult: """Result of a progressive regularization search.""" estimates: RidgeScaleEstimate trials: tuple[RegularizationTrial, ...] best_candidate: RegularizationCandidate best_fitter_kwargs: Mapping[str, object] final_model: UFPModel | None = None final_fit_result: LinearFitResult | None = None @property def metadata(self) -> dict[str, object]: """Return JSON-friendly search metadata.""" return { "estimates": self.estimates.to_dict(), "trials": [trial.to_dict() for trial in self.trials], "best_candidate": self.best_candidate.to_dict(), "best_fitter_kwargs": _json_friendly_mapping(self.best_fitter_kwargs), "final_fit": ( None if self.final_fit_result is None else { "solver": self.final_fit_result.solver, "objective": float(self.final_fit_result.objective), "residual_norm": float(self.final_fit_result.residual_norm), "n_rows": int(self.final_fit_result.n_rows), "n_parameters": int(self.final_fit_result.n_parameters), } ), }
def _json_friendly_mapping(values: Mapping[str, object]) -> dict[str, object]: """Return a JSON-friendly shallow mapping.""" normalized: dict[str, object] = {} for key, value in values.items(): if isinstance(value, torch.dtype): normalized[str(key)] = str(value) elif isinstance(value, torch.device): normalized[str(key)] = str(value) elif isinstance(value, Path): normalized[str(key)] = str(value) elif isinstance(value, (str, int, float, bool)) or value is None: normalized[str(key)] = value else: normalized[str(key)] = repr(value) return normalized def _canonical_group(group: str) -> str: """Normalize parameter-layout regularization groups to tuner groups.""" if group in {"pair", "twobody"}: return "twobody" if group in {"onebody", "threebody"}: return group return "ridge" def _subset_samples( samples: Sequence[FitSample], *, subset_size: int | None, seed: int, ) -> tuple[FitSample, ...]: """Return a deterministic sample subset.""" items = tuple(samples) if subset_size is None or int(subset_size) >= len(items): return items rng = np.random.default_rng(int(seed)) indices = np.sort(rng.choice(len(items), size=int(subset_size), replace=False)) return tuple(items[int(index)] for index in indices) def _build_problem_kwargs(fit_kwargs: Mapping[str, object] | None) -> dict[str, object]: """Extract ``LinearFitter.build_problem`` keyword arguments.""" supplied = {} if fit_kwargs is None else dict(fit_kwargs) allowed = {"batch_size", "progress", "cache_directory", "cache_mode"} return {key: value for key, value in supplied.items() if key in allowed} def _numpy_pairs(pairs: object) -> np.ndarray: """Return neighbor-list pairs as a CPU numpy array.""" if isinstance(pairs, torch.Tensor): return pairs.detach().cpu().numpy() return np.asarray(pairs) def _sample_interaction_diagnostics( samples: Sequence[FitSample], ) -> dict[str, float]: """Return raw interaction-count diagnostics when samples carry them.""" diagnostics = { "atom_count": float(sum(len(sample.atoms) for sample in samples)), } pair_count = 0 centered_triplet_count = 0 samples_with_neighbor_lists = 0 for sample in samples: neighbor_list = sample.neighbor_list if neighbor_list is None: continue samples_with_neighbor_lists += 1 pair_count += int(neighbor_list.n_pairs) pairs = _numpy_pairs(neighbor_list.pairs) if pairs.size == 0: continue centers = np.asarray(pairs[0], dtype=int) degrees = np.bincount(centers, minlength=len(sample.atoms)) centered_triplet_count += int(np.sum(degrees * np.maximum(degrees - 1, 0))) if samples_with_neighbor_lists: diagnostics["samples_with_neighbor_lists"] = float(samples_with_neighbor_lists) diagnostics["explicit_neighbor_pairs"] = float(pair_count) diagnostics["estimated_centered_triplets"] = float(centered_triplet_count) return diagnostics
[docs] def estimate_linear_ridge_scales( model: UFPModel, samples: Sequence[FitSample], *, fitter_kwargs: Mapping[str, object] | None = None, fit_kwargs: Mapping[str, object] | None = None, subset_size: int | None = None, seed: int = 0, alpha: float = 1.0e-6, ) -> RidgeScaleEstimate: """ Estimate block-scale ridge weights from weighted least-squares design traces. The suggested group weight is ``alpha * trace(A_g.T @ A_g) / n_params_g``. """ if alpha < 0.0: raise ValueError("`alpha` must be non-negative") subset = _subset_samples(samples, subset_size=subset_size, seed=seed) if not subset: raise ValueError("`samples` must contain at least one FitSample") fitter = LinearFitter(model, **dict({} if fitter_kwargs is None else fitter_kwargs)) problem = fitter.build_problem(subset, **_build_problem_kwargs(fit_kwargs)) block_traces = problem.design_trace_by_block() aggregates: dict[str, list[float]] = {group: [0.0, 0.0] for group in _RIDGE_GROUPS} for solve_block in problem.layout.blocks: trace = float(block_traces.get(solve_block.key, 0.0)) size = int(solve_block.size) aggregates["ridge"][0] += size aggregates["ridge"][1] += trace layout_block = fitter.layout.block(int(solve_block.key)) group = _canonical_group(layout_block.regularization_group) aggregates[group][0] += size aggregates[group][1] += trace estimates = [] for group in _RIDGE_GROUPS: n_parameters = int(aggregates[group][0]) if n_parameters <= 0: continue design_trace = float(aggregates[group][1]) trace_per_parameter = design_trace / float(n_parameters) estimates.append( RidgeGroupEstimate( group=group, n_parameters=n_parameters, design_trace=design_trace, trace_per_parameter=trace_per_parameter, suggested_ridge=float(alpha) * trace_per_parameter, ) ) return RidgeScaleEstimate( groups=tuple(estimates), alpha=float(alpha), sample_count=len(subset), diagnostics=_sample_interaction_diagnostics(subset), )
def _regularization_values( seed: float, multipliers: Sequence[float], ) -> tuple[float, ...]: """Return candidate ridge values for a seed scale.""" values = [0.0] if seed > 0.0: values.extend(float(seed) * float(multiplier) for multiplier in multipliers) return tuple(dict.fromkeys(float(value) for value in values)) def _active_groups(estimates: RidgeScaleEstimate) -> tuple[str, ...]: """Return groups with active parameter blocks.""" by_group = estimates.by_group return tuple( group for group in _RIDGE_GROUPS if group in by_group and int(by_group[group].n_parameters) > 0 ) def _dedupe_candidates( candidates: Sequence[RegularizationCandidate], ) -> tuple[RegularizationCandidate, ...]: """Remove duplicate candidates while preserving order.""" seen: set[tuple[float, float, float, float]] = set() unique = [] for candidate in candidates: key = ( float(candidate.ridge), float(candidate.onebody_ridge), float(candidate.twobody_ridge), float(candidate.threebody_ridge), ) if key in seen: continue seen.add(key) unique.append(candidate) return tuple(unique) def _initial_candidates( estimates: RidgeScaleEstimate, config: RegularizationSearchConfig, ) -> tuple[RegularizationCandidate, ...]: """Generate baseline and group-wise sweep candidates.""" baseline = estimates.candidate() by_group = estimates.by_group candidates: list[RegularizationCandidate] = [baseline] for group in _active_groups(estimates): estimate = by_group[group] values = _regularization_values( estimate.suggested_ridge, config.candidate_multipliers, ) for value in values: candidates.append(baseline.with_group_value(group, value)) return _dedupe_candidates(candidates) def _refinement_candidates( seeds: Sequence[RegularizationCandidate], groups: Sequence[str], config: RegularizationSearchConfig, ) -> tuple[RegularizationCandidate, ...]: """Generate local Cartesian refinement candidates around successful seeds.""" candidates: list[RegularizationCandidate] = list(seeds) active_groups = tuple(groups) if not active_groups: return _dedupe_candidates(candidates) for seed in seeds: values_by_group = [] for group in active_groups: value = seed.group_value(group) values_by_group.append( _regularization_values(value, config.refinement_multipliers) ) for values in itertools.product(*values_by_group): candidate = seed for group, value in zip(active_groups, values, strict=True): candidate = candidate.with_group_value(group, value) candidates.append(candidate) return _dedupe_candidates(candidates) def _as_index_array(indices: Sequence[int] | np.ndarray) -> np.ndarray: """Return a one-dimensional integer index array.""" values = np.asarray(indices, dtype=int).reshape(-1) if values.size == 0: raise ValueError("index arrays must not be empty") return values def _representative_indices( dataset: SupervisedAtomsDataset, indices: Sequence[int] | np.ndarray, *, max_size: int | None, seed: int, ) -> np.ndarray: """Choose deterministic atom-count/energy-stratified indices.""" values = _as_index_array(indices) if max_size is None or int(max_size) >= int(values.size): return np.sort(values) requested = int(max_size) rng = np.random.default_rng(int(seed)) if not hasattr(dataset, "sizes") or not hasattr(dataset, "energies"): return np.sort(rng.choice(values, size=requested, replace=False)) sizes = np.asarray(dataset.sizes, dtype=float)[values] energies = np.asarray(dataset.energies, dtype=float)[values] safe_sizes = np.maximum(sizes, 1.0) energy_per_atom = energies / safe_sizes n_bins = max(1, min(6, int(np.sqrt(requested)))) size_bins = _quantile_bins(sizes, n_bins) energy_bins = _quantile_bins(energy_per_atom, n_bins) buckets: dict[tuple[int, int], list[int]] = {} for position, index in enumerate(values): bucket = (int(size_bins[position]), int(energy_bins[position])) buckets.setdefault(bucket, []).append(int(index)) bucket_items = sorted(buckets.items()) for _, bucket_values in bucket_items: rng.shuffle(bucket_values) selected: list[int] = [] while len(selected) < requested and bucket_items: next_items = [] for bucket, bucket_values in bucket_items: if bucket_values and len(selected) < requested: selected.append(bucket_values.pop()) if bucket_values: next_items.append((bucket, bucket_values)) bucket_items = next_items return np.sort(np.asarray(selected[:requested], dtype=int)) def _quantile_bins(values: np.ndarray, n_bins: int) -> np.ndarray: """Return stable quantile-bin labels for a one-dimensional array.""" if values.size == 0: return np.zeros(0, dtype=int) if n_bins <= 1 or np.all(values == values[0]): return np.zeros(values.shape, dtype=int) quantiles = np.linspace(0.0, 1.0, int(n_bins) + 1)[1:-1] boundaries = np.quantile(values, quantiles) return np.searchsorted(boundaries, values, side="right").astype(int) def _resolve_training_validation_indices( dataset: SupervisedAtomsDataset, *, training_indices: Sequence[int] | np.ndarray | None, validation_indices: Sequence[int] | np.ndarray | None, config: RegularizationSearchConfig, ) -> tuple[np.ndarray, np.ndarray]: """Resolve tuning train/validation indices, carving validation if needed.""" training = _as_index_array( dataset.training_indices if training_indices is None else training_indices ) if validation_indices is not None: return np.sort(training), np.sort(_as_index_array(validation_indices)) dataset_validation = np.asarray(dataset.validation_indices, dtype=int).reshape(-1) if dataset_validation.size: return np.sort(training), np.sort(dataset_validation) if training.size == 1: return np.sort(training), np.sort(training) validation_size = max( int(config.minimum_validation_size), int(round(float(training.size) * float(config.validation_fraction))), ) validation_size = min(validation_size, int(training.size) - 1) validation = _representative_indices( dataset, training, max_size=validation_size, seed=int(config.seed) + 17, ) validation_set = {int(index) for index in validation.tolist()} retained_training = np.asarray( [int(index) for index in training.tolist() if int(index) not in validation_set], dtype=int, ) return np.sort(retained_training), np.sort(validation) def _candidate_fitter_kwargs( base: Mapping[str, object] | None, candidate: RegularizationCandidate, ) -> dict[str, object]: """Merge caller fitter kwargs with candidate ridge settings.""" kwargs = {} if base is None else dict(base) kwargs.update(candidate.as_fitter_kwargs()) return kwargs def _target_weights_are_uniform(samples: Sequence[FitSample]) -> bool: """Return whether normal-equation caching can reuse split target weights.""" for weight_name in ("energy_weight", "force_weight", "per_atom_weight"): weights = [float(getattr(sample, weight_name)) for sample in samples] if weights and any(not np.isclose(weight, weights[0]) for weight in weights): return False return True def _fit_kwargs_for_stage( base: Mapping[str, object] | None, *, config: RegularizationSearchConfig, stage: str, fitter: LinearFitter, samples: Sequence[FitSample], ) -> dict[str, object]: """Return candidate-fit kwargs, adding safe cache reuse when configured.""" kwargs = {} if base is None else dict(base) kwargs.pop("cg_checkpoint_path", None) kwargs.pop("cg_resume", None) if config.progress: kwargs.setdefault("progress", True) if config.cache_directory is None: return kwargs cache_directory = Path(config.cache_directory) / stage if ( fitter.solver == "normal_equation_direct" and fitter._selected_size() <= int(config.dense_cache_parameter_limit) and _target_weights_are_uniform(samples) ): kwargs.setdefault("normal_equation_cache", True) kwargs.setdefault( "normal_equation_cache_directory", cache_directory / "normal_equations", ) kwargs.setdefault("normal_equation_cache_mode", config.cache_mode) else: kwargs.setdefault("cache_directory", cache_directory / "assembled") kwargs.setdefault("cache_mode", config.cache_mode) return kwargs def _scalar_metrics(metrics: Mapping[str, object]) -> dict[str, float]: """Keep only scalar validation metrics.""" scalars: dict[str, float] = {} for key, value in metrics.items(): array = np.asarray(value) if array.ndim == 0: scalars[str(key)] = float(array.item()) return scalars def _evaluate_candidate( *, stage: str, candidate: RegularizationCandidate, model_factory: Callable[[], UFPModel], dataset: SupervisedAtomsDataset, train_samples: Sequence[FitSample], validation_indices: np.ndarray, config: RegularizationSearchConfig, fitter_kwargs: Mapping[str, object] | None, fit_kwargs: Mapping[str, object] | None, ) -> RegularizationTrial: """Fit and validate one regularization candidate.""" try: model = model_factory() if not isinstance(model, UFPModel): raise TypeError("`model_factory` must return a UFPModel") merged_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, candidate) fitter = LinearFitter(model, **merged_fitter_kwargs) merged_fit_kwargs = _fit_kwargs_for_stage( fit_kwargs, config=config, stage=stage, fitter=fitter, samples=train_samples, ) fitter.fit(train_samples, **merged_fit_kwargs) device = cast(torch.device | str | None, merged_fitter_kwargs.get("device")) dtype = cast(torch.dtype | None, merged_fitter_kwargs.get("dtype")) metrics = prediction_metrics_for_split( model, dataset, validation_indices, batch_size=int(config.prediction_batch_size), device=device, dtype=dtype, progress=bool(config.progress), ) return RegularizationTrial( stage=stage, candidate=candidate, training_size=len(train_samples), validation_size=int(validation_indices.size), metrics=_scalar_metrics(metrics), status="ok", ) except Exception as exc: # pragma: no cover - exercised by behavior tests. return RegularizationTrial( stage=stage, candidate=candidate, training_size=len(train_samples), validation_size=int(validation_indices.size), status="error", error=f"{type(exc).__name__}: {exc}", ) def _metric_normalizers( trials: Sequence[RegularizationTrial], baseline: RegularizationCandidate, ) -> dict[str, float]: """Return per-metric baseline normalizers for one stage.""" successful = [trial for trial in trials if trial.status == "ok"] if not successful: return {} baseline_trial = next( (trial for trial in successful if trial.candidate == baseline), successful[0], ) normalizers = {} for key in (_ENERGY_RMSE_KEY, _FORCE_RMSE_KEY): value = baseline_trial.metrics.get(key) if value is not None: normalizers[key] = max(abs(float(value)), 1.0e-12) return normalizers def _score_metrics( metrics: Mapping[str, float], normalizers: Mapping[str, float], config: RegularizationSearchConfig, ) -> float: """Return normalized scalar score for validation metrics.""" weighted = 0.0 total_weight = 0.0 metric_weights = { _ENERGY_RMSE_KEY: float(config.energy_score_weight), _FORCE_RMSE_KEY: float(config.force_score_weight), } for key, weight in metric_weights.items(): if weight <= 0.0 or key not in metrics: continue normalizer = max(float(normalizers.get(key, 1.0)), 1.0e-12) weighted += weight * float(metrics[key]) / normalizer total_weight += weight if total_weight == 0.0: return float("inf") return weighted / total_weight def _score_stage_trials( trials: Sequence[RegularizationTrial], *, baseline: RegularizationCandidate, config: RegularizationSearchConfig, ) -> tuple[RegularizationTrial, ...]: """Attach normalized stage scores to successful trials.""" normalizers = _metric_normalizers(trials, baseline) scored = [] for trial in trials: if trial.status != "ok": scored.append(trial) continue scored.append( replace( trial, score=_score_metrics(trial.metrics, normalizers, config), ) ) return tuple(scored) def _trial_rank_key(trial: RegularizationTrial) -> tuple[float, float, float, float]: """Return deterministic candidate ranking key.""" score = float("inf") if trial.score is None else float(trial.score) force = float(trial.metrics.get(_FORCE_RMSE_KEY, float("inf"))) energy = float(trial.metrics.get(_ENERGY_RMSE_KEY, float("inf"))) return score, force, energy, trial.candidate.total_ridge() def _successful_trials( trials: Sequence[RegularizationTrial], ) -> tuple[RegularizationTrial, ...]: """Return successful trials sorted by score.""" return tuple( sorted( (trial for trial in trials if trial.status == "ok"), key=_trial_rank_key, ) ) def _evaluate_stage( *, stage: str, candidates: Sequence[RegularizationCandidate], baseline: RegularizationCandidate, model_factory: Callable[[], UFPModel], dataset: SupervisedAtomsDataset, training_indices: np.ndarray, validation_indices: np.ndarray, sample_weights: dict[int, float] | None, energy_weight: float, force_weight: float, config: RegularizationSearchConfig, fitter_kwargs: Mapping[str, object] | None, fit_kwargs: Mapping[str, object] | None, ) -> tuple[RegularizationTrial, ...]: """Evaluate one progressive tuning stage.""" train_samples = fit_samples_from_dataset( dataset, indices=training_indices, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, ) trials = [ _evaluate_candidate( stage=stage, candidate=candidate, model_factory=model_factory, dataset=dataset, train_samples=train_samples, validation_indices=validation_indices, config=config, fitter_kwargs=fitter_kwargs, fit_kwargs=fit_kwargs, ) for candidate in candidates ] return _score_stage_trials(trials, baseline=baseline, config=config) def _best_trials_for_next_stage( trials: Sequence[RegularizationTrial], *, top_k: int, ) -> tuple[RegularizationTrial, ...]: """Return successful trials retained for the next stage.""" return _successful_trials(trials)[: int(top_k)] def _final_fit( *, candidate: RegularizationCandidate, model_factory: Callable[[], UFPModel], dataset: SupervisedAtomsDataset, training_indices: np.ndarray, sample_weights: dict[int, float] | None, energy_weight: float, force_weight: float, fitter_kwargs: Mapping[str, object] | None, fit_kwargs: Mapping[str, object] | None, config: RegularizationSearchConfig, ) -> tuple[UFPModel, LinearFitResult]: """Fit the selected candidate on the full tuning-training split.""" model = model_factory() if not isinstance(model, UFPModel): raise TypeError("`model_factory` must return a UFPModel") samples = fit_samples_from_dataset( dataset, indices=training_indices, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, ) merged_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, candidate) fitter = LinearFitter(model, **merged_fitter_kwargs) merged_fit_kwargs = _fit_kwargs_for_stage( fit_kwargs, config=config, stage="final_refit", fitter=fitter, samples=samples, ) return model, fitter.fit(samples, **merged_fit_kwargs)
[docs] def tune_linear_regularization( model_factory: Callable[[], UFPModel], dataset: SupervisedAtomsDataset, *, training_indices: Sequence[int] | np.ndarray | None = None, validation_indices: Sequence[int] | np.ndarray | None = None, sample_weights: dict[int, float] | None = None, energy_weight: float = 1.0, force_weight: float = 1.0, config: RegularizationSearchConfig | None = None, fitter_kwargs: Mapping[str, object] | None = None, fit_kwargs: Mapping[str, object] | None = None, ) -> RegularizationSearchResult: """Tune linear ridge weights with deterministic progressive subsets.""" resolved_config = RegularizationSearchConfig() if config is None else config train_indices, validation_pool = _resolve_training_validation_indices( dataset, training_indices=training_indices, validation_indices=validation_indices, config=resolved_config, ) estimate_training = _representative_indices( dataset, train_indices, max_size=resolved_config.estimate_subset_size, seed=int(resolved_config.seed), ) estimate_samples = fit_samples_from_dataset( dataset, indices=estimate_training, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, ) estimates = estimate_linear_ridge_scales( model_factory(), estimate_samples, fitter_kwargs=fitter_kwargs, fit_kwargs=fit_kwargs, subset_size=None, seed=int(resolved_config.seed), alpha=float(resolved_config.alpha), ) baseline = estimates.candidate() active_groups = _active_groups(estimates) candidates = _initial_candidates(estimates, resolved_config) all_trials: list[RegularizationTrial] = [] retained = candidates last_successes: tuple[RegularizationTrial, ...] = () for stage_number, subset_size in enumerate(resolved_config.stage_subset_sizes, 1): stage = f"subset_{int(subset_size)}" stage_training = _representative_indices( dataset, train_indices, max_size=int(subset_size), seed=int(resolved_config.seed) + stage_number, ) stage_validation = _representative_indices( dataset, validation_pool, max_size=int(subset_size), seed=int(resolved_config.seed) + 100 + stage_number, ) trials = _evaluate_stage( stage=stage, candidates=retained, baseline=baseline, model_factory=model_factory, dataset=dataset, training_indices=stage_training, validation_indices=stage_validation, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, config=resolved_config, fitter_kwargs=fitter_kwargs, fit_kwargs=fit_kwargs, ) all_trials.extend(trials) last_successes = _best_trials_for_next_stage( trials, top_k=resolved_config.top_k_per_stage, ) if not last_successes: continue retained_candidates = tuple(trial.candidate for trial in last_successes) retained = _refinement_candidates( retained_candidates, active_groups, resolved_config, )[ : max( int(resolved_config.top_k_per_stage) * 8, int(resolved_config.top_k_per_stage), ) ] if last_successes: final_candidates = tuple(trial.candidate for trial in last_successes) else: final_candidates = candidates final_trials = _evaluate_stage( stage="full_validation", candidates=final_candidates, baseline=baseline, model_factory=model_factory, dataset=dataset, training_indices=train_indices, validation_indices=validation_pool, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, config=resolved_config, fitter_kwargs=fitter_kwargs, fit_kwargs=fit_kwargs, ) all_trials.extend(final_trials) final_successes = _successful_trials(final_trials) best_trial_pool = ( final_successes or last_successes or _successful_trials(all_trials) ) if not best_trial_pool: raise RuntimeError("regularization search did not produce a successful trial") best_candidate = best_trial_pool[0].candidate best_fitter_kwargs = _candidate_fitter_kwargs(fitter_kwargs, best_candidate) final_model = None final_fit_result = None if resolved_config.refit_full: final_model, final_fit_result = _final_fit( candidate=best_candidate, model_factory=model_factory, dataset=dataset, training_indices=train_indices, sample_weights=sample_weights, energy_weight=energy_weight, force_weight=force_weight, fitter_kwargs=fitter_kwargs, fit_kwargs=fit_kwargs, config=resolved_config, ) return RegularizationSearchResult( estimates=estimates, trials=tuple(all_trials), best_candidate=best_candidate, best_fitter_kwargs=best_fitter_kwargs, final_model=final_model, final_fit_result=final_fit_result, )
__all__ = [ "RegularizationCandidate", "RegularizationSearchConfig", "RegularizationSearchResult", "RegularizationTrial", "RidgeGroupEstimate", "RidgeScaleEstimate", "estimate_linear_ridge_scales", "tune_linear_regularization", ]