Source code for ufp.terms.state

"""Charge and collinear-spin state terms for UFP models."""

from __future__ import annotations

import math
from collections.abc import Sequence

import torch

from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import spline_support_mask_1d, uniform_stencil_1d
from ufp.terms._base import (
    LinearAssemblyOptions,
    OneBodyTerm,
    PairTerm,
    TermInputRequirements,
)
from ufp.terms._parameters import (
    ParameterBlock,
    ParameterBlockCacheChannel,
    ParameterBlockCacheDescriptor,
    copy_parameter_data,
)
from ufp.terms._shared import empty_atomwise_output, pair_weight
from ufp.terms.categories import active_pair_mask as _active_pair_mask
from ufp.terms.categories import pair_categories as _pair_categories
from ufp.terms.cutoffs import CutoffEnvelope, normalize_cutoff_envelope
from ufp.terms.twobody import SplineTwoBodyTerm


COULOMB_CONSTANT_EV_ANGSTROM = 14.3996454784255


def _normalized_atomic_types(atomic_types: Sequence[int]) -> tuple[int, ...]:
    """Return sorted unique atomic numbers and reject empty specifications."""
    normalized = tuple(sorted(set(int(value) for value in atomic_types)))
    if not normalized:
        raise ValueError("`atomic_types` must contain at least one element")
    return normalized


def _element_parameter(
    value,
    *,
    name: str,
    shape: tuple[int, ...],
    dtype: torch.dtype | None,
) -> torch.Tensor:
    """Normalize a one-dimensional per-element initializer."""
    if value is None:
        return torch.zeros(shape, dtype=dtype)
    tensor = torch.as_tensor(value, dtype=dtype)
    if tensor.ndim == 0 and shape == (1,):
        tensor = tensor.reshape(1)
    if tuple(int(dim) for dim in tensor.shape) != shape:
        raise ValueError(f"`{name}` must have shape {shape}")
    return tensor.detach().clone()


