Source code for ufp.leastsquares.dataset

"""
Fit-sample normalization and batch preparation for least-squares workflows.

Use this module to convert labeled structures into weighted target rows and the
shared ``UFPInput`` batches consumed by the assembly layer.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Sequence

import ase
import numpy as np
import torch

from ufp.core.input import UFPInput
from ufp.core.potential import UFPotential
from ufp.neighbors._data import NeighborListData, concatenate_neighbor_lists


[docs] @dataclass class FitSample: """One labeled structure for least-squares fitting over fixed geometry.""" atoms: ase.Atoms neighbor_list: NeighborListData | None = None energy: float | None = None forces: np.ndarray | torch.Tensor | None = None per_atom_energy: np.ndarray | torch.Tensor | None = None energy_weight: float = 1.0 force_weight: float = 1.0 per_atom_weight: float = 1.0 def __post_init__(self) -> None: """Validate target shapes and normalize per-target weights.""" if not isinstance(self.atoms, ase.Atoms): raise TypeError("`atoms` must be an ase.Atoms instance") if self.energy is not None: self.energy = float(self.energy) self.energy_weight = float(self.energy_weight) self.force_weight = float(self.force_weight) self.per_atom_weight = float(self.per_atom_weight) if self.energy_weight < 0.0: raise ValueError("`energy_weight` must be non-negative") if self.force_weight < 0.0: raise ValueError("`force_weight` must be non-negative") if self.per_atom_weight < 0.0: raise ValueError("`per_atom_weight` must be non-negative") n_atoms = len(self.atoms) if self.forces is not None: forces = np.asarray(self.forces) if forces.shape != (n_atoms, 3): raise ValueError( f"`forces` must have shape ({n_atoms}, 3), got {forces.shape}" ) if self.per_atom_energy is not None: per_atom_energy = np.asarray(self.per_atom_energy).reshape(-1) if per_atom_energy.shape != (n_atoms,): raise ValueError( "`per_atom_energy` must have shape " f"({n_atoms},), got {per_atom_energy.shape}" )
[docs] def has_targets( self, *, fit_energy: bool, fit_forces: bool, fit_per_atom_energy: bool, ) -> bool: """Report whether this sample contributes to the enabled target types.""" return bool( (fit_energy and self.energy is not None) or (fit_forces and self.forces is not None) or (fit_per_atom_energy and self.per_atom_energy is not None) )
[docs] @dataclass class BatchTargets: """Weighted target rows and row mappings for one prepared fit batch.""" values: torch.Tensor sqrt_weights: torch.Tensor row_scales: torch.Tensor energy_rows: torch.Tensor force_rows: torch.Tensor per_atom_rows: torch.Tensor @property def n_rows(self) -> int: """Return the number of active target rows in this batch.""" return int(self.values.shape[0]) @property def weighted_values(self) -> torch.Tensor: """Return target values after square-root weighting is applied.""" return self.sqrt_weights * self.values
[docs] @dataclass class PreparedBatch: """One batch coupling input geometry with weighted least-squares targets.""" samples: tuple[FitSample, ...] inputs: UFPInput targets: BatchTargets
def _chunked( items: Sequence[FitSample], batch_size: int, ) -> Iterable[tuple[FitSample, ...]]: """Yield consecutive sample chunks of the requested batch size.""" if batch_size <= 0: raise ValueError("`batch_size` must be positive") for start in range(0, len(items), batch_size): yield tuple(items[start : start + batch_size]) def _prepare_targets( samples: Sequence[FitSample], inputs: UFPInput, *, fit_energy: bool, fit_forces: bool, fit_per_atom_energy: bool, ) -> BatchTargets: """Flatten active sample targets into weighted row tensors and row lookups.""" device = inputs.device dtype = inputs.dtype energy_rows = torch.full( (inputs.n_systems,), fill_value=-1, dtype=torch.int64, device=device, ) force_rows = torch.full( (inputs.n_atoms, 3), fill_value=-1, dtype=torch.int64, device=device, ) per_atom_rows = torch.full( (inputs.n_atoms,), fill_value=-1, dtype=torch.int64, device=device, ) target_values: list[torch.Tensor] = [] sqrt_weights: list[torch.Tensor] = [] row_scales: list[torch.Tensor] = [] row_count = 0 atom_offset = 0 for system_i, sample in enumerate(samples): n_atoms = len(sample.atoms) if fit_energy and sample.energy is not None: energy_rows[system_i] = row_count target_values.append( torch.tensor( [float(sample.energy) / float(n_atoms)], dtype=dtype, device=device, ) ) sqrt_weights.append( torch.full( (1,), fill_value=float(sample.energy_weight) ** 0.5, dtype=dtype, device=device, ) ) row_scales.append( torch.full( (1,), fill_value=1.0 / float(n_atoms), dtype=dtype, device=device, ) ) row_count += 1 if fit_forces and sample.forces is not None: forces = torch.as_tensor( sample.forces, dtype=dtype, device=device, ) n_force_rows = n_atoms * 3 force_rows[atom_offset : atom_offset + n_atoms] = torch.arange( row_count, row_count + n_force_rows, dtype=torch.int64, device=device, ).reshape(n_atoms, 3) target_values.append(forces.reshape(-1)) sqrt_weights.append( torch.full( (n_force_rows,), fill_value=float(sample.force_weight) ** 0.5, dtype=dtype, device=device, ) ) row_scales.append(torch.ones(n_force_rows, dtype=dtype, device=device)) row_count += n_force_rows if fit_per_atom_energy and sample.per_atom_energy is not None: per_atom_energy = torch.as_tensor( sample.per_atom_energy, dtype=dtype, device=device, ).reshape(n_atoms) per_atom_rows[atom_offset : atom_offset + n_atoms] = torch.arange( row_count, row_count + n_atoms, dtype=torch.int64, device=device, ) target_values.append(per_atom_energy) sqrt_weights.append( torch.full( (n_atoms,), fill_value=float(sample.per_atom_weight) ** 0.5, dtype=dtype, device=device, ) ) row_scales.append(torch.ones(n_atoms, dtype=dtype, device=device)) row_count += n_atoms atom_offset += n_atoms if not target_values: raise ValueError("no active targets were found in the batch") return BatchTargets( values=torch.cat(target_values, dim=0), sqrt_weights=torch.cat(sqrt_weights, dim=0), row_scales=torch.cat(row_scales, dim=0), energy_rows=energy_rows, force_rows=force_rows, per_atom_rows=per_atom_rows, )
[docs] def prepare_batches( model: UFPotential, samples: Sequence[FitSample], *, batch_size: int, fit_energy: bool, fit_forces: bool, fit_per_atom_energy: bool, dtype: torch.dtype, device: torch.device | None, ) -> tuple[PreparedBatch, ...]: """Convert fit samples into prepared batches ready for matrix assembly.""" items = tuple(samples) if not items: raise ValueError("`samples` must contain at least one FitSample") filtered = [ sample for sample in items if sample.has_targets( fit_energy=fit_energy, fit_forces=fit_forces, fit_per_atom_energy=fit_per_atom_energy, ) ] if not filtered: raise ValueError("no samples contain the requested targets") prepared_batches: list[PreparedBatch] = [] for batch_samples in _chunked(filtered, batch_size): atoms = [sample.atoms for sample in batch_samples] neighbor_lists = [sample.neighbor_list for sample in batch_samples] if all(item is None for item in neighbor_lists): inputs = model.prepare_input( atoms, neighbor_list=None, device=device, dtype=dtype, requires_grad=False, ) elif all(item is not None for item in neighbor_lists): inputs = model.prepare_input( atoms, neighbor_list=neighbor_lists, # type: ignore[arg-type] device=device, dtype=dtype, requires_grad=False, ) else: per_system_neighbor_lists: list[NeighborListData] = [] for sample in batch_samples: single_input = model.prepare_input( sample.atoms, neighbor_list=sample.neighbor_list, device=device, dtype=dtype, requires_grad=False, ) if single_input.neighbor_list is None: raise ValueError( "mixed explicit and implicit neighbor-list batches require " "the model to produce a neighbor list for every sample" ) per_system_neighbor_lists.append(single_input.neighbor_list) inputs = model.prepare_input( atoms, neighbor_list=concatenate_neighbor_lists(per_system_neighbor_lists), device=device, dtype=dtype, requires_grad=False, ) targets = _prepare_targets( batch_samples, inputs, fit_energy=fit_energy, fit_forces=fit_forces, fit_per_atom_energy=fit_per_atom_energy, ) prepared_batches.append( PreparedBatch( samples=batch_samples, inputs=inputs, targets=targets, ) ) return tuple(prepared_batches)
__all__ = [ "BatchTargets", "FitSample", "PreparedBatch", "prepare_batches", ]