Source code for ufp.terms.alchemical
"""
Alchemical coefficient sharing for related spline terms.
Use this module when several true coefficient blocks should be generated from a
smaller proxy basis with optional trainable mixing weights.
"""
from __future__ import annotations
import torch
[docs]
class AlchemicalCoefficients(torch.nn.Module):
"""
Shared linear projection from proxy coefficient tensors to true term coefficients.
"""
def __init__(
self,
*,
proxy_coeffs,
n_true_terms: int,
weights=None,
proxy_trainable: bool = True,
weights_trainable: bool = True,
dtype: torch.dtype | None = None,
) -> None:
"""Initialize proxy coefficients and optional mixing weights."""
super().__init__()
if int(n_true_terms) <= 0:
raise ValueError("`n_true_terms` must be positive")
proxy_tensor = torch.as_tensor(
proxy_coeffs,
dtype=dtype,
)
if proxy_tensor.ndim < 2:
raise ValueError("`proxy_coeffs` must have shape (n_proxy_terms, ...)")
n_proxy_terms = int(proxy_tensor.shape[0])
if n_proxy_terms <= 0:
raise ValueError("`proxy_coeffs` must contain at least one proxy term")
if n_proxy_terms > int(n_true_terms):
raise ValueError("`proxy_coeffs.shape[0]` can not exceed `n_true_terms`")
self.n_true_terms = int(n_true_terms)
self.n_proxy_terms = n_proxy_terms
self.coefficient_shape = tuple(int(dim) for dim in proxy_tensor.shape[1:])
self.proxy_coeffs = torch.nn.Parameter(
proxy_tensor,
requires_grad=bool(proxy_trainable),
)
if weights is None:
if self.n_proxy_terms != self.n_true_terms:
raise ValueError(
"`weights` is required when the number of proxy and true "
"terms differ"
)
self.register_buffer("weights", None, persistent=False)
else:
weights_tensor = torch.as_tensor(
weights,
dtype=proxy_tensor.dtype,
)
expected_shape = (self.n_true_terms, self.n_proxy_terms)
if (
weights_tensor.ndim != 2
or tuple(weights_tensor.shape) != expected_shape
):
raise ValueError(
"`weights` must have shape "
f"{expected_shape}, got {tuple(weights_tensor.shape)}"
)
if weights_trainable:
self.weights = torch.nn.Parameter(weights_tensor)
else:
self.register_buffer("weights", weights_tensor)
@property
def uses_identity_weights(self) -> bool:
"""Report whether proxy coefficients are used directly as true coefficients."""
return self.weights is None
@property
def true_coeffs(self) -> torch.Tensor:
"""Return the current true coefficient tensor after optional proxy mixing."""
if self.weights is None:
return self.proxy_coeffs
return torch.tensordot(self.weights, self.proxy_coeffs, dims=([1], [0]))
[docs]
def true_coeffs_for(self, index: int) -> torch.Tensor:
"""Return one true coefficient block by index."""
index = int(index)
if index < 0 or index >= self.n_true_terms:
raise IndexError(
f"coefficient index {index} is outside [0, {self.n_true_terms})"
)
if self.weights is None:
return self.proxy_coeffs[index]
return torch.tensordot(self.weights[index], self.proxy_coeffs, dims=([0], [0]))
__all__ = [
"AlchemicalCoefficients",
]