def _empty_block_matrix(
    targets,
    block,
    *,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """Create an unweighted least-squares block matrix."""
    return torch.zeros(
        (targets.n_rows, block.size),
        dtype=dtype,
        device=device,
    )


def _add_entries(
    matrix: torch.Tensor,
    rows: torch.Tensor,
    cols: torch.Tensor,
    values: torch.Tensor,
) -> None:
    """Accumulate broadcastable row/column/value entries into a dense matrix."""
    if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
        return

    rows, cols, values = torch.broadcast_tensors(rows, cols, values)
    valid = rows >= 0
    if not torch.any(valid):
        return

    width = int(matrix.shape[1])
    flat_rows = rows[valid].reshape(-1)
    flat_cols = cols[valid].reshape(-1)
    flat_values = values[valid].reshape(-1)
    matrix.reshape(-1).index_add_(0, flat_rows * width + flat_cols, flat_values)


def _inactive_aware_pair_mask(
    inputs: UFPInput,
    *,
    atomic_types: Sequence[int],
    symmetric: bool,
    active_pair_mask: torch.Tensor,
    active_pair_indices: tuple[int, ...],
    n_pair_categories: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Return pair categories and a mask for configured active pair channels."""
    pair_category = inputs.pair_category_indices(
        atomic_types,
        symmetric=symmetric,
    )
    handled_mask = pair_category >= 0
    if len(active_pair_indices) != n_pair_categories:
        active_mask = active_pair_mask.to(device=inputs.device)
        active_handled = torch.zeros_like(handled_mask)
        active_handled[handled_mask] = active_mask.index_select(
            0,
            pair_category[handled_mask],
        )
        handled_mask = active_handled
    return pair_category, handled_mask


[docs] class ChargeSelfEnergyTerm(OneBodyTerm): """Per-element local charge electronegativity and hardness energy.""" def __init__( self, *, atomic_types: Sequence[int], electronegativities=None, hardnesses=None, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store one electronegativity and hardness coefficient per element.""" normalized_atomic_types = _normalized_atomic_types(atomic_types) super().__init__(cutoff=None, atomic_types=normalized_atomic_types) shape = (len(normalized_atomic_types),) self.fittable = bool(fittable) self.frozen = bool(frozen) self.electronegativities = torch.nn.Parameter( _element_parameter( electronegativities, name="electronegativities", shape=shape, dtype=dtype, ), requires_grad=bool(trainable) and not self.frozen, ) self.hardnesses = torch.nn.Parameter( _element_parameter( hardnesses, name="hardnesses", shape=shape, dtype=dtype, ), requires_grad=bool(trainable) and not self.frozen, ) @property def input_requirements(self) -> TermInputRequirements: """Require fixed local charge state.""" return TermInputRequirements(state_fields=("atomic_charges",)) @property def provides_forces(self) -> bool: """Report that this term contributes explicit zero forces.""" return True @property def optimizer_group(self) -> str | None: """Group trainable state-term parameters for workflow optimizers.""" return "charge_spin" def _parameter_block( self, *, name: str, kind: str, parameter: torch.nn.Parameter, ) -> ParameterBlock: assert self.atomic_types is not None return ParameterBlock( name=name, kind=kind, shape=tuple(int(dim) for dim in parameter.shape), read=lambda: parameter, write=lambda values: copy_parameter_data(parameter, values), label=f"{kind}[{self.atomic_types}]", regularization_group="charge_spin", fittable=self.fittable, frozen=self.frozen, assembler=self._assemble_block, cache_descriptor=ParameterBlockCacheDescriptor( family={"kind": kind}, channels=tuple( ParameterBlockCacheChannel( kind="Z", values=(atomic_number,), start=index, stop=index + 1, ) for index, atomic_number in enumerate(self.atomic_types) ), reusable=False, ), )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return linear charge self-energy coefficient blocks.""" return ( self._parameter_block( name="electronegativities", kind="charge_self_chi", parameter=self.electronegativities, ), self._parameter_block( name="hardnesses", kind="charge_self_eta", parameter=self.hardnesses, ), )
def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None: """Assemble one charge self-energy block for fixed charges.""" self.validate_inputs(inputs) assert self.atomic_types is not None assert inputs.atomic_charges is not None value_indices = inputs.atomic_category_indices(self.atomic_types) covered_atoms = value_indices >= 0 if not torch.any(covered_atoms): return None charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype) if block.name == "electronegativities": factors = charges elif block.name == "hardnesses": factors = 0.5 * charges.square() else: return None matrix = _empty_block_matrix( targets, block, device=inputs.device, dtype=inputs.dtype, ) valid_per_atom = covered_atoms & (targets.per_atom_rows >= 0) _add_entries( matrix, targets.per_atom_rows[valid_per_atom], value_indices[valid_per_atom], factors[valid_per_atom], ) energy_rows = targets.energy_rows.index_select(0, inputs.system_index) valid_energy = covered_atoms & (energy_rows >= 0) _add_entries( matrix, energy_rows[valid_energy], value_indices[valid_energy], factors[valid_energy], ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble all requested charge self-energy blocks.""" blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := self._assemble_block(block, batch.inputs, targets)) is not None }
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate local charge self energy and charge potential.""" self.validate_inputs(inputs) assert self.atomic_types is not None assert inputs.atomic_charges is not None charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype) chi = self.electronegativities.to(device=inputs.device, dtype=inputs.dtype) eta = self.hardnesses.to(device=inputs.device, dtype=inputs.dtype) value_indices = inputs.atomic_category_indices(self.atomic_types) covered_atoms = value_indices >= 0 per_atom_energy = torch.zeros( inputs.n_atoms, dtype=inputs.dtype, device=inputs.device, ) charge_potential = torch.zeros_like(per_atom_energy) if torch.any(covered_atoms): local_chi = chi[value_indices[covered_atoms]] local_eta = eta[value_indices[covered_atoms]] local_q = charges[covered_atoms] per_atom_energy[covered_atoms] = local_chi * local_q + ( 0.5 * local_eta * local_q.square() ) charge_potential[covered_atoms] = local_chi + local_eta * local_q energy = torch.zeros( inputs.n_systems, dtype=inputs.dtype, device=inputs.device, ) energy.index_add_(0, inputs.system_index, per_atom_energy) return UFPOutput( energy=energy, forces=torch.zeros( (inputs.n_atoms, 3), dtype=inputs.dtype, device=inputs.device, ), per_atom_energy=per_atom_energy, features={"charge_potential": charge_potential}, )
[docs] class CollinearSpinLandauTerm(OneBodyTerm): """Per-element onsite Landau energy for fixed scalar spin moments.""" def __init__( self, *, atomic_types: Sequence[int], quadratic=None, quartic=None, trainable: bool = True, fittable: bool = True, frozen: bool = False, dtype: torch.dtype | None = None, ) -> None: """Store quadratic and quartic coefficients per element.""" normalized_atomic_types = _normalized_atomic_types(atomic_types) super().__init__(cutoff=None, atomic_types=normalized_atomic_types) shape = (len(normalized_atomic_types),) self.fittable = bool(fittable) self.frozen = bool(frozen) self.quadratic = torch.nn.Parameter( _element_parameter( quadratic, name="quadratic", shape=shape, dtype=dtype, ), requires_grad=bool(trainable) and not self.frozen, ) self.quartic = torch.nn.Parameter( _element_parameter( quartic, name="quartic", shape=shape, dtype=dtype, ), requires_grad=bool(trainable) and not self.frozen, ) @property def input_requirements(self) -> TermInputRequirements: """Require fixed local collinear spin moments.""" return TermInputRequirements(state_fields=("atomic_spin_moments",)) @property def provides_forces(self) -> bool: """Report that this term contributes explicit zero forces.""" return True @property def optimizer_group(self) -> str | None: """Group trainable state-term parameters for workflow optimizers.""" return "charge_spin" def _parameter_block( self, *, name: str, kind: str, parameter: torch.nn.Parameter, ) -> ParameterBlock: assert self.atomic_types is not None return ParameterBlock( name=name, kind=kind, shape=tuple(int(dim) for dim in parameter.shape), read=lambda: parameter, write=lambda values: copy_parameter_data(parameter, values), label=f"{kind}[{self.atomic_types}]", regularization_group="charge_spin", fittable=self.fittable, frozen=self.frozen, assembler=self._assemble_block, cache_descriptor=ParameterBlockCacheDescriptor( family={"kind": kind}, channels=tuple( ParameterBlockCacheChannel( kind="Z", values=(atomic_number,), start=index, stop=index + 1, ) for index, atomic_number in enumerate(self.atomic_types) ), reusable=False, ), )
[docs] def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return linear spin Landau coefficient blocks.""" return ( self._parameter_block( name="quadratic", kind="spin_landau_quadratic", parameter=self.quadratic, ), self._parameter_block( name="quartic", kind="spin_landau_quartic", parameter=self.quartic, ), )
def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None: """Assemble one Landau block for fixed spin moments.""" self.validate_inputs(inputs) assert self.atomic_types is not None assert inputs.atomic_spin_moments is not None value_indices = inputs.atomic_category_indices(self.atomic_types) covered_atoms = value_indices >= 0 if not torch.any(covered_atoms): return None spins = inputs.atomic_spin_moments.to( device=inputs.device, dtype=inputs.dtype, ) if block.name == "quadratic": factors = spins.square() elif block.name == "quartic": factors = spins.pow(4) else: return None matrix = _empty_block_matrix( targets, block, device=inputs.device, dtype=inputs.dtype, ) valid_per_atom = covered_atoms & (targets.per_atom_rows >= 0) _add_entries( matrix, targets.per_atom_rows[valid_per_atom], value_indices[valid_per_atom], factors[valid_per_atom], ) energy_rows = targets.energy_rows.index_select(0, inputs.system_index) valid_energy = covered_atoms & (energy_rows >= 0) _add_entries( matrix, energy_rows[valid_energy], value_indices[valid_energy], factors[valid_energy], ) return None if torch.count_nonzero(matrix) == 0 else matrix
[docs] def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble all requested Landau blocks.""" blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := self._assemble_block(block, batch.inputs, targets)) is not None }
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate onsite spin energy and effective field.""" self.validate_inputs(inputs) assert self.atomic_types is not None assert inputs.atomic_spin_moments is not None spins = inputs.atomic_spin_moments.to( device=inputs.device, dtype=inputs.dtype, ) quadratic = self.quadratic.to(device=inputs.device, dtype=inputs.dtype) quartic = self.quartic.to(device=inputs.device, dtype=inputs.dtype) value_indices = inputs.atomic_category_indices(self.atomic_types) covered_atoms = value_indices >= 0 per_atom_energy = torch.zeros( inputs.n_atoms, dtype=inputs.dtype, device=inputs.device, ) spin_effective_field = torch.zeros_like(per_atom_energy) if torch.any(covered_atoms): local_a = quadratic[value_indices[covered_atoms]] local_b = quartic[value_indices[covered_atoms]] local_m = spins[covered_atoms] per_atom_energy[covered_atoms] = ( local_a * local_m.square() + local_b * local_m.pow(4) ) spin_effective_field[covered_atoms] = -( 2.0 * local_a * local_m + 4.0 * local_b * local_m.pow(3) ) energy = torch.zeros( inputs.n_systems, dtype=inputs.dtype, device=inputs.device, ) energy.index_add_(0, inputs.system_index, per_atom_energy) return UFPOutput( energy=energy, forces=torch.zeros( (inputs.n_atoms, 3), dtype=inputs.dtype, device=inputs.device, ), per_atom_energy=per_atom_energy, features={"spin_effective_field": spin_effective_field}, )
class _StateScaledSplinePairTerm(SplineTwoBodyTerm): """Common implementation for state-scaled spline pair terms.""" _state_field: str _feature_name: str _feature_derivative_sign: float _block_kind: str _label_prefix: str @property def input_requirements(self) -> TermInputRequirements: """Require a neighbor list and the configured atomwise state field.""" return TermInputRequirements( neighbor_list=True, state_fields=(self._state_field,), ) @property def optimizer_group(self) -> str | None: """Group trainable state-term parameters for workflow optimizers.""" return "charge_spin" def parameter_blocks(self) -> tuple[ParameterBlock, ...]: """Return the state-scaled pair spline coefficient block.""" return ( ParameterBlock( name="coeffs_by_pair", kind=self._block_kind, shape=tuple(int(dim) for dim in self.true_coeffs_by_pair.shape), read=lambda: self.true_coeffs_by_pair, write=self._write_true_coeffs_by_pair, label=f"{self._label_prefix}[{self.atomic_types}]", coefficient_provider=self.coefficient_provider, coefficient_index=self.coefficient_index, regularization_group="charge_spin", fittable=self.fittable, frozen=self.frozen, assembler=self._assemble_block, cache_descriptor=ParameterBlockCacheDescriptor( family={ "kind": self._block_kind, "symmetric": bool(self.symmetric), "spline": str(self.spline), "first_knot": float(self.first_knot), "knot_spacing": float(self.knot_spacing), "coeff_size": int(self.true_coeffs_by_pair.shape[1]), "eps": float(self.eps), }, channels=tuple( ParameterBlockCacheChannel( kind="pair", values=self.pair_categories[pair_index], start=int(pair_index) * int(self.true_coeffs_by_pair.shape[1]), stop=(int(pair_index) + 1) * int(self.true_coeffs_by_pair.shape[1]), ) for pair_index in self._active_pair_indices ), reusable=False, ), ), ) def _state_values(self, inputs: UFPInput) -> torch.Tensor: values = getattr(inputs, self._state_field) assert values is not None return values.to(device=inputs.device, dtype=inputs.dtype) def _check_pair_inputs(self, inputs: UFPInput) -> None: self.validate_inputs(inputs) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric state-scaled spline pair terms require a full neighbor list" ) def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None: """Assemble one state-scaled spline pair block.""" self._check_pair_inputs(inputs) assert self.atomic_types is not None _, n_knots = block.shape pair_category, handled_mask = _inactive_aware_pair_mask( inputs, atomic_types=self.atomic_types, symmetric=self.symmetric, active_pair_mask=self.active_pair_mask, active_pair_indices=self._active_pair_indices, n_pair_categories=len(self.pair_categories), ) if not torch.any(handled_mask): return None pair_distances = inputs.pair_distances(handled_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(n_knots), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return None pair_distances = pair_distances[support_mask] handled_pair_category = pair_category[handled_mask][support_mask] pair_vectors = inputs.pair_vectors(handled_mask)[support_mask] first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] state = self._state_values(inputs) state_product = state.index_select(0, first_atom) * state.index_select( 0, second_atom, ) stencil = uniform_stencil_1d( pair_distances, coeff_size=int(n_knots), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) cols = stencil.indices + handled_pair_category[:, None] * int(n_knots) scale = pair_weight(inputs) values = scale * state_product[:, None] * stencil.values grads = scale * state_product[:, None] * stencil.grads matrix = _empty_block_matrix( targets, block, device=inputs.device, dtype=inputs.dtype, ) _add_entries( matrix, targets.energy_rows.index_select(0, pair_system_index)[:, None], cols, values, ) half_values = 0.5 * values _add_entries( matrix, targets.per_atom_rows.index_select(0, first_atom)[:, None], cols, half_values, ) _add_entries( matrix, targets.per_atom_rows.index_select(0, second_atom)[:, None], cols, half_values, ) inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) direction = pair_vectors * inv_r[:, None] force_second = -(grads[:, :, None] * direction[:, None, :]) force_first = -force_second _add_entries( matrix, targets.force_rows.index_select(0, first_atom)[:, :, None], cols[:, None, :], force_first.permute(0, 2, 1), ) _add_entries( matrix, targets.force_rows.index_select(0, second_atom)[:, :, None], cols[:, None, :], force_second.permute(0, 2, 1), ) return None if torch.count_nonzero(matrix) == 0 else matrix def assemble_linear_blocks( self, batch, targets, options: LinearAssemblyOptions | None = None, ): """Assemble requested state-scaled pair blocks.""" blocks = () if options is None else options.blocks return { block.index: matrix for block in blocks if (matrix := self._assemble_block(block, batch.inputs, targets)) is not None } def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate state-scaled spline pair energy, forces, and state features.""" self._check_pair_inputs(inputs) assert self.atomic_types is not None output = empty_atomwise_output(inputs, forces=True) feature = torch.zeros( inputs.n_atoms, dtype=inputs.dtype, device=inputs.device, ) output.features[self._feature_name] = feature if not self._active_pair_indices: return output pair_category, handled_mask = _inactive_aware_pair_mask( inputs, atomic_types=self.atomic_types, symmetric=self.symmetric, active_pair_mask=self.active_pair_mask, active_pair_indices=self._active_pair_indices, n_pair_categories=len(self.pair_categories), ) if not torch.any(handled_mask): return output coeffs_by_pair = self.true_coeffs_by_pair.to( device=inputs.device, dtype=inputs.dtype, ) pair_distances = inputs.pair_distances(handled_mask) support_mask = spline_support_mask_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) if not torch.any(support_mask): return output pair_distances = pair_distances[support_mask] handled_pair_category = pair_category[handled_mask][support_mask] stencil = uniform_stencil_1d( pair_distances, coeff_size=int(coeffs_by_pair.shape[1]), first_knot=self.first_knot, knot_spacing=self.knot_spacing, spline=self.spline, ) coeff_window = coeffs_by_pair[handled_pair_category[:, None], stencil.indices] pair_value = (stencil.values * coeff_window).sum(dim=1) pair_grad = (stencil.grads * coeff_window).sum(dim=1) first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) pair_vectors = inputs.pair_vectors(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] pair_vectors = pair_vectors[support_mask] state = self._state_values(inputs) first_state = state.index_select(0, first_atom) second_state = state.index_select(0, second_atom) state_product = first_state * second_state scale = pair_weight(inputs) weighted_pair_energy = scale * state_product * pair_value weighted_pair_grad = scale * state_product * pair_grad assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None output.energy.index_add_(0, pair_system_index, weighted_pair_energy) per_atom_contribution = 0.5 * weighted_pair_energy output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution) output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution) inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) force_on_second = -weighted_pair_grad[:, None] * pair_vectors * inv_r[:, None] output.forces.index_add_(0, first_atom, -force_on_second) output.forces.index_add_(0, second_atom, force_on_second) sign = float(self._feature_derivative_sign) feature.index_add_(0, first_atom, sign * scale * second_state * pair_value) feature.index_add_(0, second_atom, sign * scale * first_state * pair_value) return output
[docs] class ChargeScaledSplinePairTerm(_StateScaledSplinePairTerm): """Short-range spline pair correction scaled by fixed local charges.""" _state_field = "atomic_charges" _feature_name = "charge_potential" _feature_derivative_sign = 1.0 _block_kind = "charge_twobody" _label_prefix = "charge_twobody"
[docs] class CollinearSpinExchangeTerm(_StateScaledSplinePairTerm): """Pairwise collinear exchange spline scaled by fixed spin moments.""" _state_field = "atomic_spin_moments" _feature_name = "spin_effective_field" _feature_derivative_sign = -1.0 _block_kind = "spin_exchange" _label_prefix = "spin_exchange"
[docs] class LocalChargeCoulombTerm(PairTerm): """Finite-cutoff softened Coulomb interaction for fixed local charges.""" def __init__( self, *, cutoff: float, atomic_types: Sequence[int], active_pairs: Sequence[tuple[int, int]] | None = None, symmetric: bool = True, softening: float = 1.0e-6, scale: float = 1.0, cutoff_envelope: CutoffEnvelope | str | None = None, eps: float = 1.0e-12, ) -> None: """Store local-charge Coulomb cutoff, screening, and active pair metadata.""" cutoff = float(cutoff) if not math.isfinite(cutoff) or cutoff <= 0.0: raise ValueError("`cutoff` must be a finite positive value") softening = float(softening) if not math.isfinite(softening) or softening < 0.0: raise ValueError("`softening` must be finite and non-negative") scale = float(scale) if not math.isfinite(scale): raise ValueError("`scale` must be finite") eps = float(eps) if not math.isfinite(eps) or eps <= 0.0: raise ValueError("`eps` must be a finite positive value") normalized_atomic_types = _normalized_atomic_types(atomic_types) super().__init__(cutoff=cutoff, atomic_types=normalized_atomic_types) self.symmetric = bool(symmetric) self.softening = softening self.scale = scale self.eps = eps self.cutoff_envelope = normalize_cutoff_envelope( cutoff_envelope, cutoff=cutoff, default_kind="none", ) pair_categories = _pair_categories( normalized_atomic_types, symmetric=self.symmetric, ) object.__setattr__(self, "_pair_categories", pair_categories) active_mask = _active_pair_mask( pair_categories, active_pairs=active_pairs, symmetric=self.symmetric, ) self.register_buffer("active_pair_mask", active_mask, persistent=False) object.__setattr__( self, "_active_pair_indices", tuple( index for index, enabled in enumerate(active_mask.tolist()) if enabled ), ) @property def input_requirements(self) -> TermInputRequirements: """Require a neighbor list and fixed local charge state.""" return TermInputRequirements( neighbor_list=True, state_fields=("atomic_charges",), ) @property def provides_forces(self) -> bool: """Report that this term provides analytic pair forces.""" return True @property def optimizer_group(self) -> str | None: """Return the common charge/spin optimizer group name.""" return "charge_spin" @property def pair_categories(self) -> tuple[tuple[int, int], ...]: """Return configured pair categories.""" return self._pair_categories @property def active_pair_categories(self) -> tuple[tuple[int, int], ...]: """Return active pair categories.""" return tuple(self.pair_categories[index] for index in self._active_pair_indices) def _check_pair_inputs(self, inputs: UFPInput) -> None: self.validate_inputs(inputs) if not self.symmetric and not inputs.neighbor_list.full_list: raise RuntimeError( "asymmetric local charge Coulomb terms require a full neighbor list" )
[docs] def forward(self, inputs: UFPInput) -> UFPOutput: """Evaluate softened Coulomb energy, forces, and charge potential.""" self._check_pair_inputs(inputs) assert self.atomic_types is not None assert inputs.atomic_charges is not None output = empty_atomwise_output(inputs, forces=True) charge_potential = torch.zeros( inputs.n_atoms, dtype=inputs.dtype, device=inputs.device, ) output.features["charge_potential"] = charge_potential if not self._active_pair_indices: return output pair_category, handled_mask = _inactive_aware_pair_mask( inputs, atomic_types=self.atomic_types, symmetric=self.symmetric, active_pair_mask=self.active_pair_mask, active_pair_indices=self._active_pair_indices, n_pair_categories=len(self.pair_categories), ) del pair_category if not torch.any(handled_mask): return output pair_distances = inputs.pair_distances(handled_mask) support_mask = pair_distances < float(self.cutoff) if not torch.any(support_mask): return output pair_distances = pair_distances[support_mask] pair_vectors = inputs.pair_vectors(handled_mask)[support_mask] first_atom, second_atom = inputs.pair_indices(handled_mask) pair_system_index = inputs.pair_system_index(handled_mask) first_atom = first_atom[support_mask] second_atom = second_atom[support_mask] pair_system_index = pair_system_index[support_mask] charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype) first_charge = charges.index_select(0, first_atom) second_charge = charges.index_select(0, second_atom) charge_product = first_charge * second_charge softening_sq = float(self.softening) * float(self.softening) softened_sq = pair_distances.square() + softening_sq min_softened_sq = float(self.eps) * float(self.eps) safe_softened_sq = softened_sq.clamp_min(min_softened_sq) inverse_softened = torch.rsqrt(safe_softened_sq) inverse_softened_cubed = safe_softened_sq.pow(-1.5) base = COULOMB_CONSTANT_EV_ANGSTROM * float(self.scale) * charge_product envelope = self.cutoff_envelope.values(pair_distances).to( device=inputs.device, dtype=inputs.dtype, ) envelope_grad = self.cutoff_envelope.derivatives(pair_distances).to( device=inputs.device, dtype=inputs.dtype, ) scale = pair_weight(inputs) pair_energy = scale * base * inverse_softened * envelope softened_grad = -pair_distances * inverse_softened_cubed pair_grad = ( scale * base * (softened_grad * envelope + inverse_softened * envelope_grad) ) assert output.energy is not None assert output.forces is not None assert output.per_atom_energy is not None output.energy.index_add_(0, pair_system_index, pair_energy) per_atom_contribution = 0.5 * pair_energy output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution) output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution) inv_r = torch.where( pair_distances > self.eps, pair_distances.reciprocal(), torch.zeros_like(pair_distances), ) force_on_second = -pair_grad[:, None] * pair_vectors * inv_r[:, None] output.forces.index_add_(0, first_atom, -force_on_second) output.forces.index_add_(0, second_atom, force_on_second) potential_base = ( scale * COULOMB_CONSTANT_EV_ANGSTROM * float(self.scale) * inverse_softened * envelope ) charge_potential.index_add_(0, first_atom, potential_base * second_charge) charge_potential.index_add_(0, second_atom, potential_base * first_charge) return output
__all__ = [ "COULOMB_CONSTANT_EV_ANGSTROM", "ChargeScaledSplinePairTerm", "ChargeSelfEnergyTerm", "CollinearSpinExchangeTerm", "CollinearSpinLandauTerm", "LocalChargeCoulombTerm", ]