"""Disk-backed dense three-body feature-cache helpers."""
from __future__ import annotations
import hashlib
import json
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from ufp.core._disk_cache import (
atomic_write_json,
atomic_write_npy_memmap,
publish_staged_cache,
settings_cache_dir,
write_cache_settings_summary,
)
from ufp.core.input import UFPInput
from ufp.terms._threebody_dense import (
DenseThreeBodyFeatureCache,
DenseTripletFeatureBlock,
MemmapDenseThreeBodyFeatureCache,
MemmapDenseTripletFeatureBlock,
_build_dense_feature_cache_from_feature_cache,
_MemmapTensor,
)
from ufp.terms._threebody_eval import SplineKind
from ufp.terms._threebody_features import _build_feature_cache_from_buckets
from ufp.terms._threebody_ops import Buckets
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig
_DENSE_REQUIRED_BLOCK_FIELDS = ("energy", "forces")
_DENSE_OPTIONAL_BLOCK_FIELDS = ("per_atom_energy",)
_DENSE_BLOCK_FIELDS = (*_DENSE_REQUIRED_BLOCK_FIELDS, *_DENSE_OPTIONAL_BLOCK_FIELDS)
_DENSE_FEATURE_CACHE_SCHEMA_VERSION = 2
FeatureCacheMode = Literal["auto", "read", "refresh"]
def _write_memmap_tensor(path: Path, tensor: torch.Tensor) -> _MemmapTensor:
"""Write one tensor as a ``.npy`` memmap and return its descriptor."""
atomic_write_npy_memmap(path, tensor.detach().cpu().numpy())
return _MemmapTensor(str(path))
def _write_memmap_dense_triplet_feature_block(
directory: Path,
prefix: str,
block_index: int,
block: DenseTripletFeatureBlock,
) -> MemmapDenseTripletFeatureBlock:
"""Persist one dense triplet feature block as arrays."""
block_prefix = f"{prefix}_block{block_index}_triplet{block.triplet_index}"
return MemmapDenseTripletFeatureBlock(
triplet_index=block.triplet_index,
coeff_start=block.coeff_start,
coeff_shape=block.coeff_shape,
energy=_write_memmap_tensor(
directory / f"{block_prefix}_energy.npy",
block.energy,
),
per_atom_energy=(
None
if block.per_atom_energy is None
else _write_memmap_tensor(
directory / f"{block_prefix}_per_atom_energy.npy",
block.per_atom_energy,
)
),
per_atom_indices=(
None
if block.per_atom_indices is None
else _write_memmap_tensor(
directory / f"{block_prefix}_per_atom_indices.npy",
block.per_atom_indices,
)
),
forces=_write_memmap_tensor(
directory / f"{block_prefix}_forces.npy",
block.forces,
),
force_atom_indices=(
None
if block.force_atom_indices is None
else _write_memmap_tensor(
directory / f"{block_prefix}_force_atom_indices.npy",
block.force_atom_indices,
)
),
)
def _write_memmap_dense_feature_cache(
directory: Path,
prefix: str,
cache: DenseThreeBodyFeatureCache,
*,
metadata: dict[str, object] | None = None,
) -> MemmapDenseThreeBodyFeatureCache:
"""Persist dense triplet feature blocks and a compact manifest."""
directory.mkdir(parents=True, exist_ok=True)
blocks = tuple(
_write_memmap_dense_triplet_feature_block(directory, prefix, index, block)
for index, block in enumerate(cache.blocks)
)
manifest = {
"schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
"prefix": prefix,
"blocks": [
{
"block_index": index,
"triplet_index": int(block.triplet_index),
"coeff_start": list(block.coeff_start),
"coeff_shape": list(block.coeff_shape),
"fields": {
"energy": Path(block.energy.path).name,
"forces": Path(block.forces.path).name,
**(
{}
if block.per_atom_energy is None
else {
"per_atom_energy": Path(block.per_atom_energy.path).name,
"per_atom_indices": (
None
if block.per_atom_indices is None
else Path(block.per_atom_indices.path).name
),
}
),
**(
{}
if block.force_atom_indices is None
else {
"force_atom_indices": Path(
block.force_atom_indices.path
).name
}
),
},
}
for index, block in enumerate(blocks)
],
"metadata": {} if metadata is None else metadata,
}
atomic_write_json(directory / f"{prefix}_manifest.json", manifest)
write_cache_settings_summary(
directory,
cache_kind="threebody_dense_feature_cache",
prefix=prefix,
settings=manifest["metadata"],
extra={
"cache_schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
"manifest": f"{prefix}_manifest.json",
},
)
return MemmapDenseThreeBodyFeatureCache(blocks=blocks)
def _hash_tensor(hasher, tensor: torch.Tensor, *, dtype: torch.dtype | None = None):
"""Add a tensor's dtype, shape, and values to a cache signature."""
value = tensor.detach().cpu()
if dtype is not None:
value = value.to(dtype=dtype)
array = np.ascontiguousarray(value.numpy())
hasher.update(str(array.dtype).encode("utf8"))
hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
hasher.update(array.tobytes())
def _input_feature_cache_signature(inputs: UFPInput) -> str:
"""Return a stable signature for geometry used by dense feature caches."""
hasher = hashlib.sha256()
_hash_tensor(hasher, inputs.positions, dtype=torch.float64)
_hash_tensor(hasher, inputs.cell, dtype=torch.float64)
_hash_tensor(hasher, inputs.pbc)
_hash_tensor(hasher, inputs.atomic_numbers)
_hash_tensor(hasher, inputs.system_index)
neighbor_list = inputs.neighbor_list
if neighbor_list is None:
hasher.update(b"neighbor_list:none")
else:
hasher.update(b"neighbor_list")
_hash_tensor(hasher, neighbor_list.pairs)
_hash_tensor(hasher, neighbor_list.shifts)
if neighbor_list.distances is None:
hasher.update(b"distances:none")
else:
_hash_tensor(hasher, neighbor_list.distances, dtype=torch.float64)
if neighbor_list.vectors is None:
hasher.update(b"vectors:none")
else:
_hash_tensor(hasher, neighbor_list.vectors, dtype=torch.float64)
hasher.update(str(bool(neighbor_list.full_list)).encode("utf8"))
return hasher.hexdigest()
def _dense_feature_cache_metadata(
inputs: UFPInput,
*,
cache_key: str,
atomic_types: Sequence[int],
triplet_categories: Sequence[Sequence[int]],
coeff_shape: Sequence[int],
active_triplet_indices: Sequence[int],
include_per_atom_energy: bool,
spline: str,
first_knot_xy: float,
first_knot_z: float,
knot_spacing_xy: float,
knot_spacing_z: float,
lower_support_xy: float,
lower_support_z: float,
eps: float,
) -> dict[str, object]:
"""Return metadata used to validate resumable dense feature caches."""
return {
"schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
"cache_family": "threebody_dense_feature_cache",
"cache_kind": "memmap_dense_feature_blocks",
"cache_key": cache_key,
"input_signature": _input_feature_cache_signature(inputs),
"n_atoms": int(inputs.n_atoms),
"n_systems": int(inputs.n_systems),
"dtype": str(inputs.dtype),
"atomic_types": [int(value) for value in atomic_types],
"triplet_categories": [
[int(value) for value in triplet] for triplet in triplet_categories
],
"coeff_shape": [int(value) for value in coeff_shape],
"active_triplet_indices": [int(value) for value in active_triplet_indices],
"include_per_atom_energy": bool(include_per_atom_energy),
"row_semantics": {
"energy": "per_system",
"forces": "per_atom_cartesian",
"per_atom_energy": bool(include_per_atom_energy),
},
"row_indexed_dense_blocks": True,
"spline": spline,
"first_knot_xy": float(first_knot_xy),
"first_knot_z": float(first_knot_z),
"knot_spacing_xy": float(knot_spacing_xy),
"knot_spacing_z": float(knot_spacing_z),
"lower_support_xy": float(lower_support_xy),
"lower_support_z": float(lower_support_z),
"eps": float(eps),
}
def _cache_metadata_matches(
stored_metadata: object,
expected_metadata: dict[str, object],
*,
allow_active_triplet_superset: bool = False,
) -> bool:
"""Report whether stored disk-cache metadata describes the expected input."""
if stored_metadata == expected_metadata:
return True
if not allow_active_triplet_superset or not isinstance(stored_metadata, dict):
return False
stored_without_key = dict(stored_metadata)
expected_without_key = dict(expected_metadata)
stored_active = {
int(value) for value in stored_without_key.pop("active_triplet_indices", ())
}
expected_active = {
int(value) for value in expected_without_key.pop("active_triplet_indices", ())
}
stored_without_key.pop("cache_key", None)
expected_without_key.pop("cache_key", None)
return expected_active.issubset(stored_active) and (
stored_without_key == expected_without_key
)
def _dense_feature_cache_dir(
parent: Path,
prefix: str,
metadata: dict[str, object],
) -> Path:
"""Return the settings-named child directory for one dense feature cache."""
return settings_cache_dir(parent, prefix, metadata)
def _write_published_memmap_dense_feature_cache(
parent: Path,
prefix: str,
cache: DenseThreeBodyFeatureCache,
*,
metadata: dict[str, object],
overwrite: bool = False,
) -> MemmapDenseThreeBodyFeatureCache:
"""Persist a dense feature cache through a private staging directory."""
final_dir = _dense_feature_cache_dir(parent, prefix, metadata)
def write_staging(staging_dir: Path) -> None:
_write_memmap_dense_feature_cache(
staging_dir,
prefix,
cache,
metadata=metadata,
)
published_dir, _, published = publish_staged_cache(
final_dir.parent,
final_dir.name,
write_staging,
overwrite=overwrite,
)
loaded = _load_memmap_dense_feature_cache(
published_dir,
prefix,
expected_metadata=metadata,
)
if loaded is None:
if published:
raise ValueError(f"failed to publish dense feature cache: {published_dir}")
return _write_published_memmap_dense_feature_cache(
parent,
prefix,
cache,
metadata=metadata,
overwrite=True,
)
return loaded
def _load_memmap_dense_feature_cache(
directory: Path,
prefix: str,
*,
expected_metadata: dict[str, object],
allow_active_triplet_superset: bool = False,
required_triplet_indices: Iterable[int] | None = None,
) -> MemmapDenseThreeBodyFeatureCache | None:
"""Load one validated dense feature cache manifest if all files are present."""
manifest_path = directory / f"{prefix}_manifest.json"
if not manifest_path.is_file():
return None
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
if int(manifest.get("schema_version", -1)) != _DENSE_FEATURE_CACHE_SCHEMA_VERSION:
return None
if str(manifest.get("prefix")) != prefix:
return None
if not _cache_metadata_matches(
manifest.get("metadata"),
expected_metadata,
allow_active_triplet_superset=allow_active_triplet_superset,
):
return None
blocks = []
required_triplets = (
None
if required_triplet_indices is None
else {int(value) for value in required_triplet_indices}
)
for block in manifest.get("blocks", ()):
triplet_index = int(block["triplet_index"])
if required_triplets is not None and triplet_index not in required_triplets:
continue
fields = block.get("fields", {})
if not all(
(directory / str(fields.get(field))).is_file()
for field in _DENSE_REQUIRED_BLOCK_FIELDS
):
return None
per_atom_path = fields.get("per_atom_energy")
per_atom_energy = None
if per_atom_path is not None:
per_atom_file = directory / str(per_atom_path)
if not per_atom_file.is_file():
return None
per_atom_energy = _MemmapTensor(str(per_atom_file))
per_atom_indices_path = fields.get("per_atom_indices")
per_atom_indices = None
if per_atom_indices_path is not None:
per_atom_indices_file = directory / str(per_atom_indices_path)
if not per_atom_indices_file.is_file():
return None
per_atom_indices = _MemmapTensor(str(per_atom_indices_file))
force_atom_indices_path = fields.get("force_atom_indices")
force_atom_indices = None
if force_atom_indices_path is not None:
force_atom_indices_file = directory / str(force_atom_indices_path)
if not force_atom_indices_file.is_file():
return None
force_atom_indices = _MemmapTensor(str(force_atom_indices_file))
blocks.append(
MemmapDenseTripletFeatureBlock(
triplet_index=triplet_index,
coeff_start=tuple(int(value) for value in block["coeff_start"]),
coeff_shape=tuple(int(value) for value in block["coeff_shape"]),
energy=_MemmapTensor(str(directory / fields["energy"])),
per_atom_energy=per_atom_energy,
per_atom_indices=per_atom_indices,
forces=_MemmapTensor(str(directory / fields["forces"])),
force_atom_indices=force_atom_indices,
)
)
return MemmapDenseThreeBodyFeatureCache(blocks=tuple(blocks))
def _iter_dense_feature_cache_manifests(directory: Path):
"""Yield dense feature-cache manifests in stable search order."""
yield from sorted(directory.glob("*_manifest.json"))
for child in sorted(path for path in directory.iterdir() if path.is_dir()):
yield from sorted(child.glob("*_manifest.json"))
def _find_compatible_memmap_dense_feature_cache(
directory: Path,
*,
expected_metadata: dict[str, object],
required_triplet_indices: Iterable[int],
) -> MemmapDenseThreeBodyFeatureCache | None:
"""Find a V2 disk cache whose active triplets cover the expected request."""
if not directory.is_dir():
return None
for manifest_path in _iter_dense_feature_cache_manifests(directory):
try:
with manifest_path.open("r", encoding="utf8") as handle:
manifest = json.load(handle)
except (OSError, json.JSONDecodeError):
continue
prefix = str(manifest.get("prefix", ""))
if not prefix:
continue
cached = _load_memmap_dense_feature_cache(
manifest_path.parent,
prefix,
expected_metadata=expected_metadata,
allow_active_triplet_superset=True,
required_triplet_indices=required_triplet_indices,
)
if cached is not None:
return cached
return None
[docs]
def load_memmap_threebody_feature_cache(
directory: Path | str,
*,
prefix: str | None = None,
mmap_mode: Literal["r", "r+", "c"] | None = "r",
copy: bool = False,
) -> dict[str, list[dict[str, object]]]:
"""
Load disk-backed dense three-body feature blocks for debugging.
The returned dictionary is grouped as ``cache[prefix]``. Each value is a list
of block dictionaries with ``triplet_index``, ``coeff_start``, ``coeff_shape``,
``energy``, optional ``per_atom_energy``, optional atom-index arrays, and
``forces`` entries.
Args:
directory: Directory containing dense feature block manifests and ``.npy``
arrays.
prefix: Optional cache prefix to filter on.
mmap_mode: Mode passed to ``numpy.load``. Use ``None`` to load regular
arrays directly.
copy: Whether to copy loaded arrays into normal ``numpy.ndarray`` instances.
Returns:
Dense feature blocks grouped by cache prefix.
Raises:
ValueError: If ``directory`` is not an existing directory.
"""
cache_dir = Path(directory)
if not cache_dir.is_dir():
raise ValueError(f"`directory` must be an existing directory: {cache_dir}")
manifest_paths = list(cache_dir.glob("*_manifest.json"))
for child in sorted(path for path in cache_dir.iterdir() if path.is_dir()):
manifest_paths.extend(child.glob("*_manifest.json"))
loaded: dict[str, list[dict[str, object]]] = {}
for manifest_path in sorted(manifest_paths):
manifest_dir = manifest_path.parent
manifest = json.loads(manifest_path.read_text(encoding="utf8"))
cache_prefix = str(manifest["prefix"])
if prefix is not None and cache_prefix != prefix:
continue
blocks: list[dict[str, object]] = []
for block in manifest["blocks"]:
fields = block["fields"]
loaded_block: dict[str, object] = {
"triplet_index": int(block["triplet_index"]),
"coeff_start": tuple(int(value) for value in block["coeff_start"]),
"coeff_shape": tuple(int(value) for value in block["coeff_shape"]),
}
for field in _DENSE_BLOCK_FIELDS:
if field not in fields:
loaded_block[field] = None
continue
array = np.load(manifest_dir / fields[field], mmap_mode=mmap_mode)
if copy:
array = np.array(array)
loaded_block[field] = array
for field in ("per_atom_indices", "force_atom_indices"):
if field not in fields or fields[field] is None:
loaded_block[field] = None
continue
array = np.load(manifest_dir / fields[field], mmap_mode=mmap_mode)
if copy:
array = np.array(array)
loaded_block[field] = array
blocks.append(loaded_block)
loaded[cache_prefix] = blocks
return loaded
def _build_dense_feature_cache_from_buckets(
buckets: Buckets,
system_index: torch.Tensor,
coeff_shape: tuple[int, int, int],
*,
spline: SplineKind = "cubic",
active_triplet_mask: torch.Tensor | None = None,
n_cat: int = 10,
first_knot_xy: float = 0.0,
first_knot_z: float = 0.0,
knot_spacing_xy: float = 0.25,
knot_spacing_z: float = 0.25,
lower_support_xy: float = 0.0,
lower_support_z: float = 0.0,
eps: float = 1.0e-12,
storage: Literal["cpu", "disk"] = "cpu",
cache_dir: Path | None = None,
cache_prefix: str = "threebody",
metadata: dict[str, object] | None = None,
overwrite: bool = False,
include_per_atom_energy: bool = True,
runtime_config: ThreeBodyRuntimeConfig | None = None,
) -> DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache:
"""Build coalesced dense feature blocks for fixed triplet geometry."""
feature_cache = _build_feature_cache_from_buckets(
buckets,
coeff_shape,
spline=spline,
active_triplet_mask=active_triplet_mask,
n_cat=n_cat,
first_knot_xy=first_knot_xy,
first_knot_z=first_knot_z,
knot_spacing_xy=knot_spacing_xy,
knot_spacing_z=knot_spacing_z,
lower_support_xy=lower_support_xy,
lower_support_z=lower_support_z,
eps=eps,
runtime_config=runtime_config,
)
dense_cache = _build_dense_feature_cache_from_feature_cache(
feature_cache,
system_index,
coeff_shape=coeff_shape,
include_per_atom_energy=include_per_atom_energy,
runtime_config=runtime_config,
)
if storage == "cpu":
return dense_cache
if storage != "disk":
raise ValueError("`storage` must be 'cpu' or 'disk'")
if cache_dir is None:
raise ValueError("`cache_dir` is required when `storage='disk'`")
if metadata is None:
return _write_memmap_dense_feature_cache(
Path(cache_dir),
cache_prefix,
dense_cache,
metadata=metadata,
)
return _write_published_memmap_dense_feature_cache(
Path(cache_dir),
cache_prefix,
dense_cache,
metadata=metadata,
overwrite=overwrite,
)
__all__ = [
"FeatureCacheMode",
"load_memmap_threebody_feature_cache",
]