Source code for ufp.core.potential

"""
Base potential interfaces and batch wrappers for UFP models.

Subclass ``UFPotential`` for model logic; wrap it in ``BatchedUFPotential``
when callers need padded batch-shaped atomwise outputs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import replace
from typing import Optional, Sequence, Union

import ase
import torch

from ufp.core._execution import (
    GeometryNeighborListCache,
    derive_batched_forces_from_energy,
    derive_forces_from_energy,
    prepare_ase_input,
    validate_output,
)
from ufp.core.input import UFPInput, UFPInputState
from ufp.core.output import UFPBatchOutput, UFPOutput
from ufp.neighbors._data import NeighborListData
from ufp.neighbors._neighbors import NeighborListBackend


[docs] class UFPotential(torch.nn.Module, ABC): """ Base class for UFP interatomic potentials. Subclasses receive a torch-native :class:`UFPInput` bundle and should return :class:`UFPOutput`. The base class provides ASE conversion, optional neighbor-list construction, output validation, and automatic force derivation from differentiable energies for slow standalone workflows. Args: cutoff: Default cutoff used when building neighbor lists internally. neighbor_backend: Default backend used for neighbor-list construction. """ def __init__( self, cutoff: Optional[float] = None, neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO, ) -> None: """Store the default cutoff and neighbor-list backend.""" super().__init__() self.cutoff = None if cutoff is None else float(cutoff) self.neighbor_backend = NeighborListBackend(neighbor_backend)
[docs] @abstractmethod def forward(self, inputs: UFPInput) -> UFPOutput: """ Run the potential on one or more systems. Args: inputs: Normalized input bundle. Returns: Model predictions for the input systems. """
[docs] def preferred_dtype(self) -> torch.dtype: """Return the parameter dtype or the torch default.""" for tensor in list(self.parameters()) + list(self.buffers()): if tensor.is_floating_point(): return tensor.dtype return torch.get_default_dtype()
[docs] def provides_forces(self) -> bool: """Report whether the subclass returns forces directly.""" return False
[docs] def prepare_input( self, atoms: Union[ase.Atoms, Sequence[ase.Atoms]], neighbor_list: Optional[ Union[NeighborListData, Sequence[NeighborListData]] ] = None, backend: Optional[Union[str, NeighborListBackend]] = None, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, neighbor_list_cache: Optional[GeometryNeighborListCache] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> UFPInput: """Convert ASE structures into ``UFPInput`` before evaluation.""" resolved_dtype = self.preferred_dtype() if dtype is None else dtype return prepare_ase_input( atoms, cutoff=self.cutoff, default_backend=self.neighbor_backend, neighbor_list=neighbor_list, backend=backend, device=device, dtype=resolved_dtype, requires_grad=requires_grad, neighbor_list_cache=neighbor_list_cache, state=state, extract_state=extract_state, )
[docs] def compute_input( self, inputs: UFPInput, *, derive_forces: bool = False, ) -> UFPOutput: """Run ``forward``, validate the result, and derive forces when requested.""" output = self.forward(inputs) validate_output(output, inputs) if derive_forces and output.forces is None: output = replace( output, forces=derive_forces_from_energy( output, inputs, training=self.training, ), ) validate_output(output, inputs) return output
[docs] def compute( self, atoms: Union[ase.Atoms, Sequence[ase.Atoms]], neighbor_list: Optional[ Union[NeighborListData, Sequence[NeighborListData]] ] = None, backend: Optional[Union[str, NeighborListBackend]] = None, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, derive_forces: bool = False, neighbor_list_cache: Optional[GeometryNeighborListCache] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> UFPOutput: """Prepare ASE input first, then delegate to ``compute_input``.""" inputs = self.prepare_input( atoms, neighbor_list=neighbor_list, backend=backend, device=device, dtype=dtype, requires_grad=derive_forces and not self.provides_forces(), neighbor_list_cache=neighbor_list_cache, state=state, extract_state=extract_state, ) return self.compute_input(inputs, derive_forces=derive_forces)
[docs] class BatchedUFPotential(UFPotential): """ Batch-oriented wrapper around :class:`UFPotential`. ``UFPotential`` already supports evaluating a list of structures internally, but it returns flat per-atom arrays over the concatenated atoms. This wrapper keeps the same model logic and exposes a dedicated batch API returning per-system energies and padded per-atom outputs with an explicit batch axis. When forces are derived from differentiable energies, this wrapper differentiates each system energy separately and checks that no gradient leaks onto atoms from any other system. That makes the concatenated-neighbor-list batching strategy explicit and safe for multi-structure evaluation. Args: potential: Wrapped UFP potential. independence_tolerance: Tolerance used when checking that one system's energy does not depend on another system's atoms. """ def __init__( self, potential: UFPotential, *, independence_tolerance: float = 1.0e-12, ) -> None: """Wrap one potential while preserving its cutoff and backend defaults.""" super().__init__( cutoff=potential.cutoff, neighbor_backend=potential.neighbor_backend, ) self.potential = potential self.independence_tolerance = float(independence_tolerance)
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Delegate the flat model evaluation to the wrapped potential.""" return self.potential(inputs)
[docs] def provides_forces(self) -> bool: """Mirror force availability from the wrapped potential.""" return self.potential.provides_forces()
[docs] def compute_batch_input( self, inputs: UFPInput, *, derive_forces: bool = False, ) -> UFPBatchOutput: """Run the wrapped model and repack the result into ``UFPBatchOutput``.""" output = self.compute_input(inputs, derive_forces=derive_forces) return UFPBatchOutput.from_output(output, inputs)
[docs] def compute_batch( self, atoms: Union[ase.Atoms, Sequence[ase.Atoms]], neighbor_list: Optional[ Union[NeighborListData, Sequence[NeighborListData]] ] = None, backend: Optional[Union[str, NeighborListBackend]] = None, *, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, derive_forces: bool = False, neighbor_list_cache: Optional[GeometryNeighborListCache] = None, state: Optional[UFPInputState] = None, extract_state: bool = True, ) -> UFPBatchOutput: """Prepare batched ASE input, then delegate to ``compute_batch_input``.""" inputs = self.prepare_input( atoms, neighbor_list=neighbor_list, backend=backend, device=device, dtype=dtype, requires_grad=derive_forces and not self.provides_forces(), neighbor_list_cache=neighbor_list_cache, state=state, extract_state=extract_state, ) return self.compute_batch_input(inputs, derive_forces=derive_forces)
[docs] def compute_input( self, inputs: UFPInput, *, derive_forces: bool = False, ) -> UFPOutput: """Differentiate each system energy independently for forces.""" output = self.forward(inputs) validate_output(output, inputs) if derive_forces and output.forces is None: output = replace( output, forces=derive_batched_forces_from_energy( output, inputs, training=self.training, independence_tolerance=self.independence_tolerance, ), ) validate_output(output, inputs) return output
__all__ = [ "BatchedUFPotential", "GeometryNeighborListCache", "UFPotential", ]