"""
Predictive uncertainty helpers for coefficient-linear UFP models.
This module builds dense Bayesian posteriors from the same weighted design
matrices used by :class:`ufp.leastsquares.LinearFitter`, then evaluates
epistemic variances from sparse prediction rows. Alchemical providers are
handled by freezing the learned mixing weights and forming one fixed-weight
linear problem over direct coefficients and proxy coefficients.
"""
from __future__ import annotations
import hashlib
import json
import math
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import ase
import numpy as np
import torch
from ufp.core.output import UFPOutput
from ufp.leastsquares._block import (
BlockMatrix,
BlockProblemLayout,
BlockSolveBatch,
SolveBlock,
_block_matrix_values_on_rows,
_materialize_block_matrix,
)
from ufp.leastsquares._layout import ParameterLayout, ProviderGroup
from ufp.leastsquares._problem import BlockLinearProblem
from ufp.leastsquares.alchemical import AlchemicalALSFitter
from ufp.leastsquares.dataset import FitSample, prepare_batches
from ufp.leastsquares.linear import (
AssembledBatchCacheMode,
LinearFitter,
_twobody_shape_regularization_rows,
)
from ufp.leastsquares.regularization import _make_block_regularization
from ufp.neighbors._data import NeighborListData
from ufp.splines.representation import uniform_stencil_1d, uniform_support_parameters
from ufp.terms.model import UFPModel
from ufp.workflows.checkpoints import (
WorkflowCheckpointError,
coefficient_layout_metadata,
fixed_coefficient_hashes,
normalize_checkpoint_metadata,
selector_metadata,
validate_fixed_coefficient_hashes,
)
from ufp.workflows.models import load_model_from_checkpoint, model_schema
POSTERIOR_SCHEMA_NAME = "ufp.bayesian_linear_posterior"
POSTERIOR_SCHEMA_VERSION = 1
POSTERIOR_LAYOUT_SCHEMA_NAME = "ufp.posterior_layout"
POSTERIOR_LAYOUT_SCHEMA_VERSION = 1
UNCERTAINTY_BUNDLE_SCHEMA_NAME = "ufp.uncertainty_prediction_bundle"
UNCERTAINTY_BUNDLE_SCHEMA_VERSION = 1
ALEATORIC_SPLINE_SCHEMA_NAME = "ufp.spline_aleatoric_noise_model"
ALEATORIC_SPLINE_SCHEMA_VERSION = 1
ALEATORIC_BUNDLE_SCHEMA_NAME = "ufp.spline_aleatoric_noise_bundle"
ALEATORIC_BUNDLE_SCHEMA_VERSION = 1
ENERGY_VARIANCE_SCALE_SCHEMA_NAME = "ufp.energy_variance_scale"
ENERGY_VARIANCE_SCALE_SCHEMA_VERSION = 1
ALEATORIC_ENERGY_KIND = "energy_per_atom"
ALEATORIC_PER_ATOM_KIND = "per_atom_energy"
ALEATORIC_FORCE_KIND = "force_component"
def _torch_load(path: Path | str, *, map_location: Any = "cpu") -> object:
"""Load one torch checkpoint across supported torch versions."""
try:
return torch.load(path, map_location=map_location, weights_only=False)
except TypeError:
return torch.load(path, map_location=map_location)
def _hash_tensor_values(tensor: torch.Tensor) -> str:
"""Return a stable SHA256 digest for one tensor's dtype, shape, and values."""
contiguous = tensor.detach().cpu().contiguous()
hasher = hashlib.sha256()
hasher.update(str(contiguous.dtype).encode("utf8"))
hasher.update(torch.tensor(contiguous.shape, dtype=torch.int64).numpy().tobytes())
hasher.update(contiguous.numpy().tobytes())
return hasher.hexdigest()
def _provider_weight_metadata(model: UFPModel) -> list[dict[str, object]]:
"""Return metadata that freezes non-identity alchemical mixing weights."""
entries: list[dict[str, object]] = []
for index, provider in enumerate(model.alchemical_coefficients):
weights = provider.weights
if weights is None:
entries.append(
{
"index": int(index),
"uses_identity_weights": True,
"weights": None,
}
)
continue
weights_tensor = torch.as_tensor(weights)
detached = weights_tensor.detach().cpu()
entries.append(
{
"index": int(index),
"uses_identity_weights": False,
"weights": {
"dtype": str(detached.dtype),
"shape": [int(dim) for dim in detached.shape],
"values_hash": _hash_tensor_values(detached),
"requires_grad": bool(
getattr(weights_tensor, "requires_grad", False)
),
},
}
)
return entries
def _file_sha256(path: Path | str) -> str:
"""Return the SHA256 digest for one file."""
hasher = hashlib.sha256()
with Path(path).open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
hasher.update(chunk)
return hasher.hexdigest()
def _infer_checkpoint_dtype(checkpoint: Mapping[str, object]) -> torch.dtype:
"""Infer a floating dtype from a checkpoint state dict."""
state_dict = checkpoint.get("state_dict")
if isinstance(state_dict, Mapping):
for value in state_dict.values():
if isinstance(value, torch.Tensor) and value.is_floating_point():
return value.dtype
return torch.float64
def _serialize_solve_key(key: object) -> dict[str, object]:
"""Return a JSON representation of one solve-layout key."""
if isinstance(key, int):
return {"kind": "direct", "block_index": int(key)}
if (
isinstance(key, tuple)
and len(key) == 3
and key[0] == "proxy"
and isinstance(key[1], int)
and isinstance(key[2], int)
):
return {
"kind": "proxy",
"provider_index": int(key[1]),
"proxy_index": int(key[2]),
}
raise ValueError(f"unsupported posterior layout key: {key!r}")
def _deserialize_solve_key(payload: object) -> object:
"""Return a solve-layout key from its JSON representation."""
if not isinstance(payload, Mapping):
raise ValueError("posterior layout block key must be a mapping")
kind = payload.get("kind")
if kind == "direct":
block_index = payload.get("block_index")
if not isinstance(block_index, int):
raise ValueError("direct posterior layout key needs integer block_index")
return int(block_index)
if kind == "proxy":
provider_index = payload.get("provider_index")
proxy_index = payload.get("proxy_index")
if not isinstance(provider_index, int) or not isinstance(proxy_index, int):
raise ValueError(
"proxy posterior layout key needs integer provider/proxy indices"
)
return _provider_proxy_key(int(provider_index), int(proxy_index))
raise ValueError(f"unsupported posterior layout key kind: {kind!r}")
def _layout_to_payload(layout: BlockProblemLayout) -> dict[str, object]:
"""Serialize a posterior solve layout to JSON-friendly metadata."""
return {
"schema": {
"name": POSTERIOR_LAYOUT_SCHEMA_NAME,
"version": POSTERIOR_LAYOUT_SCHEMA_VERSION,
},
"size": int(layout.size),
"blocks": [
{
"key": _serialize_solve_key(block.key),
"label": block.label,
"size": int(block.size),
"start": int(layout.theta_slice(block.key).start),
"stop": int(layout.theta_slice(block.key).stop),
}
for block in layout.blocks
],
}
def _layout_from_payload(payload: object) -> BlockProblemLayout:
"""Deserialize and validate a posterior solve layout."""
if not isinstance(payload, Mapping):
raise ValueError("posterior layout must be a mapping")
schema = payload.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("posterior layout is missing schema metadata")
if schema.get("name") != POSTERIOR_LAYOUT_SCHEMA_NAME:
raise ValueError("unsupported posterior layout schema")
if schema.get("version") != POSTERIOR_LAYOUT_SCHEMA_VERSION:
raise ValueError("unsupported posterior layout schema version")
raw_blocks = payload.get("blocks")
if not isinstance(raw_blocks, Sequence) or isinstance(raw_blocks, (str, bytes)):
raise ValueError("posterior layout blocks must be a sequence")
cursor = 0
seen_keys: set[object] = set()
blocks: list[SolveBlock] = []
for raw_block in raw_blocks:
if not isinstance(raw_block, Mapping):
raise ValueError("posterior layout block must be a mapping")
key = _deserialize_solve_key(raw_block.get("key"))
if key in seen_keys:
raise ValueError("posterior layout contains duplicate block keys")
seen_keys.add(key)
size = int(raw_block.get("size", -1))
start = int(raw_block.get("start", -1))
stop = int(raw_block.get("stop", -1))
if size < 0:
raise ValueError("posterior layout block sizes must be non-negative")
if start != cursor or stop != cursor + size:
raise ValueError("posterior layout block slices must be contiguous")
blocks.append(
SolveBlock(
key=key,
size=size,
label=str(raw_block.get("label", "")),
regularization=None,
)
)
cursor = stop
declared_size = int(payload.get("size", -1))
if declared_size != cursor:
raise ValueError("posterior layout size does not match block slices")
layout = BlockProblemLayout.from_blocks(blocks)
if int(layout.size) != declared_size:
raise ValueError("posterior layout size mismatch")
return layout
def _solve_layout_metadata(layout: BlockProblemLayout) -> dict[str, object]:
"""Return JSON-friendly metadata for a solve-time block layout."""
return {
"size": int(layout.size),
"blocks": [
{
"key": repr(block.key),
"label": block.label,
"size": int(block.size),
"start": int(layout.theta_slice(block.key).start),
"stop": int(layout.theta_slice(block.key).stop),
}
for block in layout.blocks
],
}
def _posterior_metadata(
*,
model: UFPModel,
fitter: LinearFitter,
problem: BlockLinearProblem,
kind: str,
extra: Mapping[str, object] | None = None,
) -> dict[str, object]:
"""Build the validation metadata stored beside a posterior covariance."""
metadata: dict[str, object] = {
"schema": {
"name": POSTERIOR_SCHEMA_NAME,
"version": POSTERIOR_SCHEMA_VERSION,
},
"kind": kind,
"dtype": str(problem.dtype),
"n_parameters": int(problem.layout.size),
"coefficient_layout": coefficient_layout_metadata(fitter.layout),
"selector_metadata": selector_metadata(
fitter.layout,
fit_blocks=fitter.fit_blocks,
freeze_blocks=fitter.freeze_blocks,
),
"fixed_coefficient_hashes": fixed_coefficient_hashes(
fitter.layout,
fit_blocks=fitter.fit_blocks,
freeze_blocks=fitter.freeze_blocks,
),
"alchemical_weights": _provider_weight_metadata(model),
"solve_layout": _solve_layout_metadata(problem.layout),
}
if extra is not None:
metadata.update(normalize_checkpoint_metadata(dict(extra)))
return metadata
def _metadata_mismatch(expected: object, actual: object, *, path: str) -> str | None:
"""Return the first mismatch path when comparing two metadata payloads."""
if isinstance(expected, Mapping) and isinstance(actual, Mapping):
for key in sorted(expected):
if key not in actual:
return f"{path}.{key} missing"
mismatch = _metadata_mismatch(
expected[key],
actual[key],
path=f"{path}.{key}",
)
if mismatch is not None:
return mismatch
return None
if isinstance(expected, Sequence) and not isinstance(expected, (str, bytes)):
if not isinstance(actual, Sequence) or isinstance(actual, (str, bytes)):
return f"{path} type mismatch"
if len(expected) != len(actual):
return f"{path} length mismatch"
for index, (expected_item, actual_item) in enumerate(
zip(expected, actual, strict=True)
):
mismatch = _metadata_mismatch(
expected_item,
actual_item,
path=f"{path}[{index}]",
)
if mismatch is not None:
return mismatch
return None
if expected != actual:
return f"{path} mismatch"
return None
def _validate_expected_metadata(
metadata: Mapping[str, object],
expected_metadata: Mapping[str, object],
) -> None:
"""Raise if ``metadata`` does not contain the expected values."""
mismatch = _metadata_mismatch(expected_metadata, metadata, path="metadata")
if mismatch is not None:
raise ValueError(f"posterior metadata does not match: {mismatch}")
[docs]
@dataclass(frozen=True)
class SparseLinearRow:
"""One sparse row in a coefficient solve layout."""
indices: torch.Tensor
values: torch.Tensor
size: int
def __post_init__(self) -> None:
"""Validate row tensor shapes and bounds."""
indices = torch.as_tensor(self.indices, dtype=torch.int64).reshape(-1)
values = torch.as_tensor(self.values).reshape(-1)
if indices.numel() != values.numel():
raise ValueError("`indices` and `values` must have the same length")
if int(self.size) < 0:
raise ValueError("`size` must be non-negative")
if indices.numel() and (
bool(torch.any(indices < 0)) or bool(torch.any(indices >= int(self.size)))
):
raise ValueError("row indices are outside the solve-vector size")
object.__setattr__(self, "indices", indices)
object.__setattr__(self, "values", values)
object.__setattr__(self, "size", int(self.size))
@property
def dtype(self) -> torch.dtype:
"""Return the row value dtype."""
return self.values.dtype
@property
def device(self) -> torch.device:
"""Return the row value device."""
return self.values.device
[docs]
def to_dense(
self,
*,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
"""Materialize the sparse row as one dense vector."""
resolved_dtype = self.values.dtype if dtype is None else dtype
resolved_device = self.values.device if device is None else device
dense = torch.zeros(
(int(self.size),),
dtype=resolved_dtype,
device=resolved_device,
)
if self.indices.numel():
dense.index_add_(
0,
self.indices.to(device=resolved_device),
self.values.to(device=resolved_device, dtype=resolved_dtype),
)
return dense
[docs]
def matvec(self, theta: torch.Tensor) -> torch.Tensor:
"""Apply this row to one coefficient vector."""
theta = theta.reshape(self.size)
if not self.indices.numel():
return torch.zeros((), dtype=theta.dtype, device=theta.device)
return torch.dot(
theta.index_select(0, self.indices.to(device=theta.device)),
self.values.to(device=theta.device, dtype=theta.dtype),
)
def _row_from_dense(dense: torch.Tensor, *, size: int) -> SparseLinearRow:
"""Convert one dense vector into a sparse row."""
values = dense.reshape(int(size))
indices = torch.nonzero(values != 0, as_tuple=False).reshape(-1)
return SparseLinearRow(
indices=indices,
values=values.index_select(0, indices),
size=int(size),
)
[docs]
def combine_total_energy_rows(
rows: Sequence[tuple[SparseLinearRow, float]],
) -> SparseLinearRow:
"""Return a sparse linear combination of total-energy rows."""
items = tuple(rows)
if not items:
raise ValueError("`rows` must contain at least one row")
size = items[0][0].size
dtype = items[0][0].dtype
device = items[0][0].device
indices: list[torch.Tensor] = []
values: list[torch.Tensor] = []
for row, coefficient in items:
if row.size != size:
raise ValueError("all rows must use the same solve-vector size")
if float(coefficient) == 0.0 or not row.indices.numel():
continue
indices.append(row.indices.to(device=device))
values.append(row.values.to(device=device, dtype=dtype) * float(coefficient))
if not indices:
return SparseLinearRow(
indices=torch.empty(0, dtype=torch.int64, device=device),
values=torch.empty(0, dtype=dtype, device=device),
size=size,
)
all_indices = torch.cat(indices, dim=0)
all_values = torch.cat(values, dim=0)
order = torch.argsort(all_indices)
all_indices = all_indices.index_select(0, order)
all_values = all_values.index_select(0, order)
unique, inverse = torch.unique_consecutive(all_indices, return_inverse=True)
combined = torch.zeros(
(int(unique.numel()),),
dtype=all_values.dtype,
device=all_values.device,
)
combined.index_add_(0, inverse, all_values)
nonzero = combined != 0
return SparseLinearRow(
indices=unique[nonzero],
values=combined[nonzero],
size=size,
)
[docs]
@dataclass(frozen=True)
class SparsePredictionRows:
"""Sparse rows needed for diagonal predictive variances."""
atomic_energy_rows: tuple[SparseLinearRow, ...]
total_energy_row: SparseLinearRow
force_rows: (
tuple[tuple[SparseLinearRow, SparseLinearRow, SparseLinearRow], ...] | None
) = None
@property
def n_atoms(self) -> int:
"""Return the number of atomic energy rows."""
return len(self.atomic_energy_rows)
@property
def force_component_rows(self) -> tuple[SparseLinearRow, ...]:
"""Return flattened force-component rows, if present."""
if self.force_rows is None:
return ()
return tuple(component for row in self.force_rows for component in row)
[docs]
@dataclass(frozen=True)
class AleatoricFeatureSpec:
"""Prediction-time feature contract for aleatoric spline heads."""
kind: str = "log_num_atoms"
[docs]
def feature_for_atoms(self, atoms: ase.Atoms) -> float:
"""Return one scalar feature for a structure."""
if self.kind == "log_num_atoms":
return float(math.log1p(len(atoms)))
raise ValueError(f"unsupported aleatoric feature kind: {self.kind!r}")
[docs]
def to_payload(self) -> dict[str, object]:
"""Serialize this feature spec to JSON-compatible metadata."""
return {"kind": self.kind}
[docs]
@classmethod
def from_payload(cls, payload: object) -> "AleatoricFeatureSpec":
"""Deserialize a feature spec."""
if payload is None:
return cls()
if not isinstance(payload, Mapping):
raise ValueError("aleatoric feature spec must be a mapping")
return cls(kind=str(payload.get("kind", "log_num_atoms")))
[docs]
@dataclass
class SplineAleatoricNoiseBundle:
"""Optional spline variance heads for UFP prediction targets."""
feature_spec: AleatoricFeatureSpec = field(default_factory=AleatoricFeatureSpec)
energy_per_atom: "SplineAleatoricNoiseModel | None" = None
per_atom_energy: "SplineAleatoricNoiseModel | None" = None
force_component: "SplineAleatoricNoiseModel | None" = None
[docs]
def head_for_kind(self, kind: str) -> "SplineAleatoricNoiseModel | None":
"""Return the variance head associated with one target kind."""
if kind == ALEATORIC_ENERGY_KIND:
return self.energy_per_atom
if kind == ALEATORIC_PER_ATOM_KIND:
return self.per_atom_energy
if kind == ALEATORIC_FORCE_KIND:
return self.force_component
raise ValueError(f"unsupported aleatoric target kind: {kind!r}")
[docs]
def has_heads(self) -> bool:
"""Return whether this bundle contains any active spline head."""
return any(
head is not None
for head in (
self.energy_per_atom,
self.per_atom_energy,
self.force_component,
)
)
[docs]
def predict_for_atoms(
self,
atoms: ase.Atoms,
*,
include_forces: bool,
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""Return structure-dependent aleatoric variances for one structure."""
feature_value = self.feature_spec.feature_for_atoms(atoms)
feature = torch.as_tensor(feature_value, dtype=dtype, device=device)
n_atoms = len(atoms)
energy_variance = None
if self.energy_per_atom is not None:
value = self.energy_per_atom(feature).to(dtype=dtype, device=device)
energy_variance = (value.reshape(()) * float(n_atoms * n_atoms)).reshape(1)
per_atom_variance = None
if self.per_atom_energy is not None:
features = feature.reshape(()).expand(n_atoms)
per_atom_variance = self.per_atom_energy(features).to(
dtype=dtype,
device=device,
)
force_variance = None
if include_forces and self.force_component is not None:
features = feature.reshape(()).expand(n_atoms * 3)
force_variance = (
self.force_component(features)
.to(
dtype=dtype,
device=device,
)
.reshape(n_atoms, 3)
)
return energy_variance, per_atom_variance, force_variance
[docs]
def to_payload(self) -> dict[str, object]:
"""Return a torch-serializable payload for this aleatoric bundle."""
heads = {
ALEATORIC_ENERGY_KIND: None
if self.energy_per_atom is None
else self.energy_per_atom.to_payload(),
ALEATORIC_PER_ATOM_KIND: None
if self.per_atom_energy is None
else self.per_atom_energy.to_payload(),
ALEATORIC_FORCE_KIND: None
if self.force_component is None
else self.force_component.to_payload(),
}
return {
"schema": {
"name": ALEATORIC_BUNDLE_SCHEMA_NAME,
"version": ALEATORIC_BUNDLE_SCHEMA_VERSION,
},
"feature_spec": self.feature_spec.to_payload(),
"heads": heads,
}
[docs]
@classmethod
def from_payload(
cls,
payload: object,
*,
dtype: torch.dtype | None = None,
) -> "SplineAleatoricNoiseBundle":
"""Load a fitted aleatoric bundle from :meth:`to_payload` output."""
if not isinstance(payload, Mapping):
raise ValueError("aleatoric noise bundle payload must be a mapping")
schema = payload.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("aleatoric noise bundle payload is missing schema")
if schema.get("name") != ALEATORIC_BUNDLE_SCHEMA_NAME:
raise ValueError("unsupported aleatoric noise bundle schema")
if schema.get("version") != ALEATORIC_BUNDLE_SCHEMA_VERSION:
raise ValueError("unsupported aleatoric noise bundle schema version")
heads = payload.get("heads")
if not isinstance(heads, Mapping):
raise ValueError("aleatoric noise bundle payload is missing heads")
def load_head(kind: str) -> "SplineAleatoricNoiseModel | None":
head_payload = heads.get(kind)
if head_payload is None:
return None
return SplineAleatoricNoiseModel.from_payload(
head_payload,
dtype=dtype,
)
return cls(
feature_spec=AleatoricFeatureSpec.from_payload(payload.get("feature_spec")),
energy_per_atom=load_head(ALEATORIC_ENERGY_KIND),
per_atom_energy=load_head(ALEATORIC_PER_ATOM_KIND),
force_component=load_head(ALEATORIC_FORCE_KIND),
)
[docs]
@dataclass
class BayesianLinearPosterior:
"""Dense coefficient posterior for a coefficient-linear UFP solve layout."""
theta_mean: torch.Tensor
Sigma_theta: torch.Tensor | np.ndarray
metadata: dict[str, object] = field(default_factory=dict)
layout: BlockProblemLayout | None = None
_covariance_tensor_cache: torch.Tensor | None = field(
default=None,
init=False,
repr=False,
compare=False,
)
def __post_init__(self) -> None:
"""Validate mean and covariance shapes."""
self.theta_mean = torch.as_tensor(self.theta_mean).reshape(-1)
if isinstance(self.Sigma_theta, torch.Tensor):
sigma_shape = tuple(int(dim) for dim in self.Sigma_theta.shape)
else:
sigma_shape = tuple(int(dim) for dim in np.asarray(self.Sigma_theta).shape)
expected = (int(self.theta_mean.numel()), int(self.theta_mean.numel()))
if sigma_shape != expected:
raise ValueError(
f"`Sigma_theta` must have shape {expected}, got {sigma_shape}"
)
if self.layout is not None and int(self.layout.size) != int(
self.theta_mean.numel()
):
raise ValueError("posterior layout size must match parameter count")
schema = self.metadata.get("schema")
if schema is not None:
if not isinstance(schema, Mapping):
raise ValueError("posterior metadata `schema` must be a mapping")
if schema.get("name") != POSTERIOR_SCHEMA_NAME:
raise ValueError("unsupported posterior schema")
if schema.get("version") != POSTERIOR_SCHEMA_VERSION:
raise ValueError("unsupported posterior schema version")
@property
def n_parameters(self) -> int:
"""Return the posterior parameter dimension."""
return int(self.theta_mean.numel())
@property
def dtype(self) -> torch.dtype:
"""Return the posterior mean dtype."""
return self.theta_mean.dtype
@property
def device(self) -> torch.device:
"""Return the posterior mean device."""
return self.theta_mean.device
[docs]
def covariance_tensor(
self,
*,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> torch.Tensor:
"""Return ``Sigma_theta`` as a torch tensor."""
tensor: torch.Tensor | None
if isinstance(self.Sigma_theta, torch.Tensor):
tensor = self.Sigma_theta
else:
tensor = self._covariance_tensor_cache
if tensor is None:
array = np.asarray(self.Sigma_theta)
if not array.flags.writeable:
array = array.copy()
tensor = torch.as_tensor(array)
self._covariance_tensor_cache = tensor
assert tensor is not None
return tensor.to(
dtype=tensor.dtype if dtype is None else dtype,
device=tensor.device if device is None else device,
)
[docs]
def save(self, path: Path | str) -> None:
"""Save this posterior to a torch checkpoint file."""
payload = {
"theta_mean": self.theta_mean.detach().cpu(),
"Sigma_theta": self.covariance_tensor().detach().cpu(),
"metadata": normalize_checkpoint_metadata(self.metadata),
}
posterior_path = Path(path)
posterior_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(payload, posterior_path)
[docs]
@classmethod
def load(
cls,
path: Path | str,
*,
expected_metadata: Mapping[str, object] | None = None,
map_location: object = "cpu",
) -> "BayesianLinearPosterior":
"""Load a posterior checkpoint and validate optional metadata."""
payload = _torch_load(path, map_location=map_location)
if not isinstance(payload, Mapping):
raise ValueError("posterior checkpoint must contain a mapping")
metadata = payload.get("metadata", {})
if not isinstance(metadata, Mapping):
raise ValueError("posterior checkpoint metadata must be a mapping")
schema = metadata.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("posterior checkpoint metadata is missing `schema`")
if schema.get("name") != POSTERIOR_SCHEMA_NAME:
raise ValueError("unsupported posterior schema")
if schema.get("version") != POSTERIOR_SCHEMA_VERSION:
raise ValueError("unsupported posterior schema version")
if expected_metadata is not None:
_validate_expected_metadata(metadata, expected_metadata)
return cls(
theta_mean=torch.as_tensor(payload["theta_mean"]),
Sigma_theta=torch.as_tensor(payload["Sigma_theta"]),
metadata=dict(metadata),
)
[docs]
def save_memmap(self, directory: Path | str) -> None:
"""Save covariance as a memory-mapped ``.npy`` plus metadata files."""
output_dir = Path(directory)
output_dir.mkdir(parents=True, exist_ok=True)
np.save(output_dir / "theta_mean.npy", self.theta_mean.detach().cpu().numpy())
sigma = self.covariance_tensor().detach().cpu().numpy()
mmap = np.lib.format.open_memmap(
output_dir / "Sigma_theta.npy",
mode="w+",
dtype=sigma.dtype,
shape=sigma.shape,
)
mmap[:] = sigma
mmap.flush()
(output_dir / "metadata.json").write_text(
json.dumps(normalize_checkpoint_metadata(self.metadata), indent=2) + "\n",
encoding="utf8",
)
if self.layout is not None:
(output_dir / "layout.json").write_text(
json.dumps(_layout_to_payload(self.layout), indent=2) + "\n",
encoding="utf8",
)
[docs]
@classmethod
def load_memmap(
cls,
directory: Path | str,
*,
expected_metadata: Mapping[str, object] | None = None,
) -> "BayesianLinearPosterior":
"""Load a memory-mapped posterior saved by :meth:`save_memmap`."""
input_dir = Path(directory)
metadata = json.loads((input_dir / "metadata.json").read_text(encoding="utf8"))
if expected_metadata is not None:
_validate_expected_metadata(metadata, expected_metadata)
layout_path = input_dir / "layout.json"
layout = (
_layout_from_payload(json.loads(layout_path.read_text(encoding="utf8")))
if layout_path.is_file()
else None
)
return cls(
theta_mean=torch.as_tensor(np.load(input_dir / "theta_mean.npy")),
Sigma_theta=np.load(input_dir / "Sigma_theta.npy", mmap_mode="c"),
metadata=dict(metadata),
layout=layout,
)
[docs]
@dataclass(frozen=True)
class UncertaintyPredictionBundle:
"""Model, posterior, and metadata loaded from an uncertainty bundle."""
model: UFPModel
posterior: BayesianLinearPosterior
aleatoric_variance: float | None = None
aleatoric_noise_model: "SplineAleatoricNoiseModel | None" = None
aleatoric_prediction_feature: float | None = None
aleatoric_noise_bundle: SplineAleatoricNoiseBundle | None = None
energy_variance_scale: float = 1.0
manifest: dict[str, object] = field(default_factory=dict)
[docs]
def prediction_aleatoric_variance(self) -> float | None:
"""Return the scalar aleatoric variance to use for generic predictions."""
if self.aleatoric_variance is not None:
return float(self.aleatoric_variance)
if (
self.aleatoric_noise_model is None
or self.aleatoric_prediction_feature is None
):
return None
feature = torch.as_tensor(
self.aleatoric_prediction_feature,
dtype=self.aleatoric_noise_model.raw_values.dtype,
device=self.aleatoric_noise_model.raw_values.device,
)
with torch.no_grad():
variance = self.aleatoric_noise_model(feature)
return float(variance.detach().cpu().reshape(-1)[0].item())
[docs]
def prediction_aleatoric_noise_bundle(
self,
) -> SplineAleatoricNoiseBundle | None:
"""Return the structure-dependent aleatoric noise bundle, if present."""
return self.aleatoric_noise_bundle
def _selected_indices_by_block(
metadata: Mapping[str, object],
) -> dict[int, tuple[int, ...]]:
"""Return selected original coefficient indices by block index."""
selector = metadata.get("selector_metadata")
if not isinstance(selector, Mapping):
return {}
entries = selector.get("entries")
if not isinstance(entries, Sequence) or isinstance(entries, (str, bytes)):
return {}
indices_by_block: dict[int, list[int]] = {}
for entry in entries:
if not isinstance(entry, Mapping):
continue
block_index = entry.get("block_index")
original_indices = entry.get("original_indices")
if not isinstance(block_index, int):
continue
if not isinstance(original_indices, Sequence) or isinstance(
original_indices,
(str, bytes),
):
continue
indices_by_block.setdefault(int(block_index), []).extend(
int(index) for index in original_indices
)
return {
block_index: tuple(indices) for block_index, indices in indices_by_block.items()
}
def _assert_close_vector(
actual: torch.Tensor,
expected: torch.Tensor,
*,
label: str,
) -> None:
"""Raise a clear error if two model/posterior vectors differ."""
actual_vector = actual.detach().cpu().reshape(-1).to(dtype=expected.dtype)
expected_vector = expected.detach().cpu().reshape(-1)
if actual_vector.shape != expected_vector.shape:
raise ValueError(
f"{label} shape does not match posterior theta_mean "
f"({tuple(actual_vector.shape)} != {tuple(expected_vector.shape)})"
)
if not torch.allclose(
actual_vector,
expected_vector,
rtol=1.0e-6,
atol=1.0e-8,
):
raise ValueError(f"{label} does not match posterior theta_mean")
def _direct_block_values_for_posterior(
parameter_layout: ParameterLayout,
block_index: int,
*,
expected_size: int,
indices_by_block: Mapping[int, tuple[int, ...]],
) -> torch.Tensor:
"""Return current model values for one direct posterior solve block."""
block = parameter_layout.block(int(block_index))
values = block.read().detach().reshape(-1)
selected_indices = indices_by_block.get(int(block_index))
if selected_indices is None:
if int(values.numel()) != int(expected_size):
raise ValueError(f"posterior block {block.label!r} needs selector metadata")
return values
if len(selected_indices) != int(expected_size):
raise ValueError(
f"posterior block {block.label!r} selector size does not match layout"
)
index = torch.tensor(selected_indices, dtype=torch.long)
return values.index_select(0, index)
def _proxy_block_values_for_posterior(
provider_groups: Sequence[ProviderGroup],
provider_index: int,
proxy_index: int,
*,
expected_size: int,
) -> torch.Tensor:
"""Return current model values for one fixed-weight proxy posterior block."""
if provider_index < 0 or provider_index >= len(provider_groups):
raise ValueError("posterior proxy provider index is out of range")
provider_group = provider_groups[int(provider_index)]
n_proxy_terms = int(provider_group.n_proxy_terms)
block_size = int(provider_group.block_size)
if proxy_index < 0 or proxy_index >= n_proxy_terms:
raise ValueError("posterior proxy index is out of range")
if block_size != int(expected_size):
raise ValueError("posterior proxy block size does not match provider")
provider = provider_group.provider
proxy_coeffs = provider.proxy_coeffs.detach().reshape(n_proxy_terms, block_size)
return proxy_coeffs[int(proxy_index)].reshape(-1)
def _validate_posterior_model_compatibility(
model: UFPModel,
posterior: BayesianLinearPosterior,
) -> None:
"""Validate that a model checkpoint matches a posterior bundle."""
if posterior.layout is None:
raise ValueError("uncertainty prediction bundles require posterior layout")
parameter_layout = ParameterLayout.from_model(model, include_frozen=True)
current_layout_metadata = coefficient_layout_metadata(parameter_layout)
expected_layout_metadata = posterior.metadata.get("coefficient_layout")
if expected_layout_metadata is not None:
mismatch = _metadata_mismatch(
expected_layout_metadata,
current_layout_metadata,
path="coefficient_layout",
)
if mismatch is not None:
raise ValueError(
f"posterior coefficient layout does not match model: {mismatch}"
)
fixed_hashes = posterior.metadata.get("fixed_coefficient_hashes")
if isinstance(fixed_hashes, Mapping):
try:
validate_fixed_coefficient_hashes(model, fixed_hashes)
except WorkflowCheckpointError as exc:
raise ValueError(str(exc)) from exc
expected_weights = posterior.metadata.get("alchemical_weights")
if expected_weights is not None:
mismatch = _metadata_mismatch(
expected_weights,
_provider_weight_metadata(model),
path="alchemical_weights",
)
if mismatch is not None:
raise ValueError(
f"posterior alchemical weights do not match model: {mismatch}"
)
indices_by_block = _selected_indices_by_block(posterior.metadata)
provider_groups = parameter_layout.non_identity_providers()
for block in posterior.layout.blocks:
theta_slice = posterior.layout.theta_slice(block.key)
expected = posterior.theta_mean[theta_slice]
if isinstance(block.key, int):
actual = _direct_block_values_for_posterior(
parameter_layout,
int(block.key),
expected_size=int(block.size),
indices_by_block=indices_by_block,
)
label = f"model coefficient block {block.label!r}"
elif (
isinstance(block.key, tuple)
and len(block.key) == 3
and block.key[0] == "proxy"
):
actual = _proxy_block_values_for_posterior(
provider_groups,
int(block.key[1]),
int(block.key[2]),
expected_size=int(block.size),
)
label = f"model proxy coefficient block {block.label!r}"
else:
raise ValueError(f"unsupported posterior layout key: {block.key!r}")
_assert_close_vector(actual, expected, label=label)
def _relative_bundle_path(directory: Path, path: Path) -> str:
"""Return a manifest path relative to the bundle directory."""
try:
return str(path.relative_to(directory))
except ValueError:
return str(path)
def _posterior_file_hashes(posterior_directory: Path) -> dict[str, str]:
"""Return hashes for the persisted posterior files that are present."""
hashes = {}
for name in ("theta_mean.npy", "Sigma_theta.npy", "metadata.json", "layout.json"):
path = posterior_directory / name
if path.is_file():
hashes[name] = _file_sha256(path)
return hashes
def _write_energy_variance_scale_file(
directory: Path,
*,
scale: float,
metadata: Mapping[str, object] | None = None,
) -> Path:
"""Write the bundle calibration file for an energy variance scale."""
if float(scale) <= 0.0:
raise ValueError("energy variance scale must be positive")
calibration_dir = directory / "calibration"
calibration_dir.mkdir(parents=True, exist_ok=True)
path = calibration_dir / "energy_variance_scale.json"
payload = {
"schema": {
"name": ENERGY_VARIANCE_SCALE_SCHEMA_NAME,
"version": ENERGY_VARIANCE_SCALE_SCHEMA_VERSION,
},
"scale": float(scale),
"metadata": normalize_checkpoint_metadata(
{} if metadata is None else dict(metadata)
),
}
path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf8")
return path
def _read_energy_variance_scale_file(path: Path) -> float:
"""Read and validate an energy variance scale file."""
payload = json.loads(path.read_text(encoding="utf8"))
if not isinstance(payload, Mapping):
raise ValueError("energy variance scale file must contain a mapping")
schema = payload.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("energy variance scale file is missing schema")
if schema.get("name") != ENERGY_VARIANCE_SCALE_SCHEMA_NAME:
raise ValueError("unsupported energy variance scale schema")
if schema.get("version") != ENERGY_VARIANCE_SCALE_SCHEMA_VERSION:
raise ValueError("unsupported energy variance scale schema version")
scale = payload.get("scale")
if not isinstance(scale, (int, float)) or float(scale) <= 0.0:
raise ValueError("energy variance scale must be positive")
return float(scale)
def _validate_file_hashes(
directory: Path,
files: object,
*,
label: str,
) -> None:
"""Validate manifest file hashes when an additive hash table is present."""
if files is None:
return
if not isinstance(files, Mapping):
raise ValueError(f"uncertainty bundle {label} file hashes must be a mapping")
for relative_path, expected in files.items():
if not isinstance(relative_path, str) or not isinstance(expected, str):
raise ValueError(
f"uncertainty bundle {label} file hashes must map strings to strings"
)
path = directory / relative_path
if not path.is_file():
raise ValueError(
f"uncertainty bundle {label} file is missing: {relative_path}"
)
if _file_sha256(path) != expected:
raise ValueError(
f"uncertainty bundle {label} hash mismatch: {relative_path}"
)
def _bundle_manifest(
*,
directory: Path,
model_checkpoint_path: Path,
posterior: BayesianLinearPosterior,
posterior_directory: Path,
source_checkpoint: Path | str | None,
aleatoric_variance: float | None,
aleatoric_noise_model_path: Path | None,
aleatoric_prediction_feature: float | None,
aleatoric_noise_bundle_path: Path | None,
energy_variance_scale_path: Path | None,
metadata: Mapping[str, object] | None,
) -> dict[str, object]:
"""Build an uncertainty prediction bundle manifest."""
source_info = None
if source_checkpoint is not None:
source_path = Path(source_checkpoint)
source_info = {
"path": str(source_path),
"sha256": _file_sha256(source_path),
}
return {
"schema": {
"name": UNCERTAINTY_BUNDLE_SCHEMA_NAME,
"version": UNCERTAINTY_BUNDLE_SCHEMA_VERSION,
},
"model_checkpoint": {
"path": model_checkpoint_path.name,
"sha256": _file_sha256(model_checkpoint_path),
},
"posterior": {
"path": "posterior",
"schema": posterior.metadata.get("schema"),
"kind": posterior.metadata.get("kind"),
"n_parameters": int(posterior.n_parameters),
"files": _posterior_file_hashes(posterior_directory),
},
"source_checkpoint": source_info,
"aleatoric_variance": None
if aleatoric_variance is None
else float(aleatoric_variance),
"aleatoric_noise_model": None
if aleatoric_noise_model_path is None
else {
"path": _relative_bundle_path(directory, aleatoric_noise_model_path),
"sha256": _file_sha256(aleatoric_noise_model_path),
"schema": {
"name": ALEATORIC_SPLINE_SCHEMA_NAME,
"version": ALEATORIC_SPLINE_SCHEMA_VERSION,
},
},
"aleatoric_prediction_feature": None
if aleatoric_prediction_feature is None
else float(aleatoric_prediction_feature),
"aleatoric_noise_bundle": None
if aleatoric_noise_bundle_path is None
else {
"path": _relative_bundle_path(directory, aleatoric_noise_bundle_path),
"sha256": _file_sha256(aleatoric_noise_bundle_path),
"schema": {
"name": ALEATORIC_BUNDLE_SCHEMA_NAME,
"version": ALEATORIC_BUNDLE_SCHEMA_VERSION,
},
},
"energy_variance_scale": None
if energy_variance_scale_path is None
else {
"path": _relative_bundle_path(directory, energy_variance_scale_path),
"sha256": _file_sha256(energy_variance_scale_path),
"schema": {
"name": ENERGY_VARIANCE_SCALE_SCHEMA_NAME,
"version": ENERGY_VARIANCE_SCALE_SCHEMA_VERSION,
},
},
"metadata": normalize_checkpoint_metadata(
{} if metadata is None else dict(metadata)
),
"directory": str(directory),
}
[docs]
def save_uncertainty_prediction_bundle(
directory: Path | str,
*,
model: UFPModel,
posterior: BayesianLinearPosterior,
source_checkpoint: Path | str | None = None,
aleatoric_variance: float | None = None,
aleatoric_noise_model: "SplineAleatoricNoiseModel | None" = None,
aleatoric_prediction_feature: float | None = None,
aleatoric_noise_bundle: SplineAleatoricNoiseBundle | None = None,
energy_variance_scale: float | None = None,
calibration_metadata: Mapping[str, object] | None = None,
metadata: Mapping[str, object] | None = None,
) -> dict[str, object]:
"""Save a reusable model/posterior bundle for predictive uncertainties."""
_validate_posterior_model_compatibility(model, posterior)
output_dir = Path(directory)
output_dir.mkdir(parents=True, exist_ok=True)
model_checkpoint_path = output_dir / "model_checkpoint.pt"
torch.save(
{
"state_dict": model.state_dict(),
"model_schema": model_schema(model),
"metadata": {
"kind": "uncertainty_prediction_bundle_model",
},
},
model_checkpoint_path,
)
posterior_directory = output_dir / "posterior"
posterior.save_memmap(posterior_directory)
aleatoric_noise_model_path = None
if aleatoric_noise_model is not None:
if not isinstance(aleatoric_noise_model, SplineAleatoricNoiseModel):
raise TypeError(
"`aleatoric_noise_model` must be a SplineAleatoricNoiseModel"
)
aleatoric_directory = output_dir / "aleatoric"
aleatoric_directory.mkdir(parents=True, exist_ok=True)
aleatoric_noise_model_path = aleatoric_directory / "spline_noise_model.pt"
torch.save(
aleatoric_noise_model.to_payload(),
aleatoric_noise_model_path,
)
aleatoric_noise_bundle_path = None
if aleatoric_noise_bundle is not None:
if not isinstance(aleatoric_noise_bundle, SplineAleatoricNoiseBundle):
raise TypeError(
"`aleatoric_noise_bundle` must be a SplineAleatoricNoiseBundle"
)
aleatoric_directory = output_dir / "aleatoric"
aleatoric_directory.mkdir(parents=True, exist_ok=True)
aleatoric_noise_bundle_path = aleatoric_directory / "spline_noise_bundle.pt"
torch.save(
aleatoric_noise_bundle.to_payload(),
aleatoric_noise_bundle_path,
)
energy_variance_scale_path = None
if energy_variance_scale is not None:
energy_variance_scale_path = _write_energy_variance_scale_file(
output_dir,
scale=float(energy_variance_scale),
metadata=calibration_metadata,
)
manifest = _bundle_manifest(
directory=output_dir,
model_checkpoint_path=model_checkpoint_path,
posterior=posterior,
posterior_directory=posterior_directory,
source_checkpoint=source_checkpoint,
aleatoric_variance=aleatoric_variance,
aleatoric_noise_model_path=aleatoric_noise_model_path,
aleatoric_prediction_feature=aleatoric_prediction_feature,
aleatoric_noise_bundle_path=aleatoric_noise_bundle_path,
energy_variance_scale_path=energy_variance_scale_path,
metadata=metadata,
)
(output_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2) + "\n",
encoding="utf8",
)
return manifest
def _validate_bundle_manifest(manifest: object) -> Mapping[str, object]:
"""Validate the outer uncertainty bundle manifest schema."""
if not isinstance(manifest, Mapping):
raise ValueError("uncertainty bundle manifest must be a mapping")
schema = manifest.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("uncertainty bundle manifest is missing schema metadata")
if schema.get("name") != UNCERTAINTY_BUNDLE_SCHEMA_NAME:
raise ValueError("unsupported uncertainty bundle schema")
if schema.get("version") != UNCERTAINTY_BUNDLE_SCHEMA_VERSION:
raise ValueError("unsupported uncertainty bundle schema version")
return manifest
[docs]
def load_uncertainty_prediction_bundle(
directory: Path | str,
*,
dtype: torch.dtype | None = None,
validate: bool = True,
) -> UncertaintyPredictionBundle:
"""Load a reusable model/posterior bundle for predictive uncertainties."""
input_dir = Path(directory)
manifest = dict(
_validate_bundle_manifest(
json.loads((input_dir / "manifest.json").read_text(encoding="utf8"))
)
)
model_info = manifest.get("model_checkpoint")
if not isinstance(model_info, Mapping):
raise ValueError("uncertainty bundle manifest is missing model checkpoint")
model_path = input_dir / str(model_info.get("path", "model_checkpoint.pt"))
if validate and "sha256" in model_info:
digest = _file_sha256(model_path)
if digest != model_info["sha256"]:
raise ValueError("uncertainty bundle model checkpoint hash mismatch")
checkpoint = _torch_load(model_path, map_location="cpu")
if not isinstance(checkpoint, Mapping):
raise ValueError("uncertainty bundle model checkpoint must be a mapping")
resolved_dtype = _infer_checkpoint_dtype(checkpoint) if dtype is None else dtype
model = load_model_from_checkpoint(checkpoint, dtype=resolved_dtype)
posterior_info = manifest.get("posterior")
if not isinstance(posterior_info, Mapping):
raise ValueError("uncertainty bundle manifest is missing posterior metadata")
posterior_path = input_dir / str(posterior_info.get("path", "posterior"))
if validate:
posterior_files = posterior_info.get("files")
if posterior_files is not None:
_validate_file_hashes(
posterior_path,
posterior_files,
label="posterior",
)
posterior = BayesianLinearPosterior.load_memmap(posterior_path)
if posterior.layout is None:
raise ValueError("uncertainty bundle posterior is missing layout")
expected_n_parameters = posterior_info.get("n_parameters")
if expected_n_parameters is not None and int(expected_n_parameters) != int(
posterior.n_parameters
):
raise ValueError("uncertainty bundle posterior parameter count mismatch")
if validate:
_validate_posterior_model_compatibility(model, posterior)
aleatoric_value = manifest.get("aleatoric_variance")
if aleatoric_value is None:
aleatoric_variance = None
elif isinstance(aleatoric_value, (int, float)):
aleatoric_variance = float(aleatoric_value)
else:
raise ValueError("uncertainty bundle aleatoric_variance must be numeric")
aleatoric_noise_model = None
noise_info = manifest.get("aleatoric_noise_model")
if noise_info is not None:
if not isinstance(noise_info, Mapping):
raise ValueError(
"uncertainty bundle aleatoric_noise_model must be a mapping"
)
noise_schema = noise_info.get("schema")
if not isinstance(noise_schema, Mapping):
raise ValueError(
"uncertainty bundle aleatoric_noise_model is missing schema"
)
if noise_schema.get("name") != ALEATORIC_SPLINE_SCHEMA_NAME:
raise ValueError("unsupported aleatoric noise model schema")
if noise_schema.get("version") != ALEATORIC_SPLINE_SCHEMA_VERSION:
raise ValueError("unsupported aleatoric noise model schema version")
noise_path = input_dir / str(
noise_info.get("path", "aleatoric/spline_noise_model.pt")
)
if validate and "sha256" in noise_info:
if _file_sha256(noise_path) != noise_info["sha256"]:
raise ValueError(
"uncertainty bundle aleatoric noise model hash mismatch"
)
aleatoric_noise_model = SplineAleatoricNoiseModel.from_payload(
_torch_load(noise_path, map_location="cpu"),
dtype=resolved_dtype,
)
prediction_feature = manifest.get("aleatoric_prediction_feature")
if prediction_feature is None:
aleatoric_prediction_feature = None
elif isinstance(prediction_feature, (int, float)):
aleatoric_prediction_feature = float(prediction_feature)
else:
raise ValueError(
"uncertainty bundle aleatoric_prediction_feature must be numeric"
)
aleatoric_noise_bundle = None
bundle_info = manifest.get("aleatoric_noise_bundle")
if bundle_info is not None:
if not isinstance(bundle_info, Mapping):
raise ValueError(
"uncertainty bundle aleatoric_noise_bundle must be a mapping"
)
bundle_schema = bundle_info.get("schema")
if not isinstance(bundle_schema, Mapping):
raise ValueError(
"uncertainty bundle aleatoric_noise_bundle is missing schema"
)
if bundle_schema.get("name") != ALEATORIC_BUNDLE_SCHEMA_NAME:
raise ValueError("unsupported aleatoric noise bundle schema")
if bundle_schema.get("version") != ALEATORIC_BUNDLE_SCHEMA_VERSION:
raise ValueError("unsupported aleatoric noise bundle schema version")
bundle_path = input_dir / str(
bundle_info.get("path", "aleatoric/spline_noise_bundle.pt")
)
if validate and "sha256" in bundle_info:
if _file_sha256(bundle_path) != bundle_info["sha256"]:
raise ValueError(
"uncertainty bundle aleatoric noise bundle hash mismatch"
)
aleatoric_noise_bundle = SplineAleatoricNoiseBundle.from_payload(
_torch_load(bundle_path, map_location="cpu"),
dtype=resolved_dtype,
)
energy_scale = 1.0
scale_info = manifest.get("energy_variance_scale")
if scale_info is None:
energy_scale = 1.0
elif isinstance(scale_info, (int, float)):
energy_scale = float(scale_info)
elif isinstance(scale_info, Mapping):
scale_path = input_dir / str(
scale_info.get("path", "calibration/energy_variance_scale.json")
)
if validate and "sha256" in scale_info:
if _file_sha256(scale_path) != scale_info["sha256"]:
raise ValueError("uncertainty bundle calibration hash mismatch")
energy_scale = _read_energy_variance_scale_file(scale_path)
else:
raise ValueError("uncertainty bundle energy_variance_scale is invalid")
if float(energy_scale) <= 0.0:
raise ValueError("uncertainty bundle energy variance scale must be positive")
return UncertaintyPredictionBundle(
model=model,
posterior=posterior,
aleatoric_variance=aleatoric_variance,
aleatoric_noise_model=aleatoric_noise_model,
aleatoric_prediction_feature=aleatoric_prediction_feature,
aleatoric_noise_bundle=aleatoric_noise_bundle,
energy_variance_scale=float(energy_scale),
manifest=manifest,
)
[docs]
def save_energy_variance_scale_to_bundle(
directory: Path | str,
scale: float,
*,
metadata: Mapping[str, object] | None = None,
) -> dict[str, object]:
"""Save or replace the post-hoc energy variance scale in a bundle."""
bundle_dir = Path(directory)
manifest_path = bundle_dir / "manifest.json"
manifest = dict(
_validate_bundle_manifest(json.loads(manifest_path.read_text(encoding="utf8")))
)
scale_path = _write_energy_variance_scale_file(
bundle_dir,
scale=float(scale),
metadata=metadata,
)
manifest["energy_variance_scale"] = {
"path": _relative_bundle_path(bundle_dir, scale_path),
"sha256": _file_sha256(scale_path),
"schema": {
"name": ENERGY_VARIANCE_SCALE_SCHEMA_NAME,
"version": ENERGY_VARIANCE_SCALE_SCHEMA_VERSION,
},
}
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf8")
return manifest
[docs]
def validate_uncertainty_prediction_bundle(
directory: Path | str,
*,
dtype: torch.dtype | None = None,
) -> bool:
"""Return ``True`` when a bundle loads and validates successfully."""
load_uncertainty_prediction_bundle(directory, dtype=dtype, validate=True)
return True
[docs]
def describe_uncertainty_prediction_bundle(
directory: Path | str,
*,
dtype: torch.dtype | None = None,
validate: bool = True,
) -> dict[str, object]:
"""Return a compact human-readable summary of an uncertainty bundle."""
bundle_dir = Path(directory)
manifest = json.loads((bundle_dir / "manifest.json").read_text(encoding="utf8"))
summary: dict[str, object] = {
"directory": str(bundle_dir),
"schema": manifest.get("schema") if isinstance(manifest, Mapping) else None,
"valid": False,
"validation_error": None,
}
try:
bundle = load_uncertainty_prediction_bundle(
bundle_dir,
dtype=dtype,
validate=validate,
)
except Exception as exc:
summary["validation_error"] = str(exc)
return summary
posterior = bundle.posterior
layout = posterior.layout
summary.update(
{
"valid": True,
"posterior_kind": posterior.metadata.get("kind"),
"posterior_parameters": int(posterior.n_parameters),
"posterior_blocks": 0 if layout is None else len(layout.blocks),
"has_scalar_aleatoric_variance": bundle.aleatoric_variance is not None,
"has_legacy_aleatoric_noise_model": (
bundle.aleatoric_noise_model is not None
),
"has_aleatoric_noise_bundle": bundle.aleatoric_noise_bundle is not None,
"has_energy_variance_scale": (
manifest.get("energy_variance_scale") is not None
),
"energy_variance_scale": float(bundle.energy_variance_scale),
"source_checkpoint": manifest.get("source_checkpoint"),
"model_checkpoint": manifest.get("model_checkpoint"),
}
)
if bundle.aleatoric_noise_bundle is not None:
summary["aleatoric_heads"] = {
ALEATORIC_ENERGY_KIND: (
bundle.aleatoric_noise_bundle.energy_per_atom is not None
),
ALEATORIC_PER_ATOM_KIND: (
bundle.aleatoric_noise_bundle.per_atom_energy is not None
),
ALEATORIC_FORCE_KIND: (
bundle.aleatoric_noise_bundle.force_component is not None
),
}
summary["aleatoric_feature_kind"] = (
bundle.aleatoric_noise_bundle.feature_spec.kind
)
return summary
def _row_variances_for_batch(
row_variances: torch.Tensor | None,
*,
offset: int,
n_rows: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor | None:
"""Return row variances for one batch or ``None`` for unit noise."""
if row_variances is None:
return None
if row_variances.numel() == 1:
return row_variances.to(dtype=dtype, device=device).expand(n_rows)
return row_variances[offset : offset + n_rows].to(dtype=dtype, device=device)
def _normalize_row_variances(
row_variances: object,
*,
n_rows: int,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor | None:
"""Validate optional observation-noise variances."""
if row_variances is None:
return None
variances = torch.as_tensor(row_variances, dtype=dtype, device=device).reshape(-1)
if variances.numel() not in {1, int(n_rows)}:
raise ValueError(
"`row_variances` must be scalar or have one entry per target row"
)
if bool(torch.any(variances <= 0)):
raise ValueError("`row_variances` must be strictly positive")
return variances
def _accumulate_precision_and_rhs(
problem: BlockLinearProblem,
*,
row_variances: object = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Accumulate ``A.T R^-1 A + Lambda`` and the matching RHS."""
normalized_variances = _normalize_row_variances(
row_variances,
n_rows=problem.n_rows,
dtype=problem.dtype,
device=problem.device,
)
precision = torch.zeros(
(problem.layout.size, problem.layout.size),
dtype=problem.dtype,
device=problem.device,
)
rhs = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
offset = 0
for batch in problem.batches:
variances = _row_variances_for_batch(
normalized_variances,
offset=offset,
n_rows=batch.n_rows,
dtype=problem.dtype,
device=problem.device,
)
target = batch.target
precision_weights = None if variances is None else variances.reciprocal()
keys = tuple(batch.matrices)
dense_matrices = {
key: _materialize_block_matrix(matrix).to(
dtype=problem.dtype,
device=problem.device,
)
for key, matrix in batch.matrices.items()
}
weighted_target = (
target
if precision_weights is None
else target * precision_weights.to(dtype=target.dtype, device=target.device)
)
for key in keys:
theta_slice = problem.layout.theta_slice(key)
matrix = dense_matrices[key]
rhs[theta_slice] += matrix.T @ weighted_target
for index_i, key_i in enumerate(keys):
slice_i = problem.layout.theta_slice(key_i)
matrix_i = dense_matrices[key_i]
weighted_i = (
matrix_i
if precision_weights is None
else matrix_i * precision_weights[:, None]
)
precision[slice_i, slice_i] += matrix_i.T @ weighted_i
for key_j in keys[index_i + 1 :]:
slice_j = problem.layout.theta_slice(key_j)
matrix_j = dense_matrices[key_j]
cross = matrix_i.T @ (
matrix_j
if precision_weights is None
else matrix_j * precision_weights[:, None]
)
precision[slice_i, slice_j] += cross
precision[slice_j, slice_i] += cross.T
offset += batch.n_rows
for block in problem.layout.blocks:
if block.regularization is None:
continue
theta_slice = problem.layout.theta_slice(block.key)
precision[theta_slice, theta_slice] += block.regularization.materialize(
dtype=problem.dtype,
device=problem.device,
)
rhs[theta_slice] += block.regularization.rhs(
dtype=problem.dtype,
device=problem.device,
)
return precision, rhs
def _solve_dense_system(matrix: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
"""Solve a dense system, falling back to least squares if needed."""
try:
return torch.linalg.solve(matrix, rhs)
except RuntimeError:
return torch.linalg.lstsq(matrix, rhs).solution
def _invert_precision(precision: torch.Tensor, *, jitter: float = 0.0) -> torch.Tensor:
"""Invert a dense precision matrix with a pseudo-inverse fallback."""
if float(jitter) != 0.0:
eye = torch.eye(
precision.shape[0],
dtype=precision.dtype,
device=precision.device,
)
precision = precision + float(jitter) * eye
eye = torch.eye(
precision.shape[0],
dtype=precision.dtype,
device=precision.device,
)
chol, info = torch.linalg.cholesky_ex(precision)
if int(info.item()) == 0:
return torch.cholesky_solve(eye, chol)
return torch.linalg.pinv(precision)
def _store_covariance(
sigma: torch.Tensor,
covariance_path: Path | str | None,
) -> torch.Tensor | np.ndarray:
"""Return dense covariance or write it to an ``.npy`` memmap."""
if covariance_path is None:
return sigma
path = Path(covariance_path)
path.parent.mkdir(parents=True, exist_ok=True)
array = sigma.detach().cpu().numpy()
mmap = np.lib.format.open_memmap(
path,
mode="w+",
dtype=array.dtype,
shape=array.shape,
)
mmap[:] = array
mmap.flush()
return np.load(path, mmap_mode="r+")
def _effective_row_variances(
row_variances: object,
observation_noise_variance: float,
) -> object:
"""Return explicit row variances or a scalar observation variance."""
if row_variances is not None:
return row_variances
if float(observation_noise_variance) != 1.0:
return float(observation_noise_variance)
return None
def _fit_aleatoric_row_variances(
problem: BlockLinearProblem,
theta: torch.Tensor,
*,
noise_model: SplineAleatoricNoiseModel,
features: object,
steps: int,
lr: float,
) -> torch.Tensor:
"""Train/evaluate a spline noise model on current residual rows."""
feature_tensor = torch.as_tensor(
features,
dtype=problem.dtype,
device=problem.device,
).reshape(-1)
if int(feature_tensor.numel()) != problem.n_rows:
raise ValueError("`aleatoric_features` must have one entry per target row")
residuals = (problem.matvec(theta) - problem.target_vector()).detach()
if int(steps) > 0:
noise_model.fit_residuals(
residuals,
feature_tensor,
steps=int(steps),
lr=float(lr),
)
return noise_model(feature_tensor).detach()
def _fit_aleatoric_noise_bundle_row_variances(
problem: BlockLinearProblem,
theta: torch.Tensor,
*,
noise_bundle: SplineAleatoricNoiseBundle,
row_metadata: AleatoricRowMetadata,
steps: int,
lr: float,
fallback_variance: float,
) -> torch.Tensor:
"""Train/evaluate separate aleatoric heads and return row variances."""
features = row_metadata.features.to(dtype=problem.dtype, device=problem.device)
if len(row_metadata.kinds) != problem.n_rows:
raise ValueError("aleatoric row metadata must have one entry per target row")
residuals = (problem.matvec(theta) - problem.target_vector()).detach()
variances = torch.full(
(problem.n_rows,),
fill_value=float(fallback_variance),
dtype=problem.dtype,
device=problem.device,
)
if bool(torch.any(variances <= 0.0)):
raise ValueError("fallback aleatoric variance must be positive")
kind_array = np.asarray(row_metadata.kinds, dtype=object)
for kind in (ALEATORIC_ENERGY_KIND, ALEATORIC_PER_ATOM_KIND, ALEATORIC_FORCE_KIND):
head = noise_bundle.head_for_kind(kind)
if head is None:
continue
positions_np = np.nonzero(kind_array == kind)[0]
if positions_np.size == 0:
continue
positions = torch.as_tensor(
positions_np,
dtype=torch.int64,
device=problem.device,
)
kind_features = features.index_select(0, positions)
kind_residuals = residuals.index_select(0, positions)
if int(steps) > 0:
head.fit_residuals(
kind_residuals,
kind_features,
steps=int(steps),
lr=float(lr),
)
variances.index_copy_(0, positions, head(kind_features).detach())
return variances
[docs]
def fit_linear_uncertainty_model(
model: UFPModel,
samples: Sequence[FitSample],
*,
fitter: LinearFitter | None = None,
problem: BlockLinearProblem | None = None,
refit_mean: bool = True,
write_back: bool = True,
row_variances: object = None,
observation_noise_variance: float = 1.0,
aleatoric_noise_model: SplineAleatoricNoiseModel | None = None,
aleatoric_features: object = None,
aleatoric_noise_bundle: SplineAleatoricNoiseBundle | None = None,
aleatoric_row_metadata: AleatoricRowMetadata | None = None,
aleatoric_steps: int = 0,
aleatoric_lr: float = 1.0e-2,
batch_size: int = 32,
progress: bool = False,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
covariance_path: Path | str | None = None,
jitter: float = 0.0,
**fitter_kwargs: Any,
) -> BayesianLinearPosterior:
"""Fit or load the mean and build a dense linear Bayesian posterior."""
resolved_fitter = LinearFitter(model, **fitter_kwargs) if fitter is None else fitter
if problem is None:
problem = resolved_fitter.build_problem(
samples,
batch_size=batch_size,
progress=progress,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
elif int(problem.layout.size) != int(resolved_fitter._selected_size()):
raise ValueError(
"`problem` layout size does not match this fitter's selected "
"coefficient layout"
)
if row_variances is not None and (
aleatoric_noise_model is not None or aleatoric_noise_bundle is not None
):
raise ValueError(
"use only one of `row_variances`, `aleatoric_noise_model`, "
"and `aleatoric_noise_bundle`"
)
if aleatoric_noise_model is not None and aleatoric_noise_bundle is not None:
raise ValueError(
"use only one of `aleatoric_noise_model` and `aleatoric_noise_bundle`"
)
effective_variances = _effective_row_variances(
row_variances,
observation_noise_variance,
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if refit_mean:
theta_mean = _solve_dense_system(precision, rhs)
else:
theta_mean = resolved_fitter._current_selected_vector(
dtype=problem.dtype,
device=problem.device,
)
if aleatoric_noise_model is not None:
if aleatoric_features is None:
raise ValueError(
"`aleatoric_features` is required with `aleatoric_noise_model`"
)
effective_variances = _fit_aleatoric_row_variances(
problem,
theta_mean,
noise_model=aleatoric_noise_model,
features=aleatoric_features,
steps=aleatoric_steps,
lr=aleatoric_lr,
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if refit_mean:
theta_mean = _solve_dense_system(precision, rhs)
if aleatoric_noise_bundle is not None:
resolved_row_metadata = (
build_aleatoric_row_metadata(
samples,
fit_energy=resolved_fitter.fit_energy,
fit_forces=resolved_fitter.fit_forces,
fit_per_atom_energy=resolved_fitter.fit_per_atom_energy,
dtype=problem.dtype,
device=problem.device,
feature_spec=aleatoric_noise_bundle.feature_spec,
)
if aleatoric_row_metadata is None
else aleatoric_row_metadata
)
effective_variances = _fit_aleatoric_noise_bundle_row_variances(
problem,
theta_mean,
noise_bundle=aleatoric_noise_bundle,
row_metadata=resolved_row_metadata,
steps=aleatoric_steps,
lr=aleatoric_lr,
fallback_variance=float(observation_noise_variance),
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if refit_mean:
theta_mean = _solve_dense_system(precision, rhs)
if refit_mean and write_back:
resolved_fitter.write_back(theta_mean)
sigma = _invert_precision(precision, jitter=jitter)
metadata = _posterior_metadata(
model=model,
fitter=resolved_fitter,
problem=problem,
kind="linear",
extra={
"refit_mean": bool(refit_mean),
"observation_noise_variance": float(observation_noise_variance),
"row_variances": None
if row_variances is None
else {
"shape": list(torch.as_tensor(row_variances).reshape(-1).shape),
"values_hash": _hash_tensor_values(
torch.as_tensor(row_variances).reshape(-1)
),
},
"aleatoric_noise_model": None
if aleatoric_noise_model is None
else {
"class": type(aleatoric_noise_model).__name__,
"trained_steps": int(aleatoric_steps),
"feature_shape": list(
torch.as_tensor(aleatoric_features).reshape(-1).shape
),
"feature_hash": _hash_tensor_values(
torch.as_tensor(aleatoric_features).reshape(-1)
),
},
"aleatoric_noise_bundle": None
if aleatoric_noise_bundle is None
else {
"class": type(aleatoric_noise_bundle).__name__,
"trained_steps": int(aleatoric_steps),
"feature_kind": aleatoric_noise_bundle.feature_spec.kind,
"heads": {
ALEATORIC_ENERGY_KIND: (
aleatoric_noise_bundle.energy_per_atom is not None
),
ALEATORIC_PER_ATOM_KIND: (
aleatoric_noise_bundle.per_atom_energy is not None
),
ALEATORIC_FORCE_KIND: (
aleatoric_noise_bundle.force_component is not None
),
},
},
},
)
return BayesianLinearPosterior(
theta_mean=theta_mean.detach().clone(),
Sigma_theta=_store_covariance(sigma, covariance_path),
metadata=metadata,
layout=problem.layout,
)
def _provider_proxy_key(provider_index: int, proxy_index: int) -> tuple[str, int, int]:
"""Build a stable solve key for a fixed-weight proxy block."""
return ("proxy", int(provider_index), int(proxy_index))
def _scale_block_matrix(matrix: BlockMatrix, scalar: torch.Tensor) -> BlockMatrix:
"""Scale one block matrix, materializing only for the simple public path."""
dense = _materialize_block_matrix(matrix)
return dense * scalar.to(dtype=dense.dtype, device=dense.device)
def _make_fixed_weight_alchemical_problem(
als_fitter: AlchemicalALSFitter,
true_problem: BlockLinearProblem,
) -> BlockLinearProblem:
"""Build one joint direct/proxy problem with alchemical weights frozen."""
layout = als_fitter.layout
direct_block_indices = als_fitter._active_direct_blocks()
provider_groups = layout.non_identity_providers()
provider_index_by_id = {
id(group.provider): index for index, group in enumerate(provider_groups)
}
solve_blocks: list[SolveBlock] = []
for block_index in direct_block_indices:
block = layout.block(block_index)
solve_blocks.append(
SolveBlock(
key=block_index,
size=block.size,
label=block.label,
regularization=_make_block_regularization(
block.shape,
ridge=als_fitter._ridge_for_block_index(block_index),
third_difference_penalty=(
als_fitter._third_difference_for_block_index(block_index)
),
active_rows=_twobody_shape_regularization_rows(block),
),
)
)
for provider_index, provider_group in enumerate(provider_groups):
for proxy_index in range(provider_group.n_proxy_terms):
solve_blocks.append(
SolveBlock(
key=_provider_proxy_key(provider_index, proxy_index),
size=provider_group.block_size,
label=f"proxy[{provider_index},{proxy_index}]",
regularization=_make_block_regularization(
provider_group.coefficient_shape,
ridge=als_fitter._ridge_for_provider(provider_group),
third_difference_penalty=(
als_fitter._third_difference_for_provider(provider_group)
),
active_rows=als_fitter._provider_twobody_active_rows(
provider_group
),
),
)
)
batches: list[BlockSolveBatch] = []
for batch in true_problem.batches:
matrices: dict[object, BlockMatrix] = {}
for block_index, matrix in batch.matrices.items():
block = layout.block(int(block_index))
provider = block.coefficient_provider
if block_index in direct_block_indices:
matrices[block_index] = matrix
continue
if provider is None or provider.weights is None:
continue
provider_index = provider_index_by_id[id(provider)]
assert block.coefficient_index is not None
weights = provider.weights.detach().to(
dtype=true_problem.dtype,
device=true_problem.device,
)[block.coefficient_index]
for proxy_index in torch.nonzero(weights != 0, as_tuple=False).reshape(-1):
key = _provider_proxy_key(provider_index, int(proxy_index.item()))
contribution = _scale_block_matrix(matrix, weights[proxy_index])
if key in matrices:
matrices[key] = _materialize_block_matrix(
matrices[key]
) + _materialize_block_matrix(contribution)
else:
matrices[key] = contribution
batches.append(BlockSolveBatch(target=batch.target, matrices=matrices))
return BlockLinearProblem(
layout=BlockProblemLayout.from_blocks(tuple(solve_blocks)),
batches=tuple(batches),
)
def _current_fixed_weight_vector(
als_fitter: AlchemicalALSFitter,
problem: BlockLinearProblem,
) -> torch.Tensor:
"""Read direct coefficients and proxy coefficients in fixed-weight layout."""
output = torch.zeros(
(problem.layout.size,),
dtype=problem.dtype,
device=problem.device,
)
current_true = als_fitter.layout.current_true_vector(
dtype=problem.dtype,
device=problem.device,
)
for block_index in als_fitter._active_direct_blocks():
output[problem.layout.theta_slice(block_index)] = current_true[
als_fitter.layout.block(block_index).theta_slice
]
for provider_index, provider_group in enumerate(
als_fitter.layout.non_identity_providers()
):
proxy = provider_group.provider.proxy_coeffs.detach().reshape(
provider_group.n_proxy_terms,
provider_group.block_size,
)
proxy = proxy.to(dtype=problem.dtype, device=problem.device)
for proxy_index in range(provider_group.n_proxy_terms):
theta_slice = problem.layout.theta_slice(
_provider_proxy_key(provider_index, proxy_index)
)
output[theta_slice] = proxy[proxy_index]
return output
def _write_fixed_weight_vector(
als_fitter: AlchemicalALSFitter,
problem: BlockLinearProblem,
theta: torch.Tensor,
) -> None:
"""Write a fixed-weight direct/proxy vector back into the model."""
for block_index in als_fitter._active_direct_blocks():
theta_slice = problem.layout.theta_slice(block_index)
als_fitter.layout.write_block_vector(block_index, theta[theta_slice])
for provider_index, provider_group in enumerate(
als_fitter.layout.non_identity_providers()
):
provider = provider_group.provider
for proxy_index in range(provider_group.n_proxy_terms):
theta_slice = problem.layout.theta_slice(
_provider_proxy_key(provider_index, proxy_index)
)
provider.proxy_coeffs.data[proxy_index].copy_(
theta[theta_slice]
.reshape(provider_group.coefficient_shape)
.to(provider.proxy_coeffs)
)
def _load_alchemical_checkpoint(model: UFPModel, checkpoint_path: Path | str) -> None:
"""Load a fitted alchemical checkpoint into ``model``."""
payload = _torch_load(checkpoint_path, map_location="cpu")
if not isinstance(payload, Mapping):
raise ValueError("alchemical checkpoint must contain a mapping")
state_dict = payload.get("model_state_dict", payload.get("state_dict"))
if not isinstance(state_dict, Mapping):
raise ValueError(
"alchemical checkpoint must contain `model_state_dict` or `state_dict`"
)
model.load_state_dict(state_dict)
[docs]
def fit_alchemical_uncertainty_model(
model: UFPModel,
samples: Sequence[FitSample],
*,
mean_source: Literal["checkpoint", "als"] = "checkpoint",
checkpoint_path: Path | str | None = None,
fixed_weight_refit: bool = False,
row_variances: object = None,
observation_noise_variance: float = 1.0,
aleatoric_noise_model: SplineAleatoricNoiseModel | None = None,
aleatoric_features: object = None,
aleatoric_noise_bundle: SplineAleatoricNoiseBundle | None = None,
aleatoric_row_metadata: AleatoricRowMetadata | None = None,
aleatoric_steps: int = 0,
aleatoric_lr: float = 1.0e-2,
batch_size: int = 32,
progress: bool = False,
cache_directory: Path | str | None = None,
cache_mode: AssembledBatchCacheMode = "auto",
covariance_path: Path | str | None = None,
jitter: float = 0.0,
als_fitter: AlchemicalALSFitter | None = None,
**fitter_kwargs: Any,
) -> BayesianLinearPosterior:
"""Build a fixed-weight Bayesian posterior for an alchemical UFP model."""
if mean_source not in {"checkpoint", "als"}:
raise ValueError("`mean_source` must be 'checkpoint' or 'als'")
if checkpoint_path is not None:
_load_alchemical_checkpoint(model, checkpoint_path)
resolved_als = (
AlchemicalALSFitter(model, **fitter_kwargs)
if als_fitter is None
else als_fitter
)
if mean_source == "als":
result = resolved_als.fit(
samples,
batch_size=batch_size,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
true_problem = result.problem
else:
true_problem = resolved_als.linear_fitter.build_problem(
samples,
batch_size=batch_size,
progress=progress,
cache_directory=cache_directory,
cache_mode=cache_mode,
)
problem = _make_fixed_weight_alchemical_problem(resolved_als, true_problem)
if row_variances is not None and (
aleatoric_noise_model is not None or aleatoric_noise_bundle is not None
):
raise ValueError(
"use only one of `row_variances`, `aleatoric_noise_model`, "
"and `aleatoric_noise_bundle`"
)
if aleatoric_noise_model is not None and aleatoric_noise_bundle is not None:
raise ValueError(
"use only one of `aleatoric_noise_model` and `aleatoric_noise_bundle`"
)
effective_variances = _effective_row_variances(
row_variances,
observation_noise_variance,
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if fixed_weight_refit:
theta_mean = _solve_dense_system(precision, rhs)
else:
theta_mean = _current_fixed_weight_vector(resolved_als, problem)
if aleatoric_noise_model is not None:
if aleatoric_features is None:
raise ValueError(
"`aleatoric_features` is required with `aleatoric_noise_model`"
)
effective_variances = _fit_aleatoric_row_variances(
problem,
theta_mean,
noise_model=aleatoric_noise_model,
features=aleatoric_features,
steps=aleatoric_steps,
lr=aleatoric_lr,
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if fixed_weight_refit:
theta_mean = _solve_dense_system(precision, rhs)
if aleatoric_noise_bundle is not None:
resolved_row_metadata = (
build_aleatoric_row_metadata(
samples,
fit_energy=resolved_als.linear_fitter.fit_energy,
fit_forces=resolved_als.linear_fitter.fit_forces,
fit_per_atom_energy=resolved_als.linear_fitter.fit_per_atom_energy,
dtype=problem.dtype,
device=problem.device,
feature_spec=aleatoric_noise_bundle.feature_spec,
)
if aleatoric_row_metadata is None
else aleatoric_row_metadata
)
effective_variances = _fit_aleatoric_noise_bundle_row_variances(
problem,
theta_mean,
noise_bundle=aleatoric_noise_bundle,
row_metadata=resolved_row_metadata,
steps=aleatoric_steps,
lr=aleatoric_lr,
fallback_variance=float(observation_noise_variance),
)
precision, rhs = _accumulate_precision_and_rhs(
problem,
row_variances=effective_variances,
)
if fixed_weight_refit:
theta_mean = _solve_dense_system(precision, rhs)
if fixed_weight_refit:
_write_fixed_weight_vector(resolved_als, problem, theta_mean)
sigma = _invert_precision(precision, jitter=jitter)
metadata = _posterior_metadata(
model=model,
fitter=resolved_als.linear_fitter,
problem=problem,
kind="alchemical_fixed_weight",
extra={
"mean_source": mean_source,
"fixed_weight_refit": bool(fixed_weight_refit),
"observation_noise_variance": float(observation_noise_variance),
"aleatoric_noise_model": None
if aleatoric_noise_model is None
else {
"class": type(aleatoric_noise_model).__name__,
"trained_steps": int(aleatoric_steps),
"feature_shape": list(
torch.as_tensor(aleatoric_features).reshape(-1).shape
),
"feature_hash": _hash_tensor_values(
torch.as_tensor(aleatoric_features).reshape(-1)
),
},
"aleatoric_noise_bundle": None
if aleatoric_noise_bundle is None
else {
"class": type(aleatoric_noise_bundle).__name__,
"trained_steps": int(aleatoric_steps),
"feature_kind": aleatoric_noise_bundle.feature_spec.kind,
"heads": {
ALEATORIC_ENERGY_KIND: (
aleatoric_noise_bundle.energy_per_atom is not None
),
ALEATORIC_PER_ATOM_KIND: (
aleatoric_noise_bundle.per_atom_energy is not None
),
ALEATORIC_FORCE_KIND: (
aleatoric_noise_bundle.force_component is not None
),
},
},
},
)
return BayesianLinearPosterior(
theta_mean=theta_mean.detach().clone(),
Sigma_theta=_store_covariance(sigma, covariance_path),
metadata=metadata,
layout=problem.layout,
)
def _zero_fit_sample(
atoms: ase.Atoms,
*,
neighbor_list: NeighborListData | None,
include_forces: bool,
) -> FitSample:
"""Return a synthetic labeled sample used only for row assembly."""
return FitSample(
atoms=atoms,
neighbor_list=neighbor_list,
forces=np.zeros((len(atoms), 3), dtype=float) if include_forces else None,
per_atom_energy=np.zeros(len(atoms), dtype=float),
)
def _sparse_rows_from_batch(
batch: BlockSolveBatch,
layout: BlockProblemLayout,
rows: torch.Tensor,
) -> tuple[SparseLinearRow, ...]:
"""Extract sparse global rows from one assembled block batch."""
rows = rows.reshape(-1).to(dtype=torch.int64)
output: list[SparseLinearRow] = []
for row in rows:
if int(row.item()) < 0:
output.append(
SparseLinearRow(
indices=torch.empty(0, dtype=torch.int64),
values=torch.empty(0, dtype=batch.target.dtype),
size=layout.size,
)
)
continue
dense = torch.zeros(
(layout.size,),
dtype=batch.target.dtype,
device=batch.target.device,
)
one_row = row.reshape(1)
for key, matrix in batch.matrices.items():
theta_slice = layout.theta_slice(key)
values = _block_matrix_values_on_rows(matrix, one_row).reshape(-1)
dense[theta_slice] += values.to(dtype=dense.dtype, device=dense.device)
output.append(_row_from_dense(dense, size=layout.size))
return tuple(output)
def _prediction_batch_for_linear(
fitter: LinearFitter,
atoms: ase.Atoms,
*,
neighbor_list: NeighborListData | None,
include_forces: bool,
) -> tuple[BlockSolveBatch, torch.Tensor, torch.Tensor | None]:
"""Assemble per-atom and optional force prediction rows for a linear fitter."""
prepared = prepare_batches(
fitter.model,
[
_zero_fit_sample(
atoms,
neighbor_list=neighbor_list,
include_forces=include_forces,
)
],
batch_size=1,
fit_energy=False,
fit_forces=include_forces,
fit_per_atom_energy=True,
dtype=fitter.dtype,
device=fitter.device,
)[0]
assembled = fitter._assemble_true_blocks(prepared)
problem_batch = BlockSolveBatch(
target=assembled.target,
matrices=fitter._apply_matrix_storage(assembled).block_matrices,
)
force_rows = prepared.targets.force_rows if include_forces else None
return problem_batch, prepared.targets.per_atom_rows, force_rows
def _prediction_batch_for_fixed_alchemical(
model: UFPModel,
posterior_layout: BlockProblemLayout,
atoms: ase.Atoms,
*,
neighbor_list: NeighborListData | None,
include_forces: bool,
dtype: torch.dtype,
device: torch.device,
) -> tuple[BlockSolveBatch, torch.Tensor, torch.Tensor | None]:
"""Assemble prediction rows in a fixed-weight alchemical posterior layout."""
als_fitter = AlchemicalALSFitter(
model,
fit_energy=False,
fit_forces=include_forces,
fit_per_atom_energy=True,
dtype=dtype,
device=device,
)
true_batch, per_atom_rows, force_rows = _prediction_batch_for_linear(
als_fitter.linear_fitter,
atoms,
neighbor_list=neighbor_list,
include_forces=include_forces,
)
provider_groups = als_fitter.layout.non_identity_providers()
provider_index_by_id = {
id(group.provider): index for index, group in enumerate(provider_groups)
}
direct_block_indices = set(als_fitter._active_direct_blocks())
matrices: dict[object, BlockMatrix] = {}
for block_index, matrix in true_batch.matrices.items():
block = als_fitter.layout.block(int(block_index))
provider = block.coefficient_provider
if block_index in direct_block_indices:
if block_index in posterior_layout.slices:
matrices[block_index] = matrix
continue
if provider is None or provider.weights is None:
continue
provider_index = provider_index_by_id[id(provider)]
assert block.coefficient_index is not None
weights = provider.weights.detach().to(
dtype=true_batch.target.dtype,
device=true_batch.target.device,
)[block.coefficient_index]
for proxy_index in torch.nonzero(weights != 0, as_tuple=False).reshape(-1):
key = _provider_proxy_key(provider_index, int(proxy_index.item()))
if key not in posterior_layout.slices:
continue
contribution = _scale_block_matrix(matrix, weights[proxy_index])
if key in matrices:
matrices[key] = _materialize_block_matrix(
matrices[key]
) + _materialize_block_matrix(contribution)
else:
matrices[key] = contribution
return (
BlockSolveBatch(target=true_batch.target, matrices=matrices),
per_atom_rows,
force_rows,
)
[docs]
def build_prediction_rows(
model: UFPModel,
atoms: ase.Atoms,
*,
posterior: BayesianLinearPosterior | None = None,
fitter: LinearFitter | None = None,
neighbor_list: NeighborListData | None = None,
include_forces: bool = True,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> SparsePredictionRows:
"""Build sparse atomic, total-energy, and optional force prediction rows."""
if posterior is not None and posterior.layout is not None:
posterior_kind = posterior.metadata.get("kind")
if posterior_kind == "alchemical_fixed_weight":
batch, per_atom_rows, force_rows = _prediction_batch_for_fixed_alchemical(
model,
posterior.layout,
atoms,
neighbor_list=neighbor_list,
include_forces=include_forces,
dtype=posterior.dtype if dtype is None else dtype,
device=posterior.device if device is None else device,
)
layout = posterior.layout
else:
resolved_fitter = (
LinearFitter(
model,
fit_energy=False,
fit_forces=include_forces,
fit_per_atom_energy=True,
dtype=posterior.dtype if dtype is None else dtype,
device=posterior.device if device is None else device,
)
if fitter is None
else fitter
)
batch, per_atom_rows, force_rows = _prediction_batch_for_linear(
resolved_fitter,
atoms,
neighbor_list=neighbor_list,
include_forces=include_forces,
)
layout = posterior.layout
else:
resolved_fitter = (
LinearFitter(
model,
fit_energy=False,
fit_forces=include_forces,
fit_per_atom_energy=True,
dtype=dtype,
device=device,
)
if fitter is None
else fitter
)
batch, per_atom_rows, force_rows = _prediction_batch_for_linear(
resolved_fitter,
atoms,
neighbor_list=neighbor_list,
include_forces=include_forces,
)
layout = BlockProblemLayout.from_blocks(resolved_fitter._direct_blocks())
atomic_rows = _sparse_rows_from_batch(batch, layout, per_atom_rows)
total_row = combine_total_energy_rows([(row, 1.0) for row in atomic_rows])
force_row_tuple = None
if include_forces and force_rows is not None:
flattened = _sparse_rows_from_batch(batch, layout, force_rows.reshape(-1))
force_row_tuple = tuple(
(
flattened[3 * index],
flattened[3 * index + 1],
flattened[3 * index + 2],
)
for index in range(len(atomic_rows))
)
return SparsePredictionRows(
atomic_energy_rows=atomic_rows,
total_energy_row=total_row,
force_rows=force_row_tuple,
)
[docs]
def variance_for_energy_row(
row: SparseLinearRow | torch.Tensor,
posterior: BayesianLinearPosterior,
) -> torch.Tensor:
"""Return ``row @ Sigma_theta @ row.T`` for one sparse or dense row."""
sigma = posterior.covariance_tensor(
dtype=posterior.theta_mean.dtype,
device=posterior.theta_mean.device,
)
return _variance_for_row_with_sigma(row, posterior, sigma)
def _variance_for_row_with_sigma(
row: SparseLinearRow | torch.Tensor,
posterior: BayesianLinearPosterior,
sigma: torch.Tensor,
) -> torch.Tensor:
"""Return row variance using an already materialized covariance tensor."""
if isinstance(row, SparseLinearRow):
if row.size != posterior.n_parameters:
raise ValueError("row size does not match posterior parameter count")
if not row.indices.numel():
return torch.zeros((), dtype=sigma.dtype, device=sigma.device)
indices = row.indices.to(device=sigma.device)
values = row.values.to(dtype=sigma.dtype, device=sigma.device)
submatrix = sigma.index_select(0, indices).index_select(1, indices)
return values @ (submatrix @ values)
dense = torch.as_tensor(row, dtype=sigma.dtype, device=sigma.device).reshape(-1)
if int(dense.numel()) != posterior.n_parameters:
raise ValueError("row size does not match posterior parameter count")
return dense @ (sigma @ dense)
[docs]
def variance_for_sparse_rows(
rows: Sequence[SparseLinearRow | torch.Tensor],
posterior: BayesianLinearPosterior,
*,
chunk_size: int = 256,
) -> torch.Tensor:
"""Return ``diag(A @ Sigma_theta @ A.T)`` for sparse/dense prediction rows."""
if not rows:
return torch.empty(
0,
dtype=posterior.theta_mean.dtype,
device=posterior.theta_mean.device,
)
if int(chunk_size) <= 0:
raise ValueError("`chunk_size` must be positive")
sigma = posterior.covariance_tensor(
dtype=posterior.theta_mean.dtype,
device=posterior.theta_mean.device,
)
outputs = []
for start in range(0, len(rows), int(chunk_size)):
chunk = rows[start : start + int(chunk_size)]
if all(isinstance(row, SparseLinearRow) for row in chunk):
sparse_chunk = [row for row in chunk if isinstance(row, SparseLinearRow)]
if any(row.size != posterior.n_parameters for row in sparse_chunk):
raise ValueError("row size does not match posterior parameter count")
nonempty_indices = [
row.indices.to(device=sigma.device)
for row in sparse_chunk
if row.indices.numel()
]
if not nonempty_indices:
outputs.append(
torch.zeros(
(len(sparse_chunk),),
dtype=sigma.dtype,
device=sigma.device,
)
)
continue
unique_indices, inverse = torch.unique(
torch.cat(nonempty_indices),
sorted=True,
return_inverse=True,
)
compact_rows = torch.zeros(
(len(sparse_chunk), int(unique_indices.numel())),
dtype=sigma.dtype,
device=sigma.device,
)
cursor = 0
for row_index, row in enumerate(sparse_chunk):
width = int(row.indices.numel())
if width == 0:
continue
compact_rows[row_index, inverse[cursor : cursor + width]] = (
row.values.to(dtype=sigma.dtype, device=sigma.device)
)
cursor += width
sigma_chunk = sigma.index_select(0, unique_indices).index_select(
1,
unique_indices,
)
outputs.append(
torch.sum((compact_rows @ sigma_chunk) * compact_rows, dim=1)
)
continue
dense_rows = torch.stack(
[
row.to_dense(dtype=sigma.dtype, device=sigma.device)
if isinstance(row, SparseLinearRow)
else torch.as_tensor(
row, dtype=sigma.dtype, device=sigma.device
).reshape(-1)
for row in chunk
],
dim=0,
)
if dense_rows.shape[1] != posterior.n_parameters:
raise ValueError("row size does not match posterior parameter count")
outputs.append(torch.sum((dense_rows @ sigma) * dense_rows, dim=1))
return torch.cat(outputs, dim=0)
[docs]
@dataclass(frozen=True)
class UFPUncertaintyOutput:
"""UFP predictions with epistemic, aleatoric, and total variances."""
means: UFPOutput
energy_epistemic_variance: torch.Tensor | None = None
per_atom_energy_epistemic_variance: torch.Tensor | None = None
force_epistemic_variance: torch.Tensor | None = None
energy_aleatoric_variance: torch.Tensor | None = None
per_atom_energy_aleatoric_variance: torch.Tensor | None = None
force_aleatoric_variance: torch.Tensor | None = None
energy_total_variance: torch.Tensor | None = None
per_atom_energy_total_variance: torch.Tensor | None = None
force_total_variance: torch.Tensor | None = None
rows: SparsePredictionRows | None = None
def _zero_like_optional(value: object) -> torch.Tensor | None:
"""Return a zero tensor matching ``value`` when present."""
if value is None:
return None
tensor = torch.as_tensor(value)
return torch.zeros_like(tensor)
[docs]
def predict_with_uncertainty(
model: UFPModel,
atoms: ase.Atoms,
posterior: BayesianLinearPosterior,
*,
rows: SparsePredictionRows | None = None,
fitter: LinearFitter | None = None,
neighbor_list: NeighborListData | None = None,
include_forces: bool = True,
aleatoric_variance: float | torch.Tensor | None = None,
aleatoric_noise_bundle: SplineAleatoricNoiseBundle | None = None,
variance_chunk_size: int = 256,
return_rows: bool = False,
) -> UFPUncertaintyOutput:
"""Predict means and diagonal epistemic/total variances for one structure."""
if aleatoric_variance is not None and aleatoric_noise_bundle is not None:
raise ValueError(
"use only one of `aleatoric_variance` and `aleatoric_noise_bundle`"
)
output = model.compute(
atoms,
neighbor_list=neighbor_list,
dtype=posterior.dtype,
device=posterior.device,
derive_forces=include_forces,
)
prediction_rows = rows
if prediction_rows is None:
prediction_rows = build_prediction_rows(
model,
atoms,
posterior=posterior,
fitter=fitter,
neighbor_list=neighbor_list,
include_forces=include_forces,
)
sigma = posterior.covariance_tensor(
dtype=posterior.theta_mean.dtype,
device=posterior.theta_mean.device,
)
energy_epistemic = _variance_for_row_with_sigma(
prediction_rows.total_energy_row,
posterior,
sigma,
).reshape(1)
per_atom_epistemic = variance_for_sparse_rows(
prediction_rows.atomic_energy_rows,
posterior,
chunk_size=variance_chunk_size,
)
force_epistemic = None
if include_forces and prediction_rows.force_rows is not None:
force_epistemic = variance_for_sparse_rows(
prediction_rows.force_component_rows,
posterior,
chunk_size=variance_chunk_size,
).reshape(len(prediction_rows.force_rows), 3)
if aleatoric_noise_bundle is not None:
bundle_energy, bundle_per_atom, bundle_force = (
aleatoric_noise_bundle.predict_for_atoms(
atoms,
include_forces=include_forces,
dtype=posterior.dtype,
device=posterior.device,
)
)
energy_aleatoric = (
torch.zeros_like(energy_epistemic)
if bundle_energy is None
else bundle_energy.reshape_as(energy_epistemic)
)
per_atom_aleatoric = (
torch.zeros_like(per_atom_epistemic)
if bundle_per_atom is None
else bundle_per_atom.reshape_as(per_atom_epistemic)
)
force_aleatoric = (
None
if force_epistemic is None
else (
torch.zeros_like(force_epistemic)
if bundle_force is None
else bundle_force.reshape_as(force_epistemic)
)
)
elif aleatoric_variance is None:
energy_aleatoric = torch.zeros_like(energy_epistemic)
per_atom_aleatoric = torch.zeros_like(per_atom_epistemic)
force_aleatoric = (
None if force_epistemic is None else torch.zeros_like(force_epistemic)
)
else:
variance = torch.as_tensor(
aleatoric_variance,
dtype=posterior.dtype,
device=posterior.device,
)
if variance.numel() == 1:
per_atom_aleatoric = variance.reshape(()).expand_as(per_atom_epistemic)
energy_aleatoric = torch.sum(per_atom_aleatoric).reshape(1)
force_aleatoric = (
None
if force_epistemic is None
else variance.reshape(()).expand_as(force_epistemic)
)
else:
per_atom_aleatoric = variance.reshape_as(per_atom_epistemic)
energy_aleatoric = torch.sum(per_atom_aleatoric).reshape(1)
force_aleatoric = (
None if force_epistemic is None else torch.zeros_like(force_epistemic)
)
return UFPUncertaintyOutput(
means=output,
energy_epistemic_variance=energy_epistemic,
per_atom_energy_epistemic_variance=per_atom_epistemic,
force_epistemic_variance=force_epistemic,
energy_aleatoric_variance=energy_aleatoric,
per_atom_energy_aleatoric_variance=per_atom_aleatoric,
force_aleatoric_variance=force_aleatoric,
energy_total_variance=energy_epistemic + energy_aleatoric,
per_atom_energy_total_variance=per_atom_epistemic + per_atom_aleatoric,
force_total_variance=None
if force_epistemic is None or force_aleatoric is None
else force_epistemic + force_aleatoric,
rows=prediction_rows if return_rows else None,
)
[docs]
class SplineAleatoricNoiseModel(torch.nn.Module):
"""Positive 1D spline variance head using ``softplus(raw) + floor``."""
def __init__(
self,
*,
n_coefficients: int,
lower_full_support: float,
upper_full_support: float,
spline: str = "cubic",
variance_floor: float = 1.0e-12,
initial_raw: float | torch.Tensor = -6.0,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize a scalar spline log-variance head."""
super().__init__()
first_knot, knot_spacing = uniform_support_parameters(
coeff_size=int(n_coefficients),
lower_full_support=float(lower_full_support),
upper_full_support=float(upper_full_support),
spline=spline,
)
raw = torch.as_tensor(initial_raw, dtype=dtype)
if raw.ndim == 0:
raw = raw.expand(int(n_coefficients)).clone()
if tuple(raw.shape) != (int(n_coefficients),):
raise ValueError(
f"`initial_raw` must be scalar or have shape ({int(n_coefficients)},)"
)
self.raw_values = torch.nn.Parameter(raw)
self.n_coefficients = int(n_coefficients)
self.lower_full_support = float(lower_full_support)
self.upper_full_support = float(upper_full_support)
self.first_knot = float(first_knot)
self.knot_spacing = float(knot_spacing)
self.spline = str(spline)
self.variance_floor = float(variance_floor)
if self.variance_floor < 0.0:
raise ValueError("`variance_floor` must be non-negative")
def _clamp_inputs(self, x: torch.Tensor) -> torch.Tensor:
"""Clamp features into the active spline support."""
upper = torch.as_tensor(
self.upper_full_support,
dtype=x.dtype,
device=x.device,
)
lower = torch.as_tensor(
self.lower_full_support,
dtype=x.dtype,
device=x.device,
)
eps = torch.finfo(x.dtype).eps * torch.clamp(torch.abs(upper), min=1.0)
return torch.clamp(
x,
min=lower,
max=upper - eps,
)
[docs]
def raw(self, x: torch.Tensor) -> torch.Tensor:
"""Evaluate the unconstrained spline head at scalar features ``x``."""
features = self._clamp_inputs(torch.as_tensor(x, dtype=self.raw_values.dtype))
original_shape = features.shape
flat = features.reshape(-1).to(device=self.raw_values.device)
stencil = uniform_stencil_1d(
flat,
coeff_size=self.n_coefficients,
first_knot=self.first_knot,
knot_spacing=self.knot_spacing,
spline=self.spline,
)
raw = torch.sum(
self.raw_values.index_select(0, stencil.indices.reshape(-1)).reshape(
stencil.indices.shape
)
* stencil.values,
dim=1,
)
return raw.reshape(original_shape)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return positive variances at scalar features ``x``."""
return torch.nn.functional.softplus(self.raw(x)) + self.variance_floor
[docs]
def to_payload(self) -> dict[str, object]:
"""Return a torch-serializable payload for this fitted noise head."""
return {
"schema": {
"name": ALEATORIC_SPLINE_SCHEMA_NAME,
"version": ALEATORIC_SPLINE_SCHEMA_VERSION,
},
"config": {
"n_coefficients": int(self.n_coefficients),
"lower_full_support": float(self.lower_full_support),
"upper_full_support": float(self.upper_full_support),
"spline": str(self.spline),
"variance_floor": float(self.variance_floor),
},
"state_dict": {
key: value.detach().cpu() for key, value in self.state_dict().items()
},
}
[docs]
@classmethod
def from_payload(
cls,
payload: object,
*,
dtype: torch.dtype | None = None,
) -> "SplineAleatoricNoiseModel":
"""Load a fitted noise head from :meth:`to_payload` output."""
if not isinstance(payload, Mapping):
raise ValueError("aleatoric noise model payload must be a mapping")
schema = payload.get("schema")
if not isinstance(schema, Mapping):
raise ValueError("aleatoric noise model payload is missing schema")
if schema.get("name") != ALEATORIC_SPLINE_SCHEMA_NAME:
raise ValueError("unsupported aleatoric noise model schema")
if schema.get("version") != ALEATORIC_SPLINE_SCHEMA_VERSION:
raise ValueError("unsupported aleatoric noise model schema version")
config = payload.get("config")
if not isinstance(config, Mapping):
raise ValueError("aleatoric noise model payload is missing config")
state_dict = payload.get("state_dict")
if not isinstance(state_dict, Mapping):
raise ValueError("aleatoric noise model payload is missing state_dict")
model = cls(
n_coefficients=int(config["n_coefficients"]),
lower_full_support=float(config["lower_full_support"]),
upper_full_support=float(config["upper_full_support"]),
spline=str(config.get("spline", "cubic")),
variance_floor=float(config.get("variance_floor", 1.0e-12)),
initial_raw=0.0,
dtype=dtype,
)
model.load_state_dict(
{
str(key): torch.as_tensor(value, dtype=dtype)
for key, value in state_dict.items()
}
)
return model
[docs]
def gaussian_nll(
self,
residuals: torch.Tensor,
x: torch.Tensor,
*,
reduction: Literal["mean", "sum", "none"] = "mean",
include_constant: bool = False,
) -> torch.Tensor:
"""Return Gaussian negative log likelihood for residuals."""
residuals = torch.as_tensor(
residuals,
dtype=self.raw_values.dtype,
device=self.raw_values.device,
)
variances = self.forward(x).reshape_as(residuals)
nll = 0.5 * (torch.log(variances) + residuals * residuals / variances)
if include_constant:
nll = nll + 0.5 * math.log(2.0 * math.pi)
if reduction == "none":
return nll
if reduction == "sum":
return torch.sum(nll)
if reduction == "mean":
return torch.mean(nll)
raise ValueError("`reduction` must be 'mean', 'sum', or 'none'")
[docs]
def fit_residuals(
self,
residuals: torch.Tensor,
x: torch.Tensor,
*,
steps: int = 500,
lr: float = 1.0e-2,
) -> tuple[float, ...]:
"""Optimize this variance head against residuals with Adam."""
optimizer = torch.optim.Adam(self.parameters(), lr=float(lr))
history: list[float] = []
for _ in range(int(steps)):
optimizer.zero_grad()
loss = self.gaussian_nll(residuals, x)
loss.backward()
optimizer.step()
history.append(float(loss.detach().item()))
return tuple(history)
__all__ = [
"AleatoricFeatureSpec",
"AleatoricRowMetadata",
"BayesianLinearPosterior",
"SparseLinearRow",
"SparsePredictionRows",
"SplineAleatoricNoiseBundle",
"SplineAleatoricNoiseModel",
"UFPUncertaintyOutput",
"UncertaintyPredictionBundle",
"build_aleatoric_row_metadata",
"build_prediction_rows",
"combine_total_energy_rows",
"describe_uncertainty_prediction_bundle",
"fit_alchemical_uncertainty_model",
"fit_linear_uncertainty_model",
"load_uncertainty_prediction_bundle",
"predict_with_uncertainty",
"save_uncertainty_prediction_bundle",
"save_energy_variance_scale_to_bundle",
"validate_uncertainty_prediction_bundle",
"variance_for_energy_row",
"variance_for_sparse_rows",
]