Source code for ufp.terms._threebody_cache

"""Disk-backed dense three-body feature-cache helpers."""

from __future__ import annotations

import hashlib
import json
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Literal

import numpy as np
import torch

from ufp.core._disk_cache import (
    atomic_write_json,
    atomic_write_npy_memmap,
    publish_staged_cache,
    settings_cache_dir,
    write_cache_settings_summary,
)
from ufp.core.input import UFPInput
from ufp.terms._threebody_dense import (
    DenseThreeBodyFeatureCache,
    DenseTripletFeatureBlock,
    MemmapDenseThreeBodyFeatureCache,
    MemmapDenseTripletFeatureBlock,
    _build_dense_feature_cache_from_feature_cache,
    _MemmapTensor,
)
from ufp.terms._threebody_eval import SplineKind
from ufp.terms._threebody_features import _build_feature_cache_from_buckets
from ufp.terms._threebody_ops import Buckets
from ufp.terms._threebody_runtime import ThreeBodyRuntimeConfig


_DENSE_REQUIRED_BLOCK_FIELDS = ("energy", "forces")
_DENSE_OPTIONAL_BLOCK_FIELDS = ("per_atom_energy",)
_DENSE_BLOCK_FIELDS = (*_DENSE_REQUIRED_BLOCK_FIELDS, *_DENSE_OPTIONAL_BLOCK_FIELDS)
_DENSE_FEATURE_CACHE_SCHEMA_VERSION = 2
FeatureCacheMode = Literal["auto", "read", "refresh"]


def _write_memmap_tensor(path: Path, tensor: torch.Tensor) -> _MemmapTensor:
    """Write one tensor as a ``.npy`` memmap and return its descriptor."""
    atomic_write_npy_memmap(path, tensor.detach().cpu().numpy())
    return _MemmapTensor(str(path))


def _write_memmap_dense_triplet_feature_block(
    directory: Path,
    prefix: str,
    block_index: int,
    block: DenseTripletFeatureBlock,
) -> MemmapDenseTripletFeatureBlock:
    """Persist one dense triplet feature block as arrays."""
    block_prefix = f"{prefix}_block{block_index}_triplet{block.triplet_index}"
    return MemmapDenseTripletFeatureBlock(
        triplet_index=block.triplet_index,
        coeff_start=block.coeff_start,
        coeff_shape=block.coeff_shape,
        energy=_write_memmap_tensor(
            directory / f"{block_prefix}_energy.npy",
            block.energy,
        ),
        per_atom_energy=(
            None
            if block.per_atom_energy is None
            else _write_memmap_tensor(
                directory / f"{block_prefix}_per_atom_energy.npy",
                block.per_atom_energy,
            )
        ),
        per_atom_indices=(
            None
            if block.per_atom_indices is None
            else _write_memmap_tensor(
                directory / f"{block_prefix}_per_atom_indices.npy",
                block.per_atom_indices,
            )
        ),
        forces=_write_memmap_tensor(
            directory / f"{block_prefix}_forces.npy",
            block.forces,
        ),
        force_atom_indices=(
            None
            if block.force_atom_indices is None
            else _write_memmap_tensor(
                directory / f"{block_prefix}_force_atom_indices.npy",
                block.force_atom_indices,
            )
        ),
    )


