Source code for ufp.terms.alchemical

"""
Alchemical coefficient sharing for related spline terms.

Use this module when several true coefficient blocks should be generated from a
smaller proxy basis with optional trainable mixing weights.
"""

from __future__ import annotations

import torch


[docs] class AlchemicalCoefficients(torch.nn.Module): """ Shared linear projection from proxy coefficient tensors to true term coefficients. """ def __init__( self, *, proxy_coeffs, n_true_terms: int, weights=None, proxy_trainable: bool = True, weights_trainable: bool = True, dtype: torch.dtype | None = None, ) -> None: """Initialize proxy coefficients and optional mixing weights.""" super().__init__() if int(n_true_terms) <= 0: raise ValueError("`n_true_terms` must be positive") proxy_tensor = torch.as_tensor( proxy_coeffs, dtype=dtype, ) if proxy_tensor.ndim < 2: raise ValueError("`proxy_coeffs` must have shape (n_proxy_terms, ...)") n_proxy_terms = int(proxy_tensor.shape[0]) if n_proxy_terms <= 0: raise ValueError("`proxy_coeffs` must contain at least one proxy term") if n_proxy_terms > int(n_true_terms): raise ValueError("`proxy_coeffs.shape[0]` can not exceed `n_true_terms`") self.n_true_terms = int(n_true_terms) self.n_proxy_terms = n_proxy_terms self.coefficient_shape = tuple(int(dim) for dim in proxy_tensor.shape[1:]) self.proxy_coeffs = torch.nn.Parameter( proxy_tensor, requires_grad=bool(proxy_trainable), ) if weights is None: if self.n_proxy_terms != self.n_true_terms: raise ValueError( "`weights` is required when the number of proxy and true " "terms differ" ) self.register_buffer("weights", None, persistent=False) else: weights_tensor = torch.as_tensor( weights, dtype=proxy_tensor.dtype, ) expected_shape = (self.n_true_terms, self.n_proxy_terms) if ( weights_tensor.ndim != 2 or tuple(weights_tensor.shape) != expected_shape ): raise ValueError( "`weights` must have shape " f"{expected_shape}, got {tuple(weights_tensor.shape)}" ) if weights_trainable: self.weights = torch.nn.Parameter(weights_tensor) else: self.register_buffer("weights", weights_tensor) @property def uses_identity_weights(self) -> bool: """Report whether proxy coefficients are used directly as true coefficients.""" return self.weights is None @property def true_coeffs(self) -> torch.Tensor: """Return the current true coefficient tensor after optional proxy mixing.""" if self.weights is None: return self.proxy_coeffs return torch.tensordot(self.weights, self.proxy_coeffs, dims=([1], [0]))
[docs] def true_coeffs_for(self, index: int) -> torch.Tensor: """Return one true coefficient block by index.""" index = int(index) if index < 0 or index >= self.n_true_terms: raise IndexError( f"coefficient index {index} is outside [0, {self.n_true_terms})" ) if self.weights is None: return self.proxy_coeffs[index] return torch.tensordot(self.weights[index], self.proxy_coeffs, dims=([0], [0]))
__all__ = [ "AlchemicalCoefficients", ]