Source code for ufp.terms._parameters

"""Parameter-block contracts shared by model terms and fitters."""

from __future__ import annotations

from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import Any

import torch

from ufp.terms.alchemical import AlchemicalCoefficients


TensorReader = Callable[[], torch.Tensor]
TensorWriter = Callable[[torch.Tensor], None]


[docs] @dataclass(frozen=True) class ParameterBlockCacheChannel: """Semantic channel metadata for reusable least-squares cache layouts.""" kind: str values: tuple[int, ...] start: int stop: int key: str | None = None
[docs] @dataclass(frozen=True) class ParameterBlockCacheDescriptor: """Semantic cache descriptor attached to a fittable parameter block.""" family: Mapping[str, object] | str channels: tuple[ParameterBlockCacheChannel, ...] reusable: bool = True
[docs] @dataclass(frozen=True) class ParameterBlock: """ Named parameter tensor exposed by a term for fitting or inspection. Blocks are resolved once when a fitting layout is created. Term forward paths do not use this object, which keeps the contract out of runtime hot loops. Attributes: name: Local parameter name within the owning term. kind: Term-specific block kind used by layout and solver code. shape: Logical tensor shape for this block. read: Callback returning the current tensor data. write: Callback that writes solved values back into the owning term. label: Optional human-readable label. coefficient_provider: Optional alchemical coefficient provider. coefficient_index: Optional provider index for this block. regularization_group: Optional group name used by regularization settings. fittable: Whether linear fitting should include this block. frozen: Whether this block should be treated as fixed. assembler: Optional custom block assembler. cache_descriptor: Optional semantic descriptor for reusable assembled caches. """ name: str kind: str shape: tuple[int, ...] read: TensorReader write: TensorWriter label: str | None = None coefficient_provider: AlchemicalCoefficients | None = None coefficient_index: int | None = None regularization_group: str | None = None fittable: bool = True frozen: bool = False assembler: Any = None cache_descriptor: ParameterBlockCacheDescriptor | None = None
[docs] def copy_parameter_data(parameter: torch.Tensor, values: torch.Tensor) -> None: """Copy reshaped values into a parameter without replacing the object.""" parameter.data.copy_(values.reshape(parameter.shape).to(parameter))
__all__ = [ "ParameterBlock", "ParameterBlockCacheChannel", "ParameterBlockCacheDescriptor", "TensorReader", "TensorWriter", "copy_parameter_data", ]