def _write_memmap_dense_feature_cache(
    directory: Path,
    prefix: str,
    cache: DenseThreeBodyFeatureCache,
    *,
    metadata: dict[str, object] | None = None,
) -> MemmapDenseThreeBodyFeatureCache:
    """Persist dense triplet feature blocks and a compact manifest."""
    directory.mkdir(parents=True, exist_ok=True)
    blocks = tuple(
        _write_memmap_dense_triplet_feature_block(directory, prefix, index, block)
        for index, block in enumerate(cache.blocks)
    )
    manifest = {
        "schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
        "prefix": prefix,
        "blocks": [
            {
                "block_index": index,
                "triplet_index": int(block.triplet_index),
                "coeff_start": list(block.coeff_start),
                "coeff_shape": list(block.coeff_shape),
                "fields": {
                    "energy": Path(block.energy.path).name,
                    "forces": Path(block.forces.path).name,
                    **(
                        {}
                        if block.per_atom_energy is None
                        else {
                            "per_atom_energy": Path(block.per_atom_energy.path).name,
                            "per_atom_indices": (
                                None
                                if block.per_atom_indices is None
                                else Path(block.per_atom_indices.path).name
                            ),
                        }
                    ),
                    **(
                        {}
                        if block.force_atom_indices is None
                        else {
                            "force_atom_indices": Path(
                                block.force_atom_indices.path
                            ).name
                        }
                    ),
                },
            }
            for index, block in enumerate(blocks)
        ],
        "metadata": {} if metadata is None else metadata,
    }
    atomic_write_json(directory / f"{prefix}_manifest.json", manifest)
    write_cache_settings_summary(
        directory,
        cache_kind="threebody_dense_feature_cache",
        prefix=prefix,
        settings=manifest["metadata"],
        extra={
            "cache_schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
            "manifest": f"{prefix}_manifest.json",
        },
    )
    return MemmapDenseThreeBodyFeatureCache(blocks=blocks)


def _hash_tensor(hasher, tensor: torch.Tensor, *, dtype: torch.dtype | None = None):
    """Add a tensor's dtype, shape, and values to a cache signature."""
    value = tensor.detach().cpu()
    if dtype is not None:
        value = value.to(dtype=dtype)
    array = np.ascontiguousarray(value.numpy())
    hasher.update(str(array.dtype).encode("utf8"))
    hasher.update(np.asarray(array.shape, dtype=np.int64).tobytes())
    hasher.update(array.tobytes())


def _input_feature_cache_signature(inputs: UFPInput) -> str:
    """Return a stable signature for geometry used by dense feature caches."""
    hasher = hashlib.sha256()
    _hash_tensor(hasher, inputs.positions, dtype=torch.float64)
    _hash_tensor(hasher, inputs.cell, dtype=torch.float64)
    _hash_tensor(hasher, inputs.pbc)
    _hash_tensor(hasher, inputs.atomic_numbers)
    _hash_tensor(hasher, inputs.system_index)
    neighbor_list = inputs.neighbor_list
    if neighbor_list is None:
        hasher.update(b"neighbor_list:none")
    else:
        hasher.update(b"neighbor_list")
        _hash_tensor(hasher, neighbor_list.pairs)
        _hash_tensor(hasher, neighbor_list.shifts)
        if neighbor_list.distances is None:
            hasher.update(b"distances:none")
        else:
            _hash_tensor(hasher, neighbor_list.distances, dtype=torch.float64)
        if neighbor_list.vectors is None:
            hasher.update(b"vectors:none")
        else:
            _hash_tensor(hasher, neighbor_list.vectors, dtype=torch.float64)
        hasher.update(str(bool(neighbor_list.full_list)).encode("utf8"))
    return hasher.hexdigest()


def _dense_feature_cache_metadata(
    inputs: UFPInput,
    *,
    cache_key: str,
    atomic_types: Sequence[int],
    triplet_categories: Sequence[Sequence[int]],
    coeff_shape: Sequence[int],
    active_triplet_indices: Sequence[int],
    include_per_atom_energy: bool,
    spline: str,
    first_knot_xy: float,
    first_knot_z: float,
    knot_spacing_xy: float,
    knot_spacing_z: float,
    lower_support_xy: float,
    lower_support_z: float,
    eps: float,
) -> dict[str, object]:
    """Return metadata used to validate resumable dense feature caches."""
    return {
        "schema_version": _DENSE_FEATURE_CACHE_SCHEMA_VERSION,
        "cache_family": "threebody_dense_feature_cache",
        "cache_kind": "memmap_dense_feature_blocks",
        "cache_key": cache_key,
        "input_signature": _input_feature_cache_signature(inputs),
        "n_atoms": int(inputs.n_atoms),
        "n_systems": int(inputs.n_systems),
        "dtype": str(inputs.dtype),
        "atomic_types": [int(value) for value in atomic_types],
        "triplet_categories": [
            [int(value) for value in triplet] for triplet in triplet_categories
        ],
        "coeff_shape": [int(value) for value in coeff_shape],
        "active_triplet_indices": [int(value) for value in active_triplet_indices],
        "include_per_atom_energy": bool(include_per_atom_energy),
        "row_semantics": {
            "energy": "per_system",
            "forces": "per_atom_cartesian",
            "per_atom_energy": bool(include_per_atom_energy),
        },
        "row_indexed_dense_blocks": True,
        "spline": spline,
        "first_knot_xy": float(first_knot_xy),
        "first_knot_z": float(first_knot_z),
        "knot_spacing_xy": float(knot_spacing_xy),
        "knot_spacing_z": float(knot_spacing_z),
        "lower_support_xy": float(lower_support_xy),
        "lower_support_z": float(lower_support_z),
        "eps": float(eps),
    }


