Source code for ufp.terms._base

"""
Abstract base classes for UFP terms.

This module defines the common interface used by pair, three-body, and generic
term implementations inside ``UFPModel``.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Sequence

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._parameters import ParameterBlock


[docs] @dataclass(frozen=True) class TermCacheOptions: """Options passed to optional term-level input cache warmers.""" feature_cache_storage: str = "cpu" feature_cache_mode: str = "auto" feature_cache_dir: object = None cache_prefix: str = "batch" legacy_cache_prefixes: tuple[str, ...] = () include_per_atom_energy: bool = True
[docs] @dataclass(frozen=True) class LinearAssemblyOptions: """Options passed to optional term-level least-squares assemblers.""" blocks: tuple[object, ...] = () threebody_lstsq_backend: str | None = None threebody_bucket_backend: str | None = None threebody_runtime_config: object | None = None
[docs] @dataclass(frozen=True) class TermInputRequirements: """Declarative input requirements for a term's forward or assembly path.""" neighbor_list: bool = False full_neighbor_list: bool = False state_fields: tuple[str, ...] = ()
[docs] def validate(self, inputs: UFPInput, *, term_name: str) -> None: """Raise clear errors for missing geometry or charge/spin state.""" if self.full_neighbor_list: if inputs.neighbor_list is None: raise RuntimeError(f"{term_name} requires a neighbor list") if not inputs.neighbor_list.full_list: raise RuntimeError(f"{term_name} requires a full neighbor list") elif self.neighbor_list and inputs.neighbor_list is None: raise RuntimeError(f"{term_name} requires a neighbor list") missing_state = inputs.missing_state_fields(self.state_fields) if missing_state: fields = ", ".join(f"`{field}`" for field in missing_state) raise RuntimeError(f"{term_name} requires input state fields: {fields}")
[docs] class UFPTerm(torch.nn.Module, ABC): """ Base class for pair and higher-order interaction terms. """ def __init__( self, *, cutoff: Optional[float] = None, atomic_types: Optional[Sequence[int]] = None, ) -> None: """Store optional cutoff and supported atomic numbers for the term.""" super().__init__() self.cutoff = None if cutoff is None else float(cutoff) self.atomic_types = ( None if atomic_types is None else tuple(sorted(set(int(z) for z in atomic_types))) )
[docs] @abstractmethod def forward(self, inputs: UFPInput) -> UFPOutput: """ Compute this term's contribution to the total model output. """
@property def provides_forces(self) -> bool: """Report whether the term returns forces analytically.""" return False
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return fittable parameter blocks exposed by this term.""" return ()
@property def input_requirements(self) -> TermInputRequirements: """Return optional input requirements declared by this term.""" return TermInputRequirements() @property def optimizer_group(self) -> str | None: """Return an optional workflow optimizer group name for this term.""" return None
[docs] def validate_inputs(self, inputs: UFPInput) -> None: """Validate inputs against this term's declared requirements.""" self.input_requirements.validate(inputs, term_name=type(self).__name__)
[docs] def cache_input( self, inputs: UFPInput, options: TermCacheOptions | None = None, **kwargs, ) -> None: """Optionally precompute reusable metadata for one input.""" del inputs, options, kwargs
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Optionally assemble this term's least-squares blocks in one call.""" del batch, targets, options return None
class PairTerm(UFPTerm): """Base class for pair interaction terms.""" pass
[docs] class OneBodyTerm(UFPTerm): """Base class for one-body contribution terms.""" pass
[docs] class ThreeBodyTerm(UFPTerm): """Base class for three-body interaction terms.""" pass
__all__ = [ "OneBodyTerm", "PairTerm", "ThreeBodyTerm", "LinearAssemblyOptions", "TermCacheOptions", "TermInputRequirements", "UFPTerm", ]