"""
Fit-sample normalization and batch preparation for least-squares workflows.
Use this module to convert labeled structures into weighted target rows and the
shared ``UFPInput`` batches consumed by the assembly layer.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
import ase
import numpy as np
import torch
from ufp.core.input import UFPInput
from ufp.core.potential import UFPotential
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists
[docs]
@dataclass
class FitSample:
"""One labeled structure for least-squares fitting over fixed geometry."""
atoms: ase.Atoms
neighbor_list: NeighborListData | None = None
energy: float | None = None
forces: np.ndarray | torch.Tensor | None = None
per_atom_energy: np.ndarray | torch.Tensor | None = None
energy_weight: float = 1.0
force_weight: float = 1.0
per_atom_weight: float = 1.0
def __post_init__(self) -> None:
"""Validate target shapes and normalize per-target weights."""
if not isinstance(self.atoms, ase.Atoms):
raise TypeError("`atoms` must be an ase.Atoms instance")
if self.energy is not None:
self.energy = float(self.energy)
self.energy_weight = float(self.energy_weight)
self.force_weight = float(self.force_weight)
self.per_atom_weight = float(self.per_atom_weight)
if self.energy_weight < 0.0:
raise ValueError("`energy_weight` must be non-negative")
if self.force_weight < 0.0:
raise ValueError("`force_weight` must be non-negative")
if self.per_atom_weight < 0.0:
raise ValueError("`per_atom_weight` must be non-negative")
n_atoms = len(self.atoms)
if self.forces is not None:
forces = np.asarray(self.forces)
if forces.shape != (n_atoms, 3):
raise ValueError(
f"`forces` must have shape ({n_atoms}, 3), got {forces.shape}"
)
if self.per_atom_energy is not None:
per_atom_energy = np.asarray(self.per_atom_energy).reshape(-1)
if per_atom_energy.shape != (n_atoms,):
raise ValueError(
"`per_atom_energy` must have shape "
f"({n_atoms},), got {per_atom_energy.shape}"
)
[docs]
def has_targets(
self,
*,
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
) -> bool:
"""Report whether this sample contributes to the enabled target types."""
return bool(
(fit_energy and self.energy is not None)
or (fit_forces and self.forces is not None)
or (fit_per_atom_energy and self.per_atom_energy is not None)
)
[docs]
@dataclass
class BatchTargets:
"""Weighted target rows and row mappings for one prepared fit batch."""
values: torch.Tensor
sqrt_weights: torch.Tensor
row_scales: torch.Tensor
energy_rows: torch.Tensor
force_rows: torch.Tensor
per_atom_rows: torch.Tensor
@property
def n_rows(self) -> int:
"""Return the number of active target rows in this batch."""
return int(self.values.shape[0])
@property
def weighted_values(self) -> torch.Tensor:
"""Return target values after square-root weighting is applied."""
return self.sqrt_weights * self.values
[docs]
@dataclass
class PreparedBatch:
"""One batch coupling input geometry with weighted least-squares targets."""
samples: tuple[FitSample, ...]
inputs: UFPInput
targets: BatchTargets
def _chunked(
items: Sequence[FitSample],
batch_size: int,
) -> Iterable[tuple[FitSample, ...]]:
"""Yield consecutive sample chunks of the requested batch size."""
if batch_size <= 0:
raise ValueError("`batch_size` must be positive")
for start in range(0, len(items), batch_size):
yield tuple(items[start : start + batch_size])
def _prepare_targets(
samples: Sequence[FitSample],
inputs: UFPInput,
*,
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
) -> BatchTargets:
"""Flatten active sample targets into weighted row tensors and row lookups."""
device = inputs.device
dtype = inputs.dtype
energy_rows = torch.full(
(inputs.n_systems,),
fill_value=-1,
dtype=torch.int64,
device=device,
)
force_rows = torch.full(
(inputs.n_atoms, 3),
fill_value=-1,
dtype=torch.int64,
device=device,
)
per_atom_rows = torch.full(
(inputs.n_atoms,),
fill_value=-1,
dtype=torch.int64,
device=device,
)
target_values: list[torch.Tensor] = []
sqrt_weights: list[torch.Tensor] = []
row_scales: list[torch.Tensor] = []
row_count = 0
atom_offset = 0
for system_i, sample in enumerate(samples):
n_atoms = len(sample.atoms)
if fit_energy and sample.energy is not None:
energy_rows[system_i] = row_count
target_values.append(
torch.tensor(
[float(sample.energy) / float(n_atoms)],
dtype=dtype,
device=device,
)
)
sqrt_weights.append(
torch.full(
(1,),
fill_value=float(sample.energy_weight) ** 0.5,
dtype=dtype,
device=device,
)
)
row_scales.append(
torch.full(
(1,),
fill_value=1.0 / float(n_atoms),
dtype=dtype,
device=device,
)
)
row_count += 1
if fit_forces and sample.forces is not None:
forces = torch.as_tensor(
sample.forces,
dtype=dtype,
device=device,
)
n_force_rows = n_atoms * 3
force_rows[atom_offset : atom_offset + n_atoms] = torch.arange(
row_count,
row_count + n_force_rows,
dtype=torch.int64,
device=device,
).reshape(n_atoms, 3)
target_values.append(forces.reshape(-1))
sqrt_weights.append(
torch.full(
(n_force_rows,),
fill_value=float(sample.force_weight) ** 0.5,
dtype=dtype,
device=device,
)
)
row_scales.append(torch.ones(n_force_rows, dtype=dtype, device=device))
row_count += n_force_rows
if fit_per_atom_energy and sample.per_atom_energy is not None:
per_atom_energy = torch.as_tensor(
sample.per_atom_energy,
dtype=dtype,
device=device,
).reshape(n_atoms)
per_atom_rows[atom_offset : atom_offset + n_atoms] = torch.arange(
row_count,
row_count + n_atoms,
dtype=torch.int64,
device=device,
)
target_values.append(per_atom_energy)
sqrt_weights.append(
torch.full(
(n_atoms,),
fill_value=float(sample.per_atom_weight) ** 0.5,
dtype=dtype,
device=device,
)
)
row_scales.append(torch.ones(n_atoms, dtype=dtype, device=device))
row_count += n_atoms
atom_offset += n_atoms
if not target_values:
raise ValueError("no active targets were found in the batch")
return BatchTargets(
values=torch.cat(target_values, dim=0),
sqrt_weights=torch.cat(sqrt_weights, dim=0),
row_scales=torch.cat(row_scales, dim=0),
energy_rows=energy_rows,
force_rows=force_rows,
per_atom_rows=per_atom_rows,
)
[docs]
def prepare_batches(
model: UFPotential,
samples: Sequence[FitSample],
*,
batch_size: int,
fit_energy: bool,
fit_forces: bool,
fit_per_atom_energy: bool,
dtype: torch.dtype,
device: torch.device | None,
) -> tuple[PreparedBatch, ...]:
"""Convert fit samples into prepared batches ready for matrix assembly."""
items = tuple(samples)
if not items:
raise ValueError("`samples` must contain at least one FitSample")
filtered = [
sample
for sample in items
if sample.has_targets(
fit_energy=fit_energy,
fit_forces=fit_forces,
fit_per_atom_energy=fit_per_atom_energy,
)
]
if not filtered:
raise ValueError("no samples contain the requested targets")
prepared_batches: list[PreparedBatch] = []
for batch_samples in _chunked(filtered, batch_size):
atoms = [sample.atoms for sample in batch_samples]
neighbor_lists = [sample.neighbor_list for sample in batch_samples]
if all(item is None for item in neighbor_lists):
inputs = model.prepare_input(
atoms,
neighbor_list=None,
device=device,
dtype=dtype,
requires_grad=False,
)
elif all(item is not None for item in neighbor_lists):
inputs = model.prepare_input(
atoms,
neighbor_list=neighbor_lists, # type: ignore[arg-type]
device=device,
dtype=dtype,
requires_grad=False,
)
else:
per_system_neighbor_lists: list[NeighborListData] = []
for sample in batch_samples:
single_input = model.prepare_input(
sample.atoms,
neighbor_list=sample.neighbor_list,
device=device,
dtype=dtype,
requires_grad=False,
)
if single_input.neighbor_list is None:
raise ValueError(
"mixed explicit and implicit neighbor-list batches require "
"the model to produce a neighbor list for every sample"
)
per_system_neighbor_lists.append(single_input.neighbor_list)
inputs = model.prepare_input(
atoms,
neighbor_list=concatenate_neighbor_lists(per_system_neighbor_lists),
device=device,
dtype=dtype,
requires_grad=False,
)
targets = _prepare_targets(
batch_samples,
inputs,
fit_energy=fit_energy,
fit_forces=fit_forces,
fit_per_atom_energy=fit_per_atom_energy,
)
prepared_batches.append(
PreparedBatch(
samples=batch_samples,
inputs=inputs,
targets=targets,
)
)
return tuple(prepared_batches)
__all__ = [
"BatchTargets",
"FitSample",
"PreparedBatch",
"prepare_batches",
]