Source code for ufp.terms._base
"""
Abstract base classes for UFP terms.
This module defines the common interface used by pair, three-body, and generic
term implementations inside ``UFPModel``.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._parameters import ParameterBlock
[docs]
@dataclass(frozen=True)
class TermCacheOptions:
"""Options passed to optional term-level input cache warmers."""
feature_cache_storage: str = "cpu"
feature_cache_mode: str = "auto"
feature_cache_dir: object = None
cache_prefix: str = "batch"
legacy_cache_prefixes: tuple[str, ...] = ()
include_per_atom_energy: bool = True
[docs]
@dataclass(frozen=True)
class LinearAssemblyOptions:
"""Options passed to optional term-level least-squares assemblers."""
blocks: tuple[object, ...] = ()
threebody_lstsq_backend: str | None = None
threebody_bucket_backend: str | None = None
threebody_runtime_config: object | None = None
[docs]
class UFPTerm(torch.nn.Module, ABC):
"""
Base class for pair and higher-order interaction terms.
"""
def __init__(
self,
*,
cutoff: Optional[float] = None,
atomic_types: Optional[Sequence[int]] = None,
) -> None:
"""Store optional cutoff and supported atomic numbers for the term."""
super().__init__()
self.cutoff = None if cutoff is None else float(cutoff)
self.atomic_types = (
None
if atomic_types is None
else tuple(sorted(set(int(z) for z in atomic_types)))
)
[docs]
@abstractmethod
def forward(self, inputs: UFPInput) -> UFPOutput:
"""
Compute this term's contribution to the total model output.
"""
@property
def provides_forces(self) -> bool:
"""Report whether the term returns forces analytically."""
return False
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return fittable parameter blocks exposed by this term."""
return ()
@property
def input_requirements(self) -> TermInputRequirements:
"""Return optional input requirements declared by this term."""
return TermInputRequirements()
@property
def optimizer_group(self) -> str | None:
"""Return an optional workflow optimizer group name for this term."""
return None
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Optionally assemble this term's least-squares blocks in one call."""
del batch, targets, options
return None
class PairTerm(UFPTerm):
"""Base class for pair interaction terms."""
pass
[docs]
class OneBodyTerm(UFPTerm):
"""Base class for one-body contribution terms."""
pass
[docs]
class ThreeBodyTerm(UFPTerm):
"""Base class for three-body interaction terms."""
pass
__all__ = [
"OneBodyTerm",
"PairTerm",
"ThreeBodyTerm",
"LinearAssemblyOptions",
"TermCacheOptions",
"TermInputRequirements",
"UFPTerm",
]