def _cache_metadata_matches(
    stored_metadata: object,
    expected_metadata: dict[str, object],
    *,
    allow_active_triplet_superset: bool = False,
) -> bool:
    """Report whether stored disk-cache metadata describes the expected input."""
    if stored_metadata == expected_metadata:
        return True
    if not allow_active_triplet_superset or not isinstance(stored_metadata, dict):
        return False

    stored_without_key = dict(stored_metadata)
    expected_without_key = dict(expected_metadata)
    stored_active = {
        int(value) for value in stored_without_key.pop("active_triplet_indices", ())
    }
    expected_active = {
        int(value) for value in expected_without_key.pop("active_triplet_indices", ())
    }
    stored_without_key.pop("cache_key", None)
    expected_without_key.pop("cache_key", None)
    return expected_active.issubset(stored_active) and (
        stored_without_key == expected_without_key
    )


def _dense_feature_cache_dir(
    parent: Path,
    prefix: str,
    metadata: dict[str, object],
) -> Path:
    """Return the settings-named child directory for one dense feature cache."""
    return settings_cache_dir(parent, prefix, metadata)


def _write_published_memmap_dense_feature_cache(
    parent: Path,
    prefix: str,
    cache: DenseThreeBodyFeatureCache,
    *,
    metadata: dict[str, object],
    overwrite: bool = False,
) -> MemmapDenseThreeBodyFeatureCache:
    """Persist a dense feature cache through a private staging directory."""
    final_dir = _dense_feature_cache_dir(parent, prefix, metadata)

    def write_staging(staging_dir: Path) -> None:
        _write_memmap_dense_feature_cache(
            staging_dir,
            prefix,
            cache,
            metadata=metadata,
        )

    published_dir, _, published = publish_staged_cache(
        final_dir.parent,
        final_dir.name,
        write_staging,
        overwrite=overwrite,
    )
    loaded = _load_memmap_dense_feature_cache(
        published_dir,
        prefix,
        expected_metadata=metadata,
    )
    if loaded is None:
        if published:
            raise ValueError(f"failed to publish dense feature cache: {published_dir}")
        return _write_published_memmap_dense_feature_cache(
            parent,
            prefix,
            cache,
            metadata=metadata,
            overwrite=True,
        )
    return loaded


