Source code for ufp.terms.cutoffs

"""Reusable cutoff envelope helpers for analytic pair priors."""

from __future__ import annotations

import math
from dataclasses import dataclass

import torch


CutoffEnvelopeKind = str

_SUPPORTED_KINDS = {"none", "cosine", "smoothstep"}


def _check_cutoff(cutoff: float) -> float:
    """Validate and normalize one cutoff distance."""
    cutoff = float(cutoff)
    if not math.isfinite(cutoff) or cutoff <= 0.0:
        raise ValueError("`cutoff` must be a finite positive value")
    return cutoff


def _check_onset(onset: float, *, cutoff: float) -> float:
    """Validate and normalize one cutoff-envelope onset."""
    onset = float(onset)
    if not math.isfinite(onset) or onset < 0.0:
        raise ValueError("`onset` must be a finite non-negative value")
    if onset >= cutoff:
        raise ValueError("`onset` must be smaller than `cutoff`")
    return onset


def _check_kind(kind: str) -> str:
    """Validate and normalize one cutoff-envelope kind."""
    kind = str(kind)
    if kind not in _SUPPORTED_KINDS:
        supported = ", ".join(sorted(_SUPPORTED_KINDS))
        raise ValueError(f"`kind` must be one of: {supported}")
    return kind


def _as_distance_tensor(distances: object) -> torch.Tensor:
    """Return distances as a floating-point tensor without changing device."""
    tensor = torch.as_tensor(distances)
    if not tensor.is_floating_point():
        tensor = tensor.to(dtype=torch.get_default_dtype())
    return tensor


[docs] @dataclass(frozen=True) class CutoffEnvelope: """Smooth multiplicative envelope that tapers pair energies to zero.""" cutoff: float onset: float = 0.0 kind: CutoffEnvelopeKind = "smoothstep" def __post_init__(self) -> None: """Validate envelope metadata.""" cutoff = _check_cutoff(self.cutoff) object.__setattr__(self, "cutoff", cutoff) object.__setattr__(self, "onset", _check_onset(self.onset, cutoff=cutoff)) object.__setattr__(self, "kind", _check_kind(self.kind))
[docs] def values(self, distances: object) -> torch.Tensor: """Evaluate envelope values at pair distances.""" return cutoff_envelope_values( distances, cutoff=self.cutoff, onset=self.onset, kind=self.kind, )
[docs] def derivatives(self, distances: object) -> torch.Tensor: """Evaluate derivatives ``d envelope / d distance``.""" return cutoff_envelope_derivatives( distances, cutoff=self.cutoff, onset=self.onset, kind=self.kind, )
[docs] def to_dict(self) -> dict[str, float | str]: """Return metadata suitable for projection diagnostics or checkpoints.""" return { "kind": self.kind, "cutoff": self.cutoff, "onset": self.onset, }
[docs] def normalize_cutoff_envelope( envelope: CutoffEnvelope | str | None, *, cutoff: float, default_kind: str = "smoothstep", default_onset_fraction: float = 0.8, ) -> CutoffEnvelope: """Normalize a cutoff-envelope specification.""" cutoff = _check_cutoff(cutoff) if isinstance(envelope, CutoffEnvelope): if not math.isclose( envelope.cutoff, cutoff, rel_tol=0.0, abs_tol=1.0e-12, ): raise ValueError("`envelope.cutoff` must match the term cutoff") return envelope kind = default_kind if envelope is None else str(envelope) kind = _check_kind(kind) onset = 0.0 if kind == "none" else float(default_onset_fraction) * cutoff return CutoffEnvelope(cutoff=cutoff, onset=onset, kind=kind)
[docs] def cutoff_envelope_values( distances: object, *, cutoff: float, onset: float = 0.0, kind: CutoffEnvelopeKind = "smoothstep", ) -> torch.Tensor: """Evaluate a supported cutoff envelope.""" cutoff = _check_cutoff(cutoff) onset = _check_onset(onset, cutoff=cutoff) kind = _check_kind(kind) r = _as_distance_tensor(distances) if kind == "none": return torch.where(r < cutoff, torch.ones_like(r), torch.zeros_like(r)) width = cutoff - onset x = ((r - onset) / width).clamp(0.0, 1.0) if kind == "cosine": middle_values = 0.5 * (1.0 + torch.cos(math.pi * x)) else: smoothstep = x.pow(3) * (10.0 - 15.0 * x + 6.0 * x.square()) middle_values = 1.0 - smoothstep return torch.where( r <= onset, torch.ones_like(r), torch.where(r >= cutoff, torch.zeros_like(r), middle_values), )
[docs] def cutoff_envelope_derivatives( distances: object, *, cutoff: float, onset: float = 0.0, kind: CutoffEnvelopeKind = "smoothstep", ) -> torch.Tensor: """Evaluate derivative of a supported cutoff envelope.""" cutoff = _check_cutoff(cutoff) onset = _check_onset(onset, cutoff=cutoff) kind = _check_kind(kind) r = _as_distance_tensor(distances) derivative = torch.zeros_like(r) if kind == "none": return derivative width = cutoff - onset x = ((r - onset) / width).clamp(0.0, 1.0) if kind == "cosine": middle_derivative = -0.5 * math.pi * torch.sin(math.pi * x) / width else: middle_derivative = ( -30.0 * x.square() + 60.0 * x.pow(3) - 30.0 * x.pow(4) ) / width middle_mask = (r > onset) & (r < cutoff) return torch.where(middle_mask, middle_derivative, derivative)
[docs] def apply_cutoff_envelope( values: torch.Tensor, distances: object, envelope: CutoffEnvelope, ) -> torch.Tensor: """Multiply values by a cutoff envelope evaluated at the same distances.""" return values * envelope.values(distances).to( dtype=values.dtype, device=values.device )
__all__ = [ "CutoffEnvelope", "CutoffEnvelopeKind", "apply_cutoff_envelope", "cutoff_envelope_derivatives", "cutoff_envelope_values", "normalize_cutoff_envelope", ]