Source code for ufp.terms._threebody_dense

"""Dense three-body feature-cache containers and evaluators."""

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch

from ufp.terms._threebody_features import ThreeBodyFeatureBlock, ThreeBodyFeatureCache
from ufp.terms._threebody_kernels import (
    call_native_threebody_dense_feature_cache,
    requires_native_threebody_backend,
    should_use_native_threebody_dense_feature_cache,
)
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig


@dataclass(frozen=True)
class _MemmapTensor:
    """Descriptor for one disk-backed dense feature array."""

    path: str

    def to_tensor(self) -> torch.Tensor:
        """Open the ``.npy`` file as a memory-mapped torch tensor."""
        array = np.load(self.path, mmap_mode="r+")
        return torch.as_tensor(array)


[docs] @dataclass(frozen=True) class DenseTripletFeatureBlock: """Coalesced dense output features for one triplet category.""" triplet_index: int coeff_start: tuple[int, int, int] coeff_shape: tuple[int, int, int] energy: torch.Tensor per_atom_energy: torch.Tensor | None per_atom_indices: torch.Tensor | None forces: torch.Tensor force_atom_indices: torch.Tensor | None def __bool__(self) -> bool: """Report whether this block contains any nonzero feature entries.""" return bool( torch.any(self.energy) or (self.per_atom_energy is not None and torch.any(self.per_atom_energy)) or torch.any(self.forces) )
[docs] def to_input_device( self, *, device: torch.device, dtype: torch.dtype, ) -> "DenseTripletFeatureBlock": """Return this dense feature block on an input device and dtype.""" return DenseTripletFeatureBlock( triplet_index=self.triplet_index, coeff_start=self.coeff_start, coeff_shape=self.coeff_shape, energy=self.energy.to(device=device, dtype=dtype, non_blocking=True), per_atom_energy=( None if self.per_atom_energy is None else self.per_atom_energy.to( device=device, dtype=dtype, non_blocking=True, ) ), per_atom_indices=( None if self.per_atom_indices is None else self.per_atom_indices.to( device=device, dtype=torch.int64, non_blocking=True, ) ), forces=self.forces.to(device=device, dtype=dtype, non_blocking=True), force_atom_indices=( None if self.force_atom_indices is None else self.force_atom_indices.to( device=device, dtype=torch.int64, non_blocking=True, ) ), )
[docs] @dataclass(frozen=True) class DenseThreeBodyFeatureCache: """Coalesced dense output features for one fixed input.""" blocks: tuple[DenseTripletFeatureBlock, ...] def __bool__(self) -> bool: """Report whether this cache contains any dense feature blocks.""" return bool(self.blocks)
[docs] def to_input_device( self, *, device: torch.device, dtype: torch.dtype, ) -> "DenseThreeBodyFeatureCache": """Return this dense feature cache on an input device and dtype.""" return DenseThreeBodyFeatureCache( blocks=tuple( block.to_input_device(device=device, dtype=dtype) for block in self.blocks ) )
[docs] @dataclass(frozen=True) class MemmapDenseTripletFeatureBlock: """Disk-backed dense output features for one triplet category.""" triplet_index: int coeff_start: tuple[int, int, int] coeff_shape: tuple[int, int, int] energy: _MemmapTensor per_atom_energy: _MemmapTensor | None per_atom_indices: _MemmapTensor | None forces: _MemmapTensor force_atom_indices: _MemmapTensor | None
[docs] def to_input_device( self, *, device: torch.device, dtype: torch.dtype, ) -> DenseTripletFeatureBlock: """Materialize this disk-backed dense block on an input device.""" return DenseTripletFeatureBlock( triplet_index=self.triplet_index, coeff_start=self.coeff_start, coeff_shape=self.coeff_shape, energy=self.energy.to_tensor().to(device=device, dtype=dtype), per_atom_energy=( None if self.per_atom_energy is None else self.per_atom_energy.to_tensor().to( device=device, dtype=dtype, ) ), per_atom_indices=( None if self.per_atom_indices is None else self.per_atom_indices.to_tensor().to( device=device, dtype=torch.int64, ) ), forces=self.forces.to_tensor().to(device=device, dtype=dtype), force_atom_indices=( None if self.force_atom_indices is None else self.force_atom_indices.to_tensor().to( device=device, dtype=torch.int64, ) ), )
[docs] @dataclass(frozen=True) class MemmapDenseThreeBodyFeatureCache: """Disk-backed dense output feature blocks for one fixed input.""" blocks: tuple[MemmapDenseTripletFeatureBlock, ...] def __bool__(self) -> bool: """Report whether this disk cache has dense feature block files.""" return bool(self.blocks)
[docs] def to_input_device( self, *, device: torch.device, dtype: torch.dtype, ) -> DenseThreeBodyFeatureCache: """Materialize this disk-backed dense cache on an input device.""" return DenseThreeBodyFeatureCache( blocks=tuple( block.to_input_device(device=device, dtype=dtype) for block in self.blocks ) )
[docs] @dataclass(frozen=True) class ThreeBodyDenseAtomFeatures: """Dense coefficient-space output rows for selected atoms.""" atom_indices: torch.Tensor atomic_energy: torch.Tensor force_x: torch.Tensor force_y: torch.Tensor force_z: torch.Tensor @property def forces(self) -> torch.Tensor: """Return force-component rows stacked as ``(n_atoms, 3, n_features)``.""" return torch.stack((self.force_x, self.force_y, self.force_z), dim=1)
def _selected_atom_indices( n_nodes: int, atom_indices: Sequence[int] | torch.Tensor | None, *, device: torch.device, ) -> torch.Tensor: """Normalize selected atom indices for diagnostic feature extraction.""" if atom_indices is None: return torch.arange(n_nodes, dtype=torch.int64, device=device) selected = torch.as_tensor(atom_indices, dtype=torch.int64, device=device) if selected.ndim != 1: raise ValueError("`atom_indices` must be one-dimensional") if torch.any(selected < 0) or torch.any(selected >= int(n_nodes)): raise ValueError("`atom_indices` contains an atom index outside the input") return selected def _compact_feature_indices( coeff_start: tuple[int, int, int], compact_shape: tuple[int, int, int], coeff_shape: tuple[int, int, int], *, triplet_index: int, device: torch.device, ) -> torch.Tensor: """Return flattened full-feature indices addressed by one compact block.""" start_x, start_y, start_z = coeff_start size_x, size_y, size_z = compact_shape full_y, full_z = coeff_shape[1], coeff_shape[2] local_x = torch.arange(size_x, dtype=torch.int64, device=device) local_y = torch.arange(size_y, dtype=torch.int64, device=device) local_z = torch.arange(size_z, dtype=torch.int64, device=device) x, y, z = torch.meshgrid(local_x, local_y, local_z, indexing="ij") coeff_indices = ( (x + int(start_x)) * int(full_y) * int(full_z) + (y + int(start_y)) * int(full_z) + (z + int(start_z)) ) coeff_volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2]) return coeff_indices.reshape(-1) + int(triplet_index) * coeff_volume def _dense_atom_features_from_feature_cache( feature_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache, atom_indices: torch.Tensor, *, n_triplet_categories: int, coeff_shape: tuple[int, int, int], dtype: torch.dtype, ) -> ThreeBodyDenseAtomFeatures: """Densify per-atom energy and force rows from compact feature blocks.""" if isinstance(feature_cache, MemmapDenseThreeBodyFeatureCache): feature_cache = feature_cache.to_input_device( device=atom_indices.device, dtype=dtype, ) coeff_volume = int(coeff_shape[0] * coeff_shape[1] * coeff_shape[2]) n_features = int(n_triplet_categories * coeff_volume) if not feature_cache.blocks: device = atom_indices.device return ThreeBodyDenseAtomFeatures( atom_indices=atom_indices, atomic_energy=torch.zeros( (int(atom_indices.numel()), n_features), dtype=dtype, device=device, ), force_x=torch.zeros( (int(atom_indices.numel()), n_features), dtype=dtype, device=device, ), force_y=torch.zeros( (int(atom_indices.numel()), n_features), dtype=dtype, device=device, ), force_z=torch.zeros( (int(atom_indices.numel()), n_features), dtype=dtype, device=device, ), ) first_block = feature_cache.blocks[0] if first_block.per_atom_energy is None: raise ValueError("dense atom features require cached per-atom energy features") dtype = first_block.per_atom_energy.dtype device = first_block.per_atom_energy.device atom_indices = atom_indices.to(device=device, dtype=torch.int64) atomic_energy = torch.zeros( (int(atom_indices.numel()), n_features), dtype=dtype, device=device, ) dense_forces = torch.zeros( (int(atom_indices.numel()), 3, n_features), dtype=dtype, device=device, ) for block in feature_cache.blocks: feature_indices = _compact_feature_indices( block.coeff_start, block.coeff_shape, coeff_shape, triplet_index=block.triplet_index, device=device, ) if block.per_atom_energy is None: raise ValueError( "dense atom features require cached per-atom energy features" ) if block.per_atom_indices is None: atomic_energy[:, feature_indices] = block.per_atom_energy.index_select( 0, atom_indices, ).reshape(int(atom_indices.numel()), -1) else: per_atom_positions = torch.searchsorted( block.per_atom_indices, atom_indices, ) valid = per_atom_positions < int(block.per_atom_indices.numel()) valid_indices = torch.nonzero(valid, as_tuple=False).reshape(-1) if valid_indices.numel(): candidate_positions = per_atom_positions.index_select(0, valid_indices) matching = block.per_atom_indices.index_select( 0, candidate_positions ) == atom_indices.index_select(0, valid_indices) valid_indices = valid_indices[matching] candidate_positions = candidate_positions[matching] atomic_energy[valid_indices[:, None], feature_indices] = ( block.per_atom_energy.index_select( 0, candidate_positions, ).reshape(int(valid_indices.numel()), -1) ) if block.force_atom_indices is None: dense_forces[:, :, feature_indices] = block.forces.index_select( 0, atom_indices, ).reshape(int(atom_indices.numel()), 3, -1) else: force_positions = torch.searchsorted( block.force_atom_indices, atom_indices, ) valid = force_positions < int(block.force_atom_indices.numel()) valid_indices = torch.nonzero(valid, as_tuple=False).reshape(-1) if valid_indices.numel(): candidate_positions = force_positions.index_select(0, valid_indices) matching = block.force_atom_indices.index_select( 0, candidate_positions ) == atom_indices.index_select(0, valid_indices) valid_indices = valid_indices[matching] candidate_positions = candidate_positions[matching] force_components = torch.arange(3, dtype=torch.int64, device=device) dense_forces[ valid_indices[:, None, None], force_components[None, :, None], feature_indices[None, None, :], ] = block.forces.index_select( 0, candidate_positions, ).reshape( int(valid_indices.numel()), 3, -1, ) return ThreeBodyDenseAtomFeatures( atom_indices=atom_indices, atomic_energy=atomic_energy, force_x=dense_forces[:, 0, :], force_y=dense_forces[:, 1, :], force_z=dense_forces[:, 2, :], ) def _symmetrize_dense_feature_rows( rows: torch.Tensor, same_neighbor_triplet_mask: torch.Tensor, *, coeff_shape: tuple[int, int, int], ) -> torch.Tensor: """Apply same-neighbor x/y coefficient tying to dense feature rows.""" if rows.numel() == 0: return rows same_mask = same_neighbor_triplet_mask.to(device=rows.device, dtype=torch.bool) if same_mask.numel() == 0 or not bool(torch.any(same_mask)): return rows shaped = rows.reshape(rows.shape[0], same_mask.numel(), *coeff_shape).clone() same_rows = shaped[:, same_mask] shaped[:, same_mask] = 0.5 * (same_rows + same_rows.transpose(2, 3)) return shaped.reshape_as(rows) def _symmetrize_dense_atom_features( features: ThreeBodyDenseAtomFeatures, same_neighbor_triplet_mask: torch.Tensor, *, coeff_shape: tuple[int, int, int], ) -> ThreeBodyDenseAtomFeatures: """Apply term coefficient tying to all dense atom feature rows.""" return ThreeBodyDenseAtomFeatures( atom_indices=features.atom_indices, atomic_energy=_symmetrize_dense_feature_rows( features.atomic_energy, same_neighbor_triplet_mask, coeff_shape=coeff_shape, ), force_x=_symmetrize_dense_feature_rows( features.force_x, same_neighbor_triplet_mask, coeff_shape=coeff_shape, ), force_y=_symmetrize_dense_feature_rows( features.force_y, same_neighbor_triplet_mask, coeff_shape=coeff_shape, ), force_z=_symmetrize_dense_feature_rows( features.force_z, same_neighbor_triplet_mask, coeff_shape=coeff_shape, ), ) def _compact_bounds_from_indices( indices: torch.Tensor, *, coeff_shape: tuple[int, int, int], ) -> tuple[tuple[int, int, int], tuple[int, int, int], torch.Tensor]: """Return compact coefficient bounds and local flat indices.""" nx, ny, nz = coeff_shape ix = torch.div(indices, ny * nz, rounding_mode="floor") rem = indices - ix * ny * nz iy = torch.div(rem, nz, rounding_mode="floor") iz = rem - iy * nz start = (int(ix.min().item()), int(iy.min().item()), int(iz.min().item())) stop = ( int(ix.max().item()) + 1, int(iy.max().item()) + 1, int(iz.max().item()) + 1, ) compact_shape = ( stop[0] - start[0], stop[1] - start[1], stop[2] - start[2], ) local_indices = ( (ix - start[0]) * compact_shape[1] * compact_shape[2] + (iy - start[1]) * compact_shape[2] + (iz - start[2]) ) return start, compact_shape, local_indices def _local_indices_from_compact_bounds( indices: torch.Tensor, *, coeff_start: tuple[int, int, int], compact_shape: tuple[int, int, int], coeff_shape: tuple[int, int, int], ) -> torch.Tensor: """Map flattened full coefficient indices into one compact block.""" ny, nz = coeff_shape[1], coeff_shape[2] ix = torch.div(indices, ny * nz, rounding_mode="floor") rem = indices - ix * ny * nz iy = torch.div(rem, nz, rounding_mode="floor") iz = rem - iy * nz return ( (ix - coeff_start[0]) * compact_shape[1] * compact_shape[2] + (iy - coeff_start[1]) * compact_shape[2] + (iz - coeff_start[2]) ) def _index_add_dense_features( target: torch.Tensor, rows: torch.Tensor, cols: torch.Tensor, values: torch.Tensor, ) -> None: """Accumulate feature values into a dense row-feature matrix.""" rows, cols, values = torch.broadcast_tensors(rows, cols, values) n_cols = target.shape[1] flat_indices = rows.reshape(-1) * n_cols + cols.reshape(-1) target.reshape(-1).index_add_(0, flat_indices, values.reshape(-1)) def _feature_cache_sparse_tensors( feature_cache: ThreeBodyFeatureCache, ) -> tuple[torch.Tensor, ...]: """Return concatenated sparse feature tensors from a feature cache.""" if not feature_cache.blocks: raise ValueError("cannot concatenate an empty three-body feature cache") if len(feature_cache.blocks) == 1: block = feature_cache.blocks[0] return ( block.src_ids, block.dst_j, block.dst_k, block.triplet_index, block.stencil_indices, block.values, block.grad_x, block.grad_y, block.grad_z, block.unit_x, block.unit_y, block.unit_z, ) return ( torch.cat([block.src_ids for block in feature_cache.blocks], dim=0), torch.cat([block.dst_j for block in feature_cache.blocks], dim=0), torch.cat([block.dst_k for block in feature_cache.blocks], dim=0), torch.cat([block.triplet_index for block in feature_cache.blocks], dim=0), torch.cat([block.stencil_indices for block in feature_cache.blocks], dim=0), torch.cat([block.values for block in feature_cache.blocks], dim=0), torch.cat([block.grad_x for block in feature_cache.blocks], dim=0), torch.cat([block.grad_y for block in feature_cache.blocks], dim=0), torch.cat([block.grad_z for block in feature_cache.blocks], dim=0), torch.cat([block.unit_x for block in feature_cache.blocks], dim=0), torch.cat([block.unit_y for block in feature_cache.blocks], dim=0), torch.cat([block.unit_z for block in feature_cache.blocks], dim=0), ) def _dense_feature_cache_from_native_tensors( tensors: tuple[torch.Tensor, ...], *, n_systems: int, ) -> DenseThreeBodyFeatureCache: """Wrap native dense-cache tensors in Python dense feature blocks.""" if len(tensors) != 13: raise RuntimeError( "native dense three-body feature-cache operator returned an unexpected " f"number of tensors: {len(tensors)}" ) ( triplet_indices, coeff_start, compact_shape, energy_ptr, per_atom_ptr, force_ptr, per_atom_index_ptr, force_atom_index_ptr, energy_values, per_atom_values, force_values, per_atom_indices, force_atom_indices, ) = tensors blocks: list[DenseTripletFeatureBlock] = [] for block_index in range(int(triplet_indices.numel())): shape = tuple(int(value) for value in compact_shape[block_index].tolist()) energy_start = int(energy_ptr[block_index]) energy_stop = int(energy_ptr[block_index + 1]) per_atom_start = int(per_atom_ptr[block_index]) per_atom_stop = int(per_atom_ptr[block_index + 1]) force_start = int(force_ptr[block_index]) force_stop = int(force_ptr[block_index + 1]) per_atom_index_start = int(per_atom_index_ptr[block_index]) per_atom_index_stop = int(per_atom_index_ptr[block_index + 1]) force_index_start = int(force_atom_index_ptr[block_index]) force_index_stop = int(force_atom_index_ptr[block_index + 1]) block_per_atom_indices = per_atom_indices[ per_atom_index_start:per_atom_index_stop ] block_force_atom_indices = force_atom_indices[ force_index_start:force_index_stop ] n_per_atom = int(block_per_atom_indices.numel()) n_force_atom = int(block_force_atom_indices.numel()) per_atom_energy = None if per_atom_stop > per_atom_start: per_atom_energy = per_atom_values[per_atom_start:per_atom_stop].reshape( n_per_atom, *shape, ) blocks.append( DenseTripletFeatureBlock( triplet_index=int(triplet_indices[block_index]), coeff_start=tuple( int(value) for value in coeff_start[block_index].tolist() ), coeff_shape=shape, energy=energy_values[energy_start:energy_stop].reshape( n_systems, *shape, ), per_atom_energy=per_atom_energy, per_atom_indices=( block_per_atom_indices if per_atom_energy is not None else None ), forces=force_values[force_start:force_stop].reshape( n_force_atom, 3, *shape, ), force_atom_indices=block_force_atom_indices, ) ) return DenseThreeBodyFeatureCache(blocks=tuple(blocks)) def _build_dense_feature_cache_from_feature_cache_torch( feature_cache: ThreeBodyFeatureCache, system_index: torch.Tensor, *, coeff_shape: tuple[int, int, int], force_scope: Literal["output", "source"] = "output", include_per_atom_energy: bool = True, ) -> DenseThreeBodyFeatureCache: """Accumulate triplet features into dense per-category feature blocks.""" if force_scope not in {"output", "source"}: raise ValueError("`force_scope` must be 'output' or 'source'") if not feature_cache.blocks: return DenseThreeBodyFeatureCache(blocks=()) device = feature_cache.blocks[0].values.device dtype = feature_cache.blocks[0].values.dtype n_nodes = int(system_index.numel()) n_systems = int(system_index.max().item()) + 1 if n_nodes else 0 components = torch.arange(3, dtype=torch.int64, device=device) system_index = system_index.to(device=device, dtype=torch.int64) triplet_indices = sorted( { int(index.item()) for block in feature_cache.blocks for index in torch.unique(block.triplet_index).cpu() } ) dense_blocks: list[DenseTripletFeatureBlock] = [] for triplet_index in triplet_indices: selected_blocks: list[tuple[ThreeBodyFeatureBlock, torch.Tensor]] = [] all_indices: list[torch.Tensor] = [] per_atom_index_chunks: list[torch.Tensor] = [] force_atom_index_chunks: list[torch.Tensor] = [] for block in feature_cache.blocks: mask = block.triplet_index == triplet_index if not torch.any(mask): continue selected_blocks.append((block, mask)) all_indices.append(block.stencil_indices[mask].reshape(-1)) src_ids = block.src_ids[mask] per_atom_index_chunks.append(src_ids) force_atom_index_chunks.append(src_ids) if force_scope == "output": force_atom_index_chunks.append(block.dst_j[mask]) force_atom_index_chunks.append(block.dst_k[mask]) if not all_indices: continue coeff_start, compact_shape, _ = _compact_bounds_from_indices( torch.cat(all_indices, dim=0), coeff_shape=coeff_shape, ) compact_volume = int(compact_shape[0] * compact_shape[1] * compact_shape[2]) force_atom_indices = torch.unique(torch.cat(force_atom_index_chunks, dim=0)) force_atom_lookup = torch.full( (n_nodes,), -1, dtype=torch.int64, device=device, ) force_atom_lookup[force_atom_indices] = torch.arange( force_atom_indices.numel(), dtype=torch.int64, device=device, ) per_atom_indices = None per_atom_lookup = None if include_per_atom_energy: per_atom_indices = torch.unique(torch.cat(per_atom_index_chunks, dim=0)) per_atom_lookup = torch.full( (n_nodes,), -1, dtype=torch.int64, device=device, ) per_atom_lookup[per_atom_indices] = torch.arange( per_atom_indices.numel(), dtype=torch.int64, device=device, ) energy = torch.zeros( (n_systems, compact_volume), dtype=dtype, device=device, ) per_atom_energy = ( torch.zeros( (int(per_atom_indices.numel()), compact_volume), dtype=dtype, device=device, ) if include_per_atom_energy else None ) forces = torch.zeros( (int(force_atom_indices.numel()) * 3, compact_volume), dtype=dtype, device=device, ) for block, mask in selected_blocks: local_indices = _local_indices_from_compact_bounds( block.stencil_indices[mask], coeff_start=coeff_start, compact_shape=compact_shape, coeff_shape=coeff_shape, ) src_ids = block.src_ids[mask] dst_j = block.dst_j[mask] dst_k = block.dst_k[mask] values = block.values[mask] system_rows = system_index.index_select(0, src_ids) _index_add_dense_features( energy, system_rows[:, None], local_indices, values ) if per_atom_energy is not None: assert per_atom_lookup is not None per_atom_rows = per_atom_lookup.index_select(0, src_ids) _index_add_dense_features( per_atom_energy, per_atom_rows[:, None], local_indices, values, ) force_src = ( block.grad_x[mask, :, None] * block.unit_x[mask, None, :] + block.grad_y[mask, :, None] * block.unit_y[mask, None, :] ) force_j = -( block.grad_x[mask, :, None] * block.unit_x[mask, None, :] + block.grad_z[mask, :, None] * block.unit_z[mask, None, :] ) force_k = ( -block.grad_y[mask, :, None] * block.unit_y[mask, None, :] + block.grad_z[mask, :, None] * block.unit_z[mask, None, :] ) _index_add_dense_features( forces, force_atom_lookup.index_select(0, src_ids)[:, None, None] * 3 + components[None, None, :], local_indices[:, :, None], force_src, ) if force_scope == "output": _index_add_dense_features( forces, force_atom_lookup.index_select(0, dst_j)[:, None, None] * 3 + components[None, None, :], local_indices[:, :, None], force_j, ) _index_add_dense_features( forces, force_atom_lookup.index_select(0, dst_k)[:, None, None] * 3 + components[None, None, :], local_indices[:, :, None], force_k, ) dense_blocks.append( DenseTripletFeatureBlock( triplet_index=triplet_index, coeff_start=coeff_start, coeff_shape=compact_shape, energy=energy.reshape(n_systems, *compact_shape), per_atom_energy=( None if per_atom_energy is None else per_atom_energy.reshape( int(per_atom_indices.numel()), *compact_shape, ) ), per_atom_indices=per_atom_indices, forces=forces.reshape( int(force_atom_indices.numel()), 3, *compact_shape, ), force_atom_indices=force_atom_indices, ) ) return DenseThreeBodyFeatureCache(blocks=tuple(dense_blocks)) def _build_dense_feature_cache_from_feature_cache( feature_cache: ThreeBodyFeatureCache, system_index: torch.Tensor, *, coeff_shape: tuple[int, int, int], force_scope: Literal["output", "source"] = "output", include_per_atom_energy: bool = True, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> DenseThreeBodyFeatureCache: """Accumulate triplet features into dense per-category feature blocks.""" if force_scope not in {"output", "source"}: raise ValueError("`force_scope` must be 'output' or 'source'") if not feature_cache.blocks: return DenseThreeBodyFeatureCache(blocks=()) device = feature_cache.blocks[0].values.device dtype = feature_cache.blocks[0].values.dtype requires_grad = any( tensor.requires_grad for block in feature_cache.blocks for tensor in ( block.values, block.grad_x, block.grad_y, block.grad_z, block.unit_x, block.unit_y, block.unit_z, ) ) if should_use_native_threebody_dense_feature_cache( device=device, dtype=dtype, requires_grad=requires_grad, force_scope=force_scope, runtime_config=runtime_config, ): try: sparse_tensors = _feature_cache_sparse_tensors(feature_cache) native_tensors = call_native_threebody_dense_feature_cache( sparse_tensors, system_index, coeff_shape=coeff_shape, force_scope=force_scope, include_per_atom_energy=include_per_atom_energy, ) n_nodes = int(system_index.numel()) n_systems = int(system_index.max().item()) + 1 if n_nodes else 0 return _dense_feature_cache_from_native_tensors( native_tensors, n_systems=n_systems, ) except RuntimeError: if requires_native_threebody_backend(runtime_config): raise return _build_dense_feature_cache_from_feature_cache_torch( feature_cache, system_index, coeff_shape=coeff_shape, force_scope=force_scope, include_per_atom_energy=include_per_atom_energy, ) def _coefficient_window( coeffs_by_triplet: torch.Tensor, block: DenseTripletFeatureBlock, ) -> torch.Tensor: """Return the coefficient view addressed by one dense feature block.""" start_x, start_y, start_z = block.coeff_start size_x, size_y, size_z = block.coeff_shape return coeffs_by_triplet[ block.triplet_index, start_x : start_x + size_x, start_y : start_y + size_y, start_z : start_z + size_z, ] def _evaluate_dense_feature_cache_energy_forces( feature_cache: DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache, coeffs_by_triplet: torch.Tensor, *, n_nodes: int, n_systems: int, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: """Evaluate coalesced dense feature blocks against current coefficients.""" if isinstance(feature_cache, MemmapDenseThreeBodyFeatureCache): feature_cache = feature_cache.to_input_device( device=coeffs_by_triplet.device, dtype=coeffs_by_triplet.dtype, ) elif feature_cache.blocks and ( feature_cache.blocks[0].energy.device != coeffs_by_triplet.device or feature_cache.blocks[0].energy.dtype != coeffs_by_triplet.dtype ): feature_cache = feature_cache.to_input_device( device=coeffs_by_triplet.device, dtype=coeffs_by_triplet.dtype, ) energy = coeffs_by_triplet.new_zeros((n_systems,)) per_atom_energy = None forces = coeffs_by_triplet.new_zeros((n_nodes, 3)) for block in feature_cache.blocks: coeffs = _coefficient_window(coeffs_by_triplet, block) energy = energy + torch.tensordot(block.energy, coeffs, dims=3) if block.per_atom_energy is not None: block_per_atom_energy = torch.tensordot( block.per_atom_energy, coeffs, dims=3, ) if block.per_atom_indices is not None: full_per_atom_energy = coeffs_by_triplet.new_zeros((n_nodes,)) full_per_atom_energy.index_add_( 0, block.per_atom_indices, block_per_atom_energy, ) block_per_atom_energy = full_per_atom_energy per_atom_energy = ( block_per_atom_energy if per_atom_energy is None else per_atom_energy + block_per_atom_energy ) block_forces = torch.tensordot(block.forces, coeffs, dims=3) if block.force_atom_indices is None: forces = forces + block_forces else: forces.index_add_(0, block.force_atom_indices, block_forces) return energy, per_atom_energy, forces