"""Dense three-body feature-cache containers and evaluators."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
from ufp.terms._threebody_features import ThreeBodyFeatureBlock, ThreeBodyFeatureCache
from ufp.terms._threebody_kernels import (
call_native_threebody_dense_feature_cache,
requires_native_threebody_backend,
should_use_native_threebody_dense_feature_cache,
)
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig
@dataclass(frozen=True)
class _MemmapTensor:
"""Descriptor for one disk-backed dense feature array."""
path: str
def to_tensor(self) -> torch.Tensor:
"""Open the ``.npy`` file as a memory-mapped torch tensor."""
array = np.load(self.path, mmap_mode="r+")
return torch.as_tensor(array)
[docs]
@dataclass(frozen=True)
class DenseTripletFeatureBlock:
"""Coalesced dense output features for one triplet category."""
triplet_index: int
coeff_start: tuple[int, int, int]
coeff_shape: tuple[int, int, int]
energy: torch.Tensor
per_atom_energy: torch.Tensor | None
per_atom_indices: torch.Tensor | None
forces: torch.Tensor
force_atom_indices: torch.Tensor | None
def __bool__(self) -> bool:
"""Report whether this block contains any nonzero feature entries."""
return bool(
torch.any(self.energy)
or (self.per_atom_energy is not None and torch.any(self.per_atom_energy))
or torch.any(self.forces)
)
[docs]
@dataclass(frozen=True)
class DenseThreeBodyFeatureCache:
"""Coalesced dense output features for one fixed input."""
blocks: tuple[DenseTripletFeatureBlock, ...]
def __bool__(self) -> bool:
"""Report whether this cache contains any dense feature blocks."""
return bool(self.blocks)
[docs]
def to_input_device(
self,
*,
device: torch.device,
dtype: torch.dtype,
) -> "DenseThreeBodyFeatureCache":
"""Return this dense feature cache on an input device and dtype."""
return DenseThreeBodyFeatureCache(
blocks=tuple(
block.to_input_device(device=device, dtype=dtype)
for block in self.blocks
)
)
[docs]
@dataclass(frozen=True)
class MemmapDenseTripletFeatureBlock:
"""Disk-backed dense output features for one triplet category."""
triplet_index: int
coeff_start: tuple[int, int, int]
coeff_shape: tuple[int, int, int]
energy: _MemmapTensor
per_atom_energy: _MemmapTensor | None
per_atom_indices: _MemmapTensor | None
forces: _MemmapTensor
force_atom_indices: _MemmapTensor | None
[docs]
@dataclass(frozen=True)
class MemmapDenseThreeBodyFeatureCache:
"""Disk-backed dense output feature blocks for one fixed input."""
blocks: tuple[MemmapDenseTripletFeatureBlock, ...]
def __bool__(self) -> bool:
"""Report whether this disk cache has dense feature block files."""
return bool(self.blocks)
[docs]
def to_input_device(
self,
*,
device: torch.device,
dtype: torch.dtype,
) -> DenseThreeBodyFeatureCache:
"""Materialize this disk-backed dense cache on an input device."""
return DenseThreeBodyFeatureCache(
blocks=tuple(
block.to_input_device(device=device, dtype=dtype)
for block in self.blocks
)
)
[docs]
@dataclass(frozen=True)
class ThreeBodyDenseAtomFeatures:
"""Dense coefficient-space output rows for selected atoms."""
atom_indices: torch.Tensor
atomic_energy: torch.Tensor
force_x: torch.Tensor
force_y: torch.Tensor
force_z: torch.Tensor
@property
def forces(self) -> torch.Tensor:
"""Return force-component rows stacked as ``(n_atoms, 3, n_features)``."""
return torch.stack((self.force_x, self.force_y, self.force_z), dim=1)
def _selected_atom_indices(
n_nodes: int,
atom_indices: Sequence[int] | torch.Tensor | None,
*,
device: torch.device,
) -> torch.Tensor:
"""Normalize selected atom indices for diagnostic feature extraction."""
if atom_indices is None:
return torch.arange(n_nodes, dtype=torch.int64, device=device)
selected = torch.as_tensor(atom_indices, dtype=torch.int64, device=device)
if selected.ndim != 1:
raise ValueError("`atom_indices` must be one-dimensional")
if torch.any(selected < 0) or torch.any(selected >= int(n_nodes)):
raise ValueError("`atom_indices` contains an atom index outside the input")
return selected
def _compact_feature_indices(
coeff_start: tuple[int, int, int],
compact_shape: tuple[int, int, int],
coeff_shape: tuple[int, int, int],
*,
triplet_index: int,
device: torch.device,
) -> torch.Tensor:
"""Return flattened full-feature indices addressed by one compact block."""
start_x, start_y, start_z = coeff_start
size_x, size_y, size_z = compact_shape
full_y, full_z = coeff_shape[1], coeff_shape[2]
local_x = torch.arange(size_x, dtype=torch.int64, device=device)
local_y = torch.arange(size_y, dtype=torch.int64, device=device)
local_z = torch.arange(size_z, dtype=torch.int64, device=device)
x, y, z = torch.meshgrid(local_x, local_y, local_z, indexing="ij")
coeff_indices = (
(x + int(start_x)) * int(full_y) * int(full_z)
+ (y + int(start_y)) * int(full_z)
+ (z + int(start_z))
)
coeff_volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2])
return coeff_indices.reshape(-1) + int(triplet_index) * coeff_volume
def _dense_atom_features_from_feature_cache(
feature_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache,
atom_indices: torch.Tensor,
*,
n_triplet_categories: int,
coeff_shape: tuple[int, int, int],
dtype: torch.dtype,
) -> ThreeBodyDenseAtomFeatures:
"""Densify per-atom energy and force rows from compact feature blocks."""
if isinstance(feature_cache, MemmapDenseThreeBodyFeatureCache):
feature_cache = feature_cache.to_input_device(
device=atom_indices.device,
dtype=dtype,
)
coeff_volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2])
n_features = int(n_triplet_categories * coeff_volume)
if not feature_cache.blocks:
device = atom_indices.device
return ThreeBodyDenseAtomFeatures(
atom_indices=atom_indices,
atomic_energy=torch.zeros(
(int(atom_indices.numel()), n_features),
dtype=dtype,
device=device,
),
force_x=torch.zeros(
(int(atom_indices.numel()), n_features),
dtype=dtype,
device=device,
),
force_y=torch.zeros(
(int(atom_indices.numel()), n_features),
dtype=dtype,
device=device,
),
force_z=torch.zeros(
(int(atom_indices.numel()), n_features),
dtype=dtype,
device=device,
),
)
first_block = feature_cache.blocks[0]
if first_block.per_atom_energy is None:
raise ValueError("dense atom features require cached per-atom energy features")
dtype = first_block.per_atom_energy.dtype
device = first_block.per_atom_energy.device
atom_indices = atom_indices.to(device=device, dtype=torch.int64)
atomic_energy = torch.zeros(
(int(atom_indices.numel()), n_features),
dtype=dtype,
device=device,
)
dense_forces = torch.zeros(
(int(atom_indices.numel()), 3, n_features),
dtype=dtype,
device=device,
)
for block in feature_cache.blocks:
feature_indices = _compact_feature_indices(
block.coeff_start,
block.coeff_shape,
coeff_shape,
triplet_index=block.triplet_index,
device=device,
)
if block.per_atom_energy is None:
raise ValueError(
"dense atom features require cached per-atom energy features"
)
if block.per_atom_indices is None:
atomic_energy[:, feature_indices] = block.per_atom_energy.index_select(
0,
atom_indices,
).reshape(int(atom_indices.numel()), -1)
else:
per_atom_positions = torch.searchsorted(
block.per_atom_indices,
atom_indices,
)
valid = per_atom_positions < int(block.per_atom_indices.numel())
valid_indices = torch.nonzero(valid, as_tuple=False).reshape(-1)
if valid_indices.numel():
candidate_positions = per_atom_positions.index_select(0, valid_indices)
matching = block.per_atom_indices.index_select(
0, candidate_positions
) == atom_indices.index_select(0, valid_indices)
valid_indices = valid_indices[matching]
candidate_positions = candidate_positions[matching]
atomic_energy[valid_indices[:, None], feature_indices] = (
block.per_atom_energy.index_select(
0,
candidate_positions,
).reshape(int(valid_indices.numel()), -1)
)
if block.force_atom_indices is None:
dense_forces[:, :, feature_indices] = block.forces.index_select(
0,
atom_indices,
).reshape(int(atom_indices.numel()), 3, -1)
else:
force_positions = torch.searchsorted(
block.force_atom_indices,
atom_indices,
)
valid = force_positions < int(block.force_atom_indices.numel())
valid_indices = torch.nonzero(valid, as_tuple=False).reshape(-1)
if valid_indices.numel():
candidate_positions = force_positions.index_select(0, valid_indices)
matching = block.force_atom_indices.index_select(
0, candidate_positions
) == atom_indices.index_select(0, valid_indices)
valid_indices = valid_indices[matching]
candidate_positions = candidate_positions[matching]
force_components = torch.arange(3, dtype=torch.int64, device=device)
dense_forces[
valid_indices[:, None, None],
force_components[None, :, None],
feature_indices[None, None, :],
] = block.forces.index_select(
0,
candidate_positions,
).reshape(
int(valid_indices.numel()),
3,
-1,
)
return ThreeBodyDenseAtomFeatures(
atom_indices=atom_indices,
atomic_energy=atomic_energy,
force_x=dense_forces[:, 0, :],
force_y=dense_forces[:, 1, :],
force_z=dense_forces[:, 2, :],
)
def _symmetrize_dense_feature_rows(
rows: torch.Tensor,
same_neighbor_triplet_mask: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
) -> torch.Tensor:
"""Apply same-neighbor x/y coefficient tying to dense feature rows."""
if rows.numel() == 0:
return rows
same_mask = same_neighbor_triplet_mask.to(device=rows.device, dtype=torch.bool)
if same_mask.numel() == 0 or not bool(torch.any(same_mask)):
return rows
shaped = rows.reshape(rows.shape[0], same_mask.numel(), *coeff_shape).clone()
same_rows = shaped[:, same_mask]
shaped[:, same_mask] = 0.5 * (same_rows + same_rows.transpose(2, 3))
return shaped.reshape_as(rows)
def _symmetrize_dense_atom_features(
features: ThreeBodyDenseAtomFeatures,
same_neighbor_triplet_mask: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
) -> ThreeBodyDenseAtomFeatures:
"""Apply term coefficient tying to all dense atom feature rows."""
return ThreeBodyDenseAtomFeatures(
atom_indices=features.atom_indices,
atomic_energy=_symmetrize_dense_feature_rows(
features.atomic_energy,
same_neighbor_triplet_mask,
coeff_shape=coeff_shape,
),
force_x=_symmetrize_dense_feature_rows(
features.force_x,
same_neighbor_triplet_mask,
coeff_shape=coeff_shape,
),
force_y=_symmetrize_dense_feature_rows(
features.force_y,
same_neighbor_triplet_mask,
coeff_shape=coeff_shape,
),
force_z=_symmetrize_dense_feature_rows(
features.force_z,
same_neighbor_triplet_mask,
coeff_shape=coeff_shape,
),
)
def _compact_bounds_from_indices(
indices: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
) -> tuple[tuple[int, int, int], tuple[int, int, int], torch.Tensor]:
"""Return compact coefficient bounds and local flat indices."""
nx, ny, nz = coeff_shape
ix = torch.div(indices, ny * nz, rounding_mode="floor")
rem = indices - ix * ny * nz
iy = torch.div(rem, nz, rounding_mode="floor")
iz = rem - iy * nz
start = (int(ix.min().item()), int(iy.min().item()), int(iz.min().item()))
stop = (
int(ix.max().item()) + 1,
int(iy.max().item()) + 1,
int(iz.max().item()) + 1,
)
compact_shape = (
stop[0] - start[0],
stop[1] - start[1],
stop[2] - start[2],
)
local_indices = (
(ix - start[0]) * compact_shape[1] * compact_shape[2]
+ (iy - start[1]) * compact_shape[2]
+ (iz - start[2])
)
return start, compact_shape, local_indices
def _local_indices_from_compact_bounds(
indices: torch.Tensor,
*,
coeff_start: tuple[int, int, int],
compact_shape: tuple[int, int, int],
coeff_shape: tuple[int, int, int],
) -> torch.Tensor:
"""Map flattened full coefficient indices into one compact block."""
ny, nz = coeff_shape[1], coeff_shape[2]
ix = torch.div(indices, ny * nz, rounding_mode="floor")
rem = indices - ix * ny * nz
iy = torch.div(rem, nz, rounding_mode="floor")
iz = rem - iy * nz
return (
(ix - coeff_start[0]) * compact_shape[1] * compact_shape[2]
+ (iy - coeff_start[1]) * compact_shape[2]
+ (iz - coeff_start[2])
)
def _index_add_dense_features(
target: torch.Tensor,
rows: torch.Tensor,
cols: torch.Tensor,
values: torch.Tensor,
) -> None:
"""Accumulate feature values into a dense row-feature matrix."""
rows, cols, values = torch.broadcast_tensors(rows, cols, values)
n_cols = target.shape[1]
flat_indices = rows.reshape(-1) * n_cols + cols.reshape(-1)
target.reshape(-1).index_add_(0, flat_indices, values.reshape(-1))
def _feature_cache_sparse_tensors(
feature_cache: ThreeBodyFeatureCache,
) -> tuple[torch.Tensor, ...]:
"""Return concatenated sparse feature tensors from a feature cache."""
if not feature_cache.blocks:
raise ValueError("cannot concatenate an empty three-body feature cache")
if len(feature_cache.blocks) == 1:
block = feature_cache.blocks[0]
return (
block.src_ids,
block.dst_j,
block.dst_k,
block.triplet_index,
block.stencil_indices,
block.values,
block.grad_x,
block.grad_y,
block.grad_z,
block.unit_x,
block.unit_y,
block.unit_z,
)
return (
torch.cat([block.src_ids for block in feature_cache.blocks], dim=0),
torch.cat([block.dst_j for block in feature_cache.blocks], dim=0),
torch.cat([block.dst_k for block in feature_cache.blocks], dim=0),
torch.cat([block.triplet_index for block in feature_cache.blocks], dim=0),
torch.cat([block.stencil_indices for block in feature_cache.blocks], dim=0),
torch.cat([block.values for block in feature_cache.blocks], dim=0),
torch.cat([block.grad_x for block in feature_cache.blocks], dim=0),
torch.cat([block.grad_y for block in feature_cache.blocks], dim=0),
torch.cat([block.grad_z for block in feature_cache.blocks], dim=0),
torch.cat([block.unit_x for block in feature_cache.blocks], dim=0),
torch.cat([block.unit_y for block in feature_cache.blocks], dim=0),
torch.cat([block.unit_z for block in feature_cache.blocks], dim=0),
)
def _dense_feature_cache_from_native_tensors(
tensors: tuple[torch.Tensor, ...],
*,
n_systems: int,
) -> DenseThreeBodyFeatureCache:
"""Wrap native dense-cache tensors in Python dense feature blocks."""
if len(tensors) != 13:
raise RuntimeError(
"native dense three-body feature-cache operator returned an unexpected "
f"number of tensors: {len(tensors)}"
)
(
triplet_indices,
coeff_start,
compact_shape,
energy_ptr,
per_atom_ptr,
force_ptr,
per_atom_index_ptr,
force_atom_index_ptr,
energy_values,
per_atom_values,
force_values,
per_atom_indices,
force_atom_indices,
) = tensors
blocks: list[DenseTripletFeatureBlock] = []
for block_index in range(int(triplet_indices.numel())):
shape = tuple(int(value) for value in compact_shape[block_index].tolist())
energy_start = int(energy_ptr[block_index])
energy_stop = int(energy_ptr[block_index + 1])
per_atom_start = int(per_atom_ptr[block_index])
per_atom_stop = int(per_atom_ptr[block_index + 1])
force_start = int(force_ptr[block_index])
force_stop = int(force_ptr[block_index + 1])
per_atom_index_start = int(per_atom_index_ptr[block_index])
per_atom_index_stop = int(per_atom_index_ptr[block_index + 1])
force_index_start = int(force_atom_index_ptr[block_index])
force_index_stop = int(force_atom_index_ptr[block_index + 1])
block_per_atom_indices = per_atom_indices[
per_atom_index_start:per_atom_index_stop
]
block_force_atom_indices = force_atom_indices[
force_index_start:force_index_stop
]
n_per_atom = int(block_per_atom_indices.numel())
n_force_atom = int(block_force_atom_indices.numel())
per_atom_energy = None
if per_atom_stop > per_atom_start:
per_atom_energy = per_atom_values[per_atom_start:per_atom_stop].reshape(
n_per_atom,
*shape,
)
blocks.append(
DenseTripletFeatureBlock(
triplet_index=int(triplet_indices[block_index]),
coeff_start=tuple(
int(value) for value in coeff_start[block_index].tolist()
),
coeff_shape=shape,
energy=energy_values[energy_start:energy_stop].reshape(
n_systems,
*shape,
),
per_atom_energy=per_atom_energy,
per_atom_indices=(
block_per_atom_indices if per_atom_energy is not None else None
),
forces=force_values[force_start:force_stop].reshape(
n_force_atom,
3,
*shape,
),
force_atom_indices=block_force_atom_indices,
)
)
return DenseThreeBodyFeatureCache(blocks=tuple(blocks))
def _build_dense_feature_cache_from_feature_cache_torch(
feature_cache: ThreeBodyFeatureCache,
system_index: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
force_scope: Literal["output", "source"] = "output",
include_per_atom_energy: bool = True,
) -> DenseThreeBodyFeatureCache:
"""Accumulate triplet features into dense per-category feature blocks."""
if force_scope not in {"output", "source"}:
raise ValueError("`force_scope` must be 'output' or 'source'")
if not feature_cache.blocks:
return DenseThreeBodyFeatureCache(blocks=())
device = feature_cache.blocks[0].values.device
dtype = feature_cache.blocks[0].values.dtype
n_nodes = int(system_index.numel())
n_systems = int(system_index.max().item()) + 1 if n_nodes else 0
components = torch.arange(3, dtype=torch.int64, device=device)
system_index = system_index.to(device=device, dtype=torch.int64)
triplet_indices = sorted(
{
int(index.item())
for block in feature_cache.blocks
for index in torch.unique(block.triplet_index).cpu()
}
)
dense_blocks: list[DenseTripletFeatureBlock] = []
for triplet_index in triplet_indices:
selected_blocks: list[tuple[ThreeBodyFeatureBlock, torch.Tensor]] = []
all_indices: list[torch.Tensor] = []
per_atom_index_chunks: list[torch.Tensor] = []
force_atom_index_chunks: list[torch.Tensor] = []
for block in feature_cache.blocks:
mask = block.triplet_index == triplet_index
if not torch.any(mask):
continue
selected_blocks.append((block, mask))
all_indices.append(block.stencil_indices[mask].reshape(-1))
src_ids = block.src_ids[mask]
per_atom_index_chunks.append(src_ids)
force_atom_index_chunks.append(src_ids)
if force_scope == "output":
force_atom_index_chunks.append(block.dst_j[mask])
force_atom_index_chunks.append(block.dst_k[mask])
if not all_indices:
continue
coeff_start, compact_shape, _ = _compact_bounds_from_indices(
torch.cat(all_indices, dim=0),
coeff_shape=coeff_shape,
)
compact_volume = int(compact_shape[0] * compact_shape[1] * compact_shape[2])
force_atom_indices = torch.unique(torch.cat(force_atom_index_chunks, dim=0))
force_atom_lookup = torch.full(
(n_nodes,),
-1,
dtype=torch.int64,
device=device,
)
force_atom_lookup[force_atom_indices] = torch.arange(
force_atom_indices.numel(),
dtype=torch.int64,
device=device,
)
per_atom_indices = None
per_atom_lookup = None
if include_per_atom_energy:
per_atom_indices = torch.unique(torch.cat(per_atom_index_chunks, dim=0))
per_atom_lookup = torch.full(
(n_nodes,),
-1,
dtype=torch.int64,
device=device,
)
per_atom_lookup[per_atom_indices] = torch.arange(
per_atom_indices.numel(),
dtype=torch.int64,
device=device,
)
energy = torch.zeros(
(n_systems, compact_volume),
dtype=dtype,
device=device,
)
per_atom_energy = (
torch.zeros(
(int(per_atom_indices.numel()), compact_volume),
dtype=dtype,
device=device,
)
if include_per_atom_energy
else None
)
forces = torch.zeros(
(int(force_atom_indices.numel()) * 3, compact_volume),
dtype=dtype,
device=device,
)
for block, mask in selected_blocks:
local_indices = _local_indices_from_compact_bounds(
block.stencil_indices[mask],
coeff_start=coeff_start,
compact_shape=compact_shape,
coeff_shape=coeff_shape,
)
src_ids = block.src_ids[mask]
dst_j = block.dst_j[mask]
dst_k = block.dst_k[mask]
values = block.values[mask]
system_rows = system_index.index_select(0, src_ids)
_index_add_dense_features(
energy, system_rows[:, None], local_indices, values
)
if per_atom_energy is not None:
assert per_atom_lookup is not None
per_atom_rows = per_atom_lookup.index_select(0, src_ids)
_index_add_dense_features(
per_atom_energy,
per_atom_rows[:, None],
local_indices,
values,
)
force_src = (
block.grad_x[mask, :, None] * block.unit_x[mask, None, :]
+ block.grad_y[mask, :, None] * block.unit_y[mask, None, :]
)
force_j = -(
block.grad_x[mask, :, None] * block.unit_x[mask, None, :]
+ block.grad_z[mask, :, None] * block.unit_z[mask, None, :]
)
force_k = (
-block.grad_y[mask, :, None] * block.unit_y[mask, None, :]
+ block.grad_z[mask, :, None] * block.unit_z[mask, None, :]
)
_index_add_dense_features(
forces,
force_atom_lookup.index_select(0, src_ids)[:, None, None] * 3
+ components[None, None, :],
local_indices[:, :, None],
force_src,
)
if force_scope == "output":
_index_add_dense_features(
forces,
force_atom_lookup.index_select(0, dst_j)[:, None, None] * 3
+ components[None, None, :],
local_indices[:, :, None],
force_j,
)
_index_add_dense_features(
forces,
force_atom_lookup.index_select(0, dst_k)[:, None, None] * 3
+ components[None, None, :],
local_indices[:, :, None],
force_k,
)
dense_blocks.append(
DenseTripletFeatureBlock(
triplet_index=triplet_index,
coeff_start=coeff_start,
coeff_shape=compact_shape,
energy=energy.reshape(n_systems, *compact_shape),
per_atom_energy=(
None
if per_atom_energy is None
else per_atom_energy.reshape(
int(per_atom_indices.numel()),
*compact_shape,
)
),
per_atom_indices=per_atom_indices,
forces=forces.reshape(
int(force_atom_indices.numel()),
3,
*compact_shape,
),
force_atom_indices=force_atom_indices,
)
)
return DenseThreeBodyFeatureCache(blocks=tuple(dense_blocks))
def _build_dense_feature_cache_from_feature_cache(
feature_cache: ThreeBodyFeatureCache,
system_index: torch.Tensor,
*,
coeff_shape: tuple[int, int, int],
force_scope: Literal["output", "source"] = "output",
include_per_atom_energy: bool = True,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> DenseThreeBodyFeatureCache:
"""Accumulate triplet features into dense per-category feature blocks."""
if force_scope not in {"output", "source"}:
raise ValueError("`force_scope` must be 'output' or 'source'")
if not feature_cache.blocks:
return DenseThreeBodyFeatureCache(blocks=())
device = feature_cache.blocks[0].values.device
dtype = feature_cache.blocks[0].values.dtype
requires_grad = any(
tensor.requires_grad
for block in feature_cache.blocks
for tensor in (
block.values,
block.grad_x,
block.grad_y,
block.grad_z,
block.unit_x,
block.unit_y,
block.unit_z,
)
)
if should_use_native_threebody_dense_feature_cache(
device=device,
dtype=dtype,
requires_grad=requires_grad,
force_scope=force_scope,
runtime_config=runtime_config,
):
try:
sparse_tensors = _feature_cache_sparse_tensors(feature_cache)
native_tensors = call_native_threebody_dense_feature_cache(
sparse_tensors,
system_index,
coeff_shape=coeff_shape,
force_scope=force_scope,
include_per_atom_energy=include_per_atom_energy,
)
n_nodes = int(system_index.numel())
n_systems = int(system_index.max().item()) + 1 if n_nodes else 0
return _dense_feature_cache_from_native_tensors(
native_tensors,
n_systems=n_systems,
)
except RuntimeError:
if requires_native_threebody_backend(runtime_config):
raise
return _build_dense_feature_cache_from_feature_cache_torch(
feature_cache,
system_index,
coeff_shape=coeff_shape,
force_scope=force_scope,
include_per_atom_energy=include_per_atom_energy,
)
def _coefficient_window(
coeffs_by_triplet: torch.Tensor,
block: DenseTripletFeatureBlock,
) -> torch.Tensor:
"""Return the coefficient view addressed by one dense feature block."""
start_x, start_y, start_z = block.coeff_start
size_x, size_y, size_z = block.coeff_shape
return coeffs_by_triplet[
block.triplet_index,
start_x : start_x + size_x,
start_y : start_y + size_y,
start_z : start_z + size_z,
]
def _evaluate_dense_feature_cache_energy_forces(
feature_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache,
coeffs_by_triplet: torch.Tensor,
*,
n_nodes: int,
n_systems: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
"""Evaluate coalesced dense feature blocks against current coefficients."""
if isinstance(feature_cache, MemmapDenseThreeBodyFeatureCache):
feature_cache = feature_cache.to_input_device(
device=coeffs_by_triplet.device,
dtype=coeffs_by_triplet.dtype,
)
elif feature_cache.blocks and (
feature_cache.blocks[0].energy.device != coeffs_by_triplet.device
or feature_cache.blocks[0].energy.dtype != coeffs_by_triplet.dtype
):
feature_cache = feature_cache.to_input_device(
device=coeffs_by_triplet.device,
dtype=coeffs_by_triplet.dtype,
)
energy = coeffs_by_triplet.new_zeros((n_systems,))
per_atom_energy = None
forces = coeffs_by_triplet.new_zeros((n_nodes, 3))
for block in feature_cache.blocks:
coeffs = _coefficient_window(coeffs_by_triplet, block)
energy = energy + torch.tensordot(block.energy, coeffs, dims=3)
if block.per_atom_energy is not None:
block_per_atom_energy = torch.tensordot(
block.per_atom_energy,
coeffs,
dims=3,
)
if block.per_atom_indices is not None:
full_per_atom_energy = coeffs_by_triplet.new_zeros((n_nodes,))
full_per_atom_energy.index_add_(
0,
block.per_atom_indices,
block_per_atom_energy,
)
block_per_atom_energy = full_per_atom_energy
per_atom_energy = (
block_per_atom_energy
if per_atom_energy is None
else per_atom_energy + block_per_atom_energy
)
block_forces = torch.tensordot(block.forces, coeffs, dims=3)
if block.force_atom_indices is None:
forces = forces + block_forces
else:
forces.index_add_(0, block.force_atom_indices, block_forces)
return energy, per_atom_energy, forces