"""
Element-wise one-body energy term.
Use this module when the model needs fixed or trainable reference energies per
atomic species in addition to geometric interaction terms.
"""
from __future__ import annotations
from collections.abc import Sequence
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.terms._base import LinearAssemblyOptions, OneBodyTerm
from ufp.terms._parameters import (
ParameterBlock,
ParameterBlockCacheChannel,
ParameterBlockCacheDescriptor,
copy_parameter_data,
)
[docs]
class ElementOneBodyTerm(OneBodyTerm):
"""
Per-element one-body energy contribution.
"""
def __init__(
self,
*,
atomic_types: Sequence[int],
values=None,
trainable: bool = True,
fittable: bool = True,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store one reference energy coefficient per declared atomic type."""
normalized_atomic_types = tuple(sorted(set(int(z) for z in atomic_types)))
if not normalized_atomic_types:
raise ValueError("`atomic_types` must contain at least one element")
super().__init__(cutoff=None, atomic_types=normalized_atomic_types)
if values is None:
value_tensor = torch.zeros(len(normalized_atomic_types), dtype=dtype)
else:
value_tensor = torch.as_tensor(values, dtype=dtype)
if value_tensor.ndim == 0 and len(normalized_atomic_types) == 1:
value_tensor = value_tensor.reshape(1)
if tuple(value_tensor.shape) != (len(normalized_atomic_types),):
raise ValueError(
"`values` must have shape "
f"({len(normalized_atomic_types)},), got {tuple(value_tensor.shape)}"
)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
self.values = torch.nn.Parameter(
value_tensor,
requires_grad=bool(trainable) and not self.frozen,
)
@property
def provides_forces(self) -> bool:
"""Report that this term provides the expected zero-force contribution."""
return True
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return the element-reference coefficient block."""
assert self.atomic_types is not None
return (
ParameterBlock(
name="values",
kind="onebody",
shape=tuple(int(dim) for dim in self.values.shape),
read=lambda: self.values,
write=lambda values: copy_parameter_data(self.values, values),
label=f"onebody[{self.atomic_types}]",
regularization_group="onebody",
fittable=self.fittable,
frozen=self.frozen,
assembler="onebody",
cache_descriptor=ParameterBlockCacheDescriptor(
family={"kind": "onebody"},
channels=tuple(
ParameterBlockCacheChannel(
kind="Z",
values=(atomic_number,),
start=index,
stop=index + 1,
)
for index, atomic_number in enumerate(self.atomic_types)
),
),
),
)
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble one-body least-squares blocks for this term."""
from ufp.leastsquares._assemble import _assemble_onebody_block
blocks = () if options is None else options.blocks
return {
block.index: matrix
for block in blocks
if (matrix := _assemble_onebody_block(block, batch.inputs, targets))
is not None
}
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Assign learned reference energies and sum by system."""
assert self.atomic_types is not None
values = self.values.to(device=inputs.device, dtype=inputs.dtype)
value_indices = inputs.atomic_category_indices(self.atomic_types)
covered_atoms = value_indices >= 0
per_atom_energy = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
per_atom_energy[covered_atoms] = values[value_indices[covered_atoms]]
energy = torch.zeros(
inputs.n_systems,
dtype=inputs.dtype,
device=inputs.device,
)
energy.index_add_(0, inputs.system_index, per_atom_energy)
return UFPOutput(
energy=energy,
forces=torch.zeros(
(inputs.n_atoms, 3),
dtype=inputs.dtype,
device=inputs.device,
),
per_atom_energy=per_atom_energy,
)
__all__ = [
"ElementOneBodyTerm",
"OneBodyTerm",
]