def _load_memmap_dense_feature_cache(
    directory: Path,
    prefix: str,
    *,
    expected_metadata: dict[str, object],
    allow_active_triplet_superset: bool = False,
    required_triplet_indices: Iterable[int] | None = None,
) -> MemmapDenseThreeBodyFeatureCache | None:
    """Load one validated dense feature cache manifest if all files are present."""
    manifest_path = directory / f"{prefix}_manifest.json"
    if not manifest_path.is_file():
        return None
    with manifest_path.open("r", encoding="utf8") as handle:
        manifest = json.load(handle)
    if int(manifest.get("schema_version", -1)) != _DENSE_FEATURE_CACHE_SCHEMA_VERSION:
        return None
    if str(manifest.get("prefix")) != prefix:
        return None
    if not _cache_metadata_matches(
        manifest.get("metadata"),
        expected_metadata,
        allow_active_triplet_superset=allow_active_triplet_superset,
    ):
        return None

    blocks = []
    required_triplets = (
        None
        if required_triplet_indices is None
        else {int(value) for value in required_triplet_indices}
    )
    for block in manifest.get("blocks", ()):
        triplet_index = int(block["triplet_index"])
        if required_triplets is not None and triplet_index not in required_triplets:
            continue
        fields = block.get("fields", {})
        if not all(
            (directory / str(fields.get(field))).is_file()
            for field in _DENSE_REQUIRED_BLOCK_FIELDS
        ):
            return None
        per_atom_path = fields.get("per_atom_energy")
        per_atom_energy = None
        if per_atom_path is not None:
            per_atom_file = directory / str(per_atom_path)
            if not per_atom_file.is_file():
                return None
            per_atom_energy = _MemmapTensor(str(per_atom_file))
        per_atom_indices_path = fields.get("per_atom_indices")
        per_atom_indices = None
        if per_atom_indices_path is not None:
            per_atom_indices_file = directory / str(per_atom_indices_path)
            if not per_atom_indices_file.is_file():
                return None
            per_atom_indices = _MemmapTensor(str(per_atom_indices_file))
        force_atom_indices_path = fields.get("force_atom_indices")
        force_atom_indices = None
        if force_atom_indices_path is not None:
            force_atom_indices_file = directory / str(force_atom_indices_path)
            if not force_atom_indices_file.is_file():
                return None
            force_atom_indices = _MemmapTensor(str(force_atom_indices_file))
        blocks.append(
            MemmapDenseTripletFeatureBlock(
                triplet_index=triplet_index,
                coeff_start=tuple(int(value) for value in block["coeff_start"]),
                coeff_shape=tuple(int(value) for value in block["coeff_shape"]),
                energy=_MemmapTensor(str(directory / fields["energy"])),
                per_atom_energy=per_atom_energy,
                per_atom_indices=per_atom_indices,
                forces=_MemmapTensor(str(directory / fields["forces"])),
                force_atom_indices=force_atom_indices,
            )
        )
    return MemmapDenseThreeBodyFeatureCache(blocks=tuple(blocks))


def _iter_dense_feature_cache_manifests(directory: Path):
    """Yield dense feature-cache manifests in stable search order."""
    yield from sorted(directory.glob("*_manifest.json"))
    for child in sorted(path for path in directory.iterdir() if path.is_dir()):
        yield from sorted(child.glob("*_manifest.json"))


def _find_compatible_memmap_dense_feature_cache(
    directory: Path,
    *,
    expected_metadata: dict[str, object],
    required_triplet_indices: Iterable[int],
) -> MemmapDenseThreeBodyFeatureCache | None:
    """Find a V2 disk cache whose active triplets cover the expected request."""
    if not directory.is_dir():
        return None
    for manifest_path in _iter_dense_feature_cache_manifests(directory):
        try:
            with manifest_path.open("r", encoding="utf8") as handle:
                manifest = json.load(handle)
        except (OSError, json.JSONDecodeError):
            continue
        prefix = str(manifest.get("prefix", ""))
        if not prefix:
            continue
        cached = _load_memmap_dense_feature_cache(
            manifest_path.parent,
            prefix,
            expected_metadata=expected_metadata,
            allow_active_triplet_superset=True,
            required_triplet_indices=required_triplet_indices,
        )
        if cached is not None:
            return cached
    return None


