"""Reusable cutoff envelope helpers for analytic pair priors."""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
CutoffEnvelopeKind = str
_SUPPORTED_KINDS = {"none", "cosine", "smoothstep"}
def _check_cutoff(cutoff: float) -> float:
"""Validate and normalize one cutoff distance."""
cutoff = float(cutoff)
if not math.isfinite(cutoff) or cutoff <= 0.0:
raise ValueError("`cutoff` must be a finite positive value")
return cutoff
def _check_onset(onset: float, *, cutoff: float) -> float:
"""Validate and normalize one cutoff-envelope onset."""
onset = float(onset)
if not math.isfinite(onset) or onset < 0.0:
raise ValueError("`onset` must be a finite non-negative value")
if onset >= cutoff:
raise ValueError("`onset` must be smaller than `cutoff`")
return onset
def _check_kind(kind: str) -> str:
"""Validate and normalize one cutoff-envelope kind."""
kind = str(kind)
if kind not in _SUPPORTED_KINDS:
supported = ", ".join(sorted(_SUPPORTED_KINDS))
raise ValueError(f"`kind` must be one of: {supported}")
return kind
def _as_distance_tensor(distances: object) -> torch.Tensor:
"""Return distances as a floating-point tensor without changing device."""
tensor = torch.as_tensor(distances)
if not tensor.is_floating_point():
tensor = tensor.to(dtype=torch.get_default_dtype())
return tensor
[docs]
@dataclass(frozen=True)
class CutoffEnvelope:
"""Smooth multiplicative envelope that tapers pair energies to zero."""
cutoff: float
onset: float = 0.0
kind: CutoffEnvelopeKind = "smoothstep"
def __post_init__(self) -> None:
"""Validate envelope metadata."""
cutoff = _check_cutoff(self.cutoff)
object.__setattr__(self, "cutoff", cutoff)
object.__setattr__(self, "onset", _check_onset(self.onset, cutoff=cutoff))
object.__setattr__(self, "kind", _check_kind(self.kind))
[docs]
def values(self, distances: object) -> torch.Tensor:
"""Evaluate envelope values at pair distances."""
return cutoff_envelope_values(
distances,
cutoff=self.cutoff,
onset=self.onset,
kind=self.kind,
)
[docs]
def derivatives(self, distances: object) -> torch.Tensor:
"""Evaluate derivatives ``d envelope / d distance``."""
return cutoff_envelope_derivatives(
distances,
cutoff=self.cutoff,
onset=self.onset,
kind=self.kind,
)
[docs]
def to_dict(self) -> dict[str, float | str]:
"""Return metadata suitable for projection diagnostics or checkpoints."""
return {
"kind": self.kind,
"cutoff": self.cutoff,
"onset": self.onset,
}
[docs]
def normalize_cutoff_envelope(
envelope: CutoffEnvelope | str | None,
*,
cutoff: float,
default_kind: str = "smoothstep",
default_onset_fraction: float = 0.8,
) -> CutoffEnvelope:
"""Normalize a cutoff-envelope specification."""
cutoff = _check_cutoff(cutoff)
if isinstance(envelope, CutoffEnvelope):
if not math.isclose(
envelope.cutoff,
cutoff,
rel_tol=0.0,
abs_tol=1.0e-12,
):
raise ValueError("`envelope.cutoff` must match the term cutoff")
return envelope
kind = default_kind if envelope is None else str(envelope)
kind = _check_kind(kind)
onset = 0.0 if kind == "none" else float(default_onset_fraction) * cutoff
return CutoffEnvelope(cutoff=cutoff, onset=onset, kind=kind)
[docs]
def cutoff_envelope_values(
distances: object,
*,
cutoff: float,
onset: float = 0.0,
kind: CutoffEnvelopeKind = "smoothstep",
) -> torch.Tensor:
"""Evaluate a supported cutoff envelope."""
cutoff = _check_cutoff(cutoff)
onset = _check_onset(onset, cutoff=cutoff)
kind = _check_kind(kind)
r = _as_distance_tensor(distances)
if kind == "none":
return torch.where(r < cutoff, torch.ones_like(r), torch.zeros_like(r))
width = cutoff - onset
x = ((r - onset) / width).clamp(0.0, 1.0)
if kind == "cosine":
middle_values = 0.5 * (1.0 + torch.cos(math.pi * x))
else:
smoothstep = x.pow(3) * (10.0 - 15.0 * x + 6.0 * x.square())
middle_values = 1.0 - smoothstep
return torch.where(
r <= onset,
torch.ones_like(r),
torch.where(r >= cutoff, torch.zeros_like(r), middle_values),
)
[docs]
def cutoff_envelope_derivatives(
distances: object,
*,
cutoff: float,
onset: float = 0.0,
kind: CutoffEnvelopeKind = "smoothstep",
) -> torch.Tensor:
"""Evaluate derivative of a supported cutoff envelope."""
cutoff = _check_cutoff(cutoff)
onset = _check_onset(onset, cutoff=cutoff)
kind = _check_kind(kind)
r = _as_distance_tensor(distances)
derivative = torch.zeros_like(r)
if kind == "none":
return derivative
width = cutoff - onset
x = ((r - onset) / width).clamp(0.0, 1.0)
if kind == "cosine":
middle_derivative = -0.5 * math.pi * torch.sin(math.pi * x) / width
else:
middle_derivative = (
-30.0 * x.square() + 60.0 * x.pow(3) - 30.0 * x.pow(4)
) / width
middle_mask = (r > onset) & (r < cutoff)
return torch.where(middle_mask, middle_derivative, derivative)
[docs]
def apply_cutoff_envelope(
values: torch.Tensor,
distances: object,
envelope: CutoffEnvelope,
) -> torch.Tensor:
"""Multiply values by a cutoff envelope evaluated at the same distances."""
return values * envelope.values(distances).to(
dtype=values.dtype, device=values.device
)
__all__ = [
"CutoffEnvelope",
"CutoffEnvelopeKind",
"apply_cutoff_envelope",
"cutoff_envelope_derivatives",
"cutoff_envelope_values",
"normalize_cutoff_envelope",
]