"""
Composite UFP model built from additive term objects.
Use this module to combine pair and three-body contributions while keeping
atomic-type checks and output summation centralized.
"""
from __future__ import annotations
from typing import Optional, Sequence
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput, sum_outputs
from ufp.core.potential import UFPotential
from ufp.neighbors._neighbors import NeighborListBackend
from ufp.terms._base import OneBodyTerm, PairTerm, ThreeBodyTerm, UFPTerm
from ufp.terms.alchemical import AlchemicalCoefficients
def _max_cutoff(terms: Sequence[UFPTerm]) -> Optional[float]:
"""Return the maximum cutoff declared by the provided terms."""
cutoffs = [term.cutoff for term in terms if term.cutoff is not None]
if not cutoffs:
return None
return max(cutoffs)
def _union_atomic_types(terms: Sequence[UFPTerm]) -> Optional[tuple[int, ...]]:
"""Collect the sorted union of atomic types declared by the provided terms."""
atomic_types = sorted(
{
atomic_number
for term in terms
if term.atomic_types is not None
for atomic_number in term.atomic_types
}
)
if not atomic_types:
return None
return tuple(atomic_types)
def _collect_alchemical_coefficients(
terms: Sequence[UFPTerm],
) -> tuple[AlchemicalCoefficients, ...]:
"""Collect unique alchemical providers referenced by the provided terms."""
coefficient_modules: list[AlchemicalCoefficients] = []
seen: set[int] = set()
for term in terms:
coefficient_provider = getattr(term, "coefficient_provider", None)
if coefficient_provider is None:
continue
provider_id = id(coefficient_provider)
if provider_id in seen:
continue
seen.add(provider_id)
coefficient_modules.append(coefficient_provider)
return tuple(coefficient_modules)
[docs]
class UFPModel(UFPotential):
"""
Concrete UFP model assembled from pair and three-body terms.
"""
def __init__(
self,
*,
terms: Sequence[UFPTerm] | None = None,
onebody_terms: Sequence[OneBodyTerm] = (),
pair_terms: Sequence[PairTerm] = (),
threebody_terms: Sequence[ThreeBodyTerm] = (),
atomic_types: Optional[Sequence[int]] = None,
neighbor_backend: str | NeighborListBackend = NeighborListBackend.AUTO,
) -> None:
"""Initialize the composite model from additive terms."""
provided_terms = () if terms is None else tuple(terms)
onebody_terms = tuple(onebody_terms)
pair_terms = tuple(pair_terms)
threebody_terms = tuple(threebody_terms)
all_terms = provided_terms + onebody_terms + pair_terms + threebody_terms
if not all_terms:
raise ValueError(
"`terms`, `onebody_terms`, `pair_terms`, and `threebody_terms` "
"can not all be empty"
)
for term in provided_terms:
if not isinstance(term, UFPTerm):
raise TypeError("all `terms` entries must be UFPTerm instances")
for term in onebody_terms:
if not isinstance(term, OneBodyTerm):
raise TypeError(
"all `onebody_terms` entries must be OneBodyTerm instances"
)
for term in pair_terms:
if not isinstance(term, PairTerm):
raise TypeError("all `pair_terms` entries must be PairTerm instances")
for term in threebody_terms:
if not isinstance(term, ThreeBodyTerm):
raise TypeError(
"all `threebody_terms` entries must be ThreeBodyTerm instances"
)
super().__init__(
cutoff=_max_cutoff(all_terms),
neighbor_backend=neighbor_backend,
)
grouped_onebody: list[OneBodyTerm] = []
grouped_pair: list[PairTerm] = []
grouped_threebody: list[ThreeBodyTerm] = []
grouped_other: list[UFPTerm] = []
order: list[tuple[str, int]] = []
def add_term(term: UFPTerm) -> None:
if isinstance(term, OneBodyTerm):
order.append(("onebody", len(grouped_onebody)))
grouped_onebody.append(term)
elif isinstance(term, PairTerm):
order.append(("pair", len(grouped_pair)))
grouped_pair.append(term)
elif isinstance(term, ThreeBodyTerm):
order.append(("threebody", len(grouped_threebody)))
grouped_threebody.append(term)
else:
order.append(("other", len(grouped_other)))
grouped_other.append(term)
for term in provided_terms:
add_term(term)
for term in onebody_terms:
add_term(term)
for term in pair_terms:
add_term(term)
for term in threebody_terms:
add_term(term)
self.onebody_terms = torch.nn.ModuleList(grouped_onebody)
self.pair_terms = torch.nn.ModuleList(grouped_pair)
self.threebody_terms = torch.nn.ModuleList(grouped_threebody)
self.other_terms = torch.nn.ModuleList(grouped_other)
object.__setattr__(self, "_term_order", tuple(order))
self.alchemical_coefficients = torch.nn.ModuleList(
_collect_alchemical_coefficients(all_terms)
)
inferred_atomic_types = _union_atomic_types(all_terms)
self.atomic_types = (
inferred_atomic_types
if atomic_types is None
else tuple(sorted(set(int(z) for z in atomic_types)))
)
self.register_buffer(
"_atomic_types_tensor",
torch.empty(0, dtype=torch.int64)
if self.atomic_types is None
else torch.tensor(self.atomic_types, dtype=torch.int64),
persistent=False,
)
@property
def terms(self) -> tuple[UFPTerm, ...]:
"""Return all terms in evaluation order."""
groups = {
"onebody": self.onebody_terms,
"pair": self.pair_terms,
"threebody": self.threebody_terms,
"other": self.other_terms,
}
return tuple(groups[name][index] for name, index in self._term_order)
[docs]
def provides_forces(self) -> bool:
"""Report whether every child term provides forces directly."""
return all(term.provides_forces for term in self.terms)
def _validate_input_atomic_types(self, inputs: UFPInput) -> None:
"""Reject inputs outside the model's declared atomic types."""
if self.atomic_types is None:
return
unsupported_mask = inputs.atomic_category_indices(self.atomic_types) < 0
if bool(torch.any(unsupported_mask)):
unsupported_types = torch.unique(inputs.atomic_numbers[unsupported_mask])
unsupported_list = sorted(
int(value) for value in unsupported_types.detach().cpu().tolist()
)
raise RuntimeError(
"encountered unsupported atomic numbers: "
+ ", ".join(str(z) for z in unsupported_list)
)
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Validate atomic types, then sum child term outputs into one model result."""
self._validate_input_atomic_types(inputs)
outputs = [term(inputs) for term in self.terms]
return sum_outputs(outputs, inputs)
__all__ = [
"UFPModel",
]