[docs] def load_memmap_threebody_feature_cache( directory: Path | str, *, prefix: str | None = None, mmap_mode: Literal["r", "r+", "c"] | None = "r", copy: bool = False, ) -> dict[str, list[dict[str, object]]]: """ Load disk-backed dense three-body feature blocks for debugging. The returned dictionary is grouped as ``cache[prefix]``. Each value is a list of block dictionaries with ``triplet_index``, ``coeff_start``, ``coeff_shape``, ``energy``, optional ``per_atom_energy``, optional atom-index arrays, and ``forces`` entries. Args: directory: Directory containing dense feature block manifests and ``.npy`` arrays. prefix: Optional cache prefix to filter on. mmap_mode: Mode passed to ``numpy.load``. Use ``None`` to load regular arrays directly. copy: Whether to copy loaded arrays into normal ``numpy.ndarray`` instances. Returns: Dense feature blocks grouped by cache prefix. Raises: ValueError: If ``directory`` is not an existing directory. """ cache_dir = Path(directory) if not cache_dir.is_dir(): raise ValueError(f"`directory` must be an existing directory: {cache_dir}") manifest_paths = list(cache_dir.glob("*_manifest.json")) for child in sorted(path for path in cache_dir.iterdir() if path.is_dir()): manifest_paths.extend(child.glob("*_manifest.json")) loaded: dict[str, list[dict[str, object]]] = {} for manifest_path in sorted(manifest_paths): manifest_dir = manifest_path.parent manifest = json.loads(manifest_path.read_text(encoding="utf8")) cache_prefix = str(manifest["prefix"]) if prefix is not None and cache_prefix != prefix: continue blocks: list[dict[str, object]] = [] for block in manifest["blocks"]: fields = block["fields"] loaded_block: dict[str, object] = { "triplet_index": int(block["triplet_index"]), "coeff_start": tuple(int(value) for value in block["coeff_start"]), "coeff_shape": tuple(int(value) for value in block["coeff_shape"]), } for field in _DENSE_BLOCK_FIELDS: if field not in fields: loaded_block[field] = None continue array = np.load(manifest_dir / fields[field], mmap_mode=mmap_mode) if copy: array = np.array(array) loaded_block[field] = array for field in ("per_atom_indices", "force_atom_indices"): if field not in fields or fields[field] is None: loaded_block[field] = None continue array = np.load(manifest_dir / fields[field], mmap_mode=mmap_mode) if copy: array = np.array(array) loaded_block[field] = array blocks.append(loaded_block) loaded[cache_prefix] = blocks return loaded
def _build_dense_feature_cache_from_buckets( buckets: Buckets, system_index: torch.Tensor, coeff_shape: tuple[int, int, int], *, spline: SplineKind = "cubic", active_triplet_mask: torch.Tensor | None = None, n_cat: int = 10, first_knot_xy: float = 0.0, first_knot_z: float = 0.0, knot_spacing_xy: float = 0.25, knot_spacing_z: float = 0.25, lower_support_xy: float = 0.0, lower_support_z: float = 0.0, eps: float = 1.0e-12, storage: Literal["cpu", "disk"] = "cpu", cache_dir: Path | None = None, cache_prefix: str = "threebody", metadata: dict[str, object] | None = None, overwrite: bool = False, include_per_atom_energy: bool = True, runtime_config: ThreeBodyRuntimeConfig | None = None, ) -> DenseThreeBodyFeatureCache | MemmapDenseThreeBodyFeatureCache: """Build coalesced dense feature blocks for fixed triplet geometry.""" feature_cache = _build_feature_cache_from_buckets( buckets, coeff_shape, spline=spline, active_triplet_mask=active_triplet_mask, n_cat=n_cat, first_knot_xy=first_knot_xy, first_knot_z=first_knot_z, knot_spacing_xy=knot_spacing_xy, knot_spacing_z=knot_spacing_z, lower_support_xy=lower_support_xy, lower_support_z=lower_support_z, eps=eps, runtime_config=runtime_config, ) dense_cache = _build_dense_feature_cache_from_feature_cache( feature_cache, system_index, coeff_shape=coeff_shape, include_per_atom_energy=include_per_atom_energy, runtime_config=runtime_config, ) if storage == "cpu": return dense_cache if storage != "disk": raise ValueError("`storage` must be 'cpu' or 'disk'") if cache_dir is None: raise ValueError("`cache_dir` is required when `storage='disk'`") if metadata is None: return _write_memmap_dense_feature_cache( Path(cache_dir), cache_prefix, dense_cache, metadata=metadata, ) return _write_published_memmap_dense_feature_cache( Path(cache_dir), cache_prefix, dense_cache, metadata=metadata, overwrite=overwrite, ) __all__ = [ "FeatureCacheMode", "load_memmap_threebody_feature_cache", ]