Source code for ufp.terms._parameters
"""Parameter-block contracts shared by model terms and fitters."""
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import Any
import torch
from ufp.terms.alchemical import AlchemicalCoefficients
TensorReader = Callable[[], torch.Tensor]
TensorWriter = Callable[[torch.Tensor], None]
[docs]
@dataclass(frozen=True)
class ParameterBlockCacheChannel:
"""Semantic channel metadata for reusable least-squares cache layouts."""
kind: str
values: tuple[int, ...]
start: int
stop: int
key: str | None = None
[docs]
@dataclass(frozen=True)
class ParameterBlockCacheDescriptor:
"""Semantic cache descriptor attached to a fittable parameter block."""
family: Mapping[str, object] | str
channels: tuple[ParameterBlockCacheChannel, ...]
reusable: bool = True
[docs]
@dataclass(frozen=True)
class ParameterBlock:
"""
Named parameter tensor exposed by a term for fitting or inspection.
Blocks are resolved once when a fitting layout is created. Term forward paths do
not use this object, which keeps the contract out of runtime hot loops.
Attributes:
name: Local parameter name within the owning term.
kind: Term-specific block kind used by layout and solver code.
shape: Logical tensor shape for this block.
read: Callback returning the current tensor data.
write: Callback that writes solved values back into the owning term.
label: Optional human-readable label.
coefficient_provider: Optional alchemical coefficient provider.
coefficient_index: Optional provider index for this block.
regularization_group: Optional group name used by regularization settings.
fittable: Whether linear fitting should include this block.
frozen: Whether this block should be treated as fixed.
assembler: Optional custom block assembler.
cache_descriptor: Optional semantic descriptor for reusable assembled caches.
"""
name: str
kind: str
shape: tuple[int, ...]
read: TensorReader
write: TensorWriter
label: str | None = None
coefficient_provider: AlchemicalCoefficients | None = None
coefficient_index: int | None = None
regularization_group: str | None = None
fittable: bool = True
frozen: bool = False
assembler: Any = None
cache_descriptor: ParameterBlockCacheDescriptor | None = None
[docs]
def copy_parameter_data(parameter: torch.Tensor, values: torch.Tensor) -> None:
"""Copy reshaped values into a parameter without replacing the object."""
parameter.data.copy_(values.reshape(parameter.shape).to(parameter))
__all__ = [
"ParameterBlock",
"ParameterBlockCacheChannel",
"ParameterBlockCacheDescriptor",
"TensorReader",
"TensorWriter",
"copy_parameter_data",
]