"""
Base potential interfaces and batch wrappers for UFP models.
Subclass ``UFPotential`` for model logic; wrap it in ``BatchedUFPotential``
when callers need padded batch-shaped atomwise outputs.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import replace
from typing import Optional, Sequence, Union
import ase
import torch
from ufp.core._execution import (
GeometryNeighborListCache,
derive_batched_forces_from_energy,
derive_forces_from_energy,
prepare_ase_input,
validate_output,
)
from ufp.core.input import UFPInput, UFPInputState
from ufp.core.output import UFPBatchOutput, UFPOutput
from ufp.neighbors._data import NeighborListData
from ufp.neighbors._neighbors import NeighborListBackend
[docs]
class UFPotential(torch.nn.Module, ABC):
"""
Base class for UFP interatomic potentials.
Subclasses receive a torch-native :class:`UFPInput` bundle and should return
:class:`UFPOutput`. The base class provides ASE conversion, optional neighbor-list
construction, output validation, and automatic force derivation from differentiable
energies for slow standalone workflows.
Args:
cutoff: Default cutoff used when building neighbor lists internally.
neighbor_backend: Default backend used for neighbor-list construction.
"""
def __init__(
self,
cutoff: Optional[float] = None,
neighbor_backend: Union[str, NeighborListBackend] = NeighborListBackend.AUTO,
) -> None:
"""Store the default cutoff and neighbor-list backend."""
super().__init__()
self.cutoff = None if cutoff is None else float(cutoff)
self.neighbor_backend = NeighborListBackend(neighbor_backend)
[docs]
@abstractmethod
def forward(self, inputs: UFPInput) -> UFPOutput:
"""
Run the potential on one or more systems.
Args:
inputs: Normalized input bundle.
Returns:
Model predictions for the input systems.
"""
[docs]
def preferred_dtype(self) -> torch.dtype:
"""Return the parameter dtype or the torch default."""
for tensor in list(self.parameters()) + list(self.buffers()):
if tensor.is_floating_point():
return tensor.dtype
return torch.get_default_dtype()
[docs]
def provides_forces(self) -> bool:
"""Report whether the subclass returns forces directly."""
return False
[docs]
def compute(
self,
atoms: Union[ase.Atoms, Sequence[ase.Atoms]],
neighbor_list: Optional[
Union[NeighborListData, Sequence[NeighborListData]]
] = None,
backend: Optional[Union[str, NeighborListBackend]] = None,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
derive_forces: bool = False,
neighbor_list_cache: Optional[GeometryNeighborListCache] = None,
state: Optional[UFPInputState] = None,
extract_state: bool = True,
) -> UFPOutput:
"""Prepare ASE input first, then delegate to ``compute_input``."""
inputs = self.prepare_input(
atoms,
neighbor_list=neighbor_list,
backend=backend,
device=device,
dtype=dtype,
requires_grad=derive_forces and not self.provides_forces(),
neighbor_list_cache=neighbor_list_cache,
state=state,
extract_state=extract_state,
)
return self.compute_input(inputs, derive_forces=derive_forces)
[docs]
class BatchedUFPotential(UFPotential):
"""
Batch-oriented wrapper around :class:`UFPotential`.
``UFPotential`` already supports evaluating a list of structures internally, but it
returns flat per-atom arrays over the concatenated atoms. This wrapper keeps the
same model logic and exposes a dedicated batch API returning per-system energies and
padded per-atom outputs with an explicit batch axis.
When forces are derived from differentiable energies, this wrapper differentiates
each system energy separately and checks that no gradient leaks onto atoms from any
other system. That makes the concatenated-neighbor-list batching strategy explicit
and safe for multi-structure evaluation.
Args:
potential: Wrapped UFP potential.
independence_tolerance: Tolerance used when checking that one system's
energy does not depend on another system's atoms.
"""
def __init__(
self,
potential: UFPotential,
*,
independence_tolerance: float = 1.0e-12,
) -> None:
"""Wrap one potential while preserving its cutoff and backend defaults."""
super().__init__(
cutoff=potential.cutoff,
neighbor_backend=potential.neighbor_backend,
)
self.potential = potential
self.independence_tolerance = float(independence_tolerance)
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Delegate the flat model evaluation to the wrapped potential."""
return self.potential(inputs)
[docs]
def provides_forces(self) -> bool:
"""Mirror force availability from the wrapped potential."""
return self.potential.provides_forces()
[docs]
def compute_batch(
self,
atoms: Union[ase.Atoms, Sequence[ase.Atoms]],
neighbor_list: Optional[
Union[NeighborListData, Sequence[NeighborListData]]
] = None,
backend: Optional[Union[str, NeighborListBackend]] = None,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
derive_forces: bool = False,
neighbor_list_cache: Optional[GeometryNeighborListCache] = None,
state: Optional[UFPInputState] = None,
extract_state: bool = True,
) -> UFPBatchOutput:
"""Prepare batched ASE input, then delegate to ``compute_batch_input``."""
inputs = self.prepare_input(
atoms,
neighbor_list=neighbor_list,
backend=backend,
device=device,
dtype=dtype,
requires_grad=derive_forces and not self.provides_forces(),
neighbor_list_cache=neighbor_list_cache,
state=state,
extract_state=extract_state,
)
return self.compute_batch_input(inputs, derive_forces=derive_forces)
__all__ = [
"BatchedUFPotential",
"GeometryNeighborListCache",
"UFPotential",
]