"""Charge and collinear-spin state terms for UFP models."""
from __future__ import annotations
import math
from collections.abc import Sequence
import torch
from ufp.core.input import UFPInput
from ufp.core.output import UFPOutput
from ufp.splines.representation import spline_support_mask_1d, uniform_stencil_1d
from ufp.terms._base import (
LinearAssemblyOptions,
OneBodyTerm,
PairTerm,
TermInputRequirements,
)
from ufp.terms._parameters import (
ParameterBlock,
ParameterBlockCacheChannel,
ParameterBlockCacheDescriptor,
copy_parameter_data,
)
from ufp.terms._shared import empty_atomwise_output, pair_weight
from ufp.terms.categories import active_pair_mask as _active_pair_mask
from ufp.terms.categories import pair_categories as _pair_categories
from ufp.terms.cutoffs import CutoffEnvelope, normalize_cutoff_envelope
from ufp.terms.twobody import SplineTwoBodyTerm
COULOMB_CONSTANT_EV_ANGSTROM = 14.3996454784255
def _normalized_atomic_types(atomic_types: Sequence[int]) -> tuple[int, ...]:
"""Return sorted unique atomic numbers and reject empty specifications."""
normalized = tuple(sorted(set(int(value) for value in atomic_types)))
if not normalized:
raise ValueError("`atomic_types` must contain at least one element")
return normalized
def _element_parameter(
value,
*,
name: str,
shape: tuple[int, ...],
dtype: torch.dtype | None,
) -> torch.Tensor:
"""Normalize a one-dimensional per-element initializer."""
if value is None:
return torch.zeros(shape, dtype=dtype)
tensor = torch.as_tensor(value, dtype=dtype)
if tensor.ndim == 0 and shape == (1,):
tensor = tensor.reshape(1)
if tuple(int(dim) for dim in tensor.shape) != shape:
raise ValueError(f"`{name}` must have shape {shape}")
return tensor.detach().clone()
def _empty_block_matrix(
targets,
block,
*,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Create an unweighted least-squares block matrix."""
return torch.zeros(
(targets.n_rows, block.size),
dtype=dtype,
device=device,
)
def _add_entries(
matrix: torch.Tensor,
rows: torch.Tensor,
cols: torch.Tensor,
values: torch.Tensor,
) -> None:
"""Accumulate broadcastable row/column/value entries into a dense matrix."""
if rows.numel() == 0 or cols.numel() == 0 or values.numel() == 0:
return
rows, cols, values = torch.broadcast_tensors(rows, cols, values)
valid = rows >= 0
if not torch.any(valid):
return
width = int(matrix.shape[1])
flat_rows = rows[valid].reshape(-1)
flat_cols = cols[valid].reshape(-1)
flat_values = values[valid].reshape(-1)
matrix.reshape(-1).index_add_(0, flat_rows * width + flat_cols, flat_values)
def _inactive_aware_pair_mask(
inputs: UFPInput,
*,
atomic_types: Sequence[int],
symmetric: bool,
active_pair_mask: torch.Tensor,
active_pair_indices: tuple[int, ...],
n_pair_categories: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return pair categories and a mask for configured active pair channels."""
pair_category = inputs.pair_category_indices(
atomic_types,
symmetric=symmetric,
)
handled_mask = pair_category >= 0
if len(active_pair_indices) != n_pair_categories:
active_mask = active_pair_mask.to(device=inputs.device)
active_handled = torch.zeros_like(handled_mask)
active_handled[handled_mask] = active_mask.index_select(
0,
pair_category[handled_mask],
)
handled_mask = active_handled
return pair_category, handled_mask
[docs]
class ChargeSelfEnergyTerm(OneBodyTerm):
"""Per-element local charge electronegativity and hardness energy."""
def __init__(
self,
*,
atomic_types: Sequence[int],
electronegativities=None,
hardnesses=None,
trainable: bool = True,
fittable: bool = True,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store one electronegativity and hardness coefficient per element."""
normalized_atomic_types = _normalized_atomic_types(atomic_types)
super().__init__(cutoff=None, atomic_types=normalized_atomic_types)
shape = (len(normalized_atomic_types),)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
self.electronegativities = torch.nn.Parameter(
_element_parameter(
electronegativities,
name="electronegativities",
shape=shape,
dtype=dtype,
),
requires_grad=bool(trainable) and not self.frozen,
)
self.hardnesses = torch.nn.Parameter(
_element_parameter(
hardnesses,
name="hardnesses",
shape=shape,
dtype=dtype,
),
requires_grad=bool(trainable) and not self.frozen,
)
@property
def input_requirements(self) -> TermInputRequirements:
"""Require fixed local charge state."""
return TermInputRequirements(state_fields=("atomic_charges",))
@property
def provides_forces(self) -> bool:
"""Report that this term contributes explicit zero forces."""
return True
@property
def optimizer_group(self) -> str | None:
"""Group trainable state-term parameters for workflow optimizers."""
return "charge_spin"
def _parameter_block(
self,
*,
name: str,
kind: str,
parameter: torch.nn.Parameter,
) -> ParameterBlock:
assert self.atomic_types is not None
return ParameterBlock(
name=name,
kind=kind,
shape=tuple(int(dim) for dim in parameter.shape),
read=lambda: parameter,
write=lambda values: copy_parameter_data(parameter, values),
label=f"{kind}[{self.atomic_types}]",
regularization_group="charge_spin",
fittable=self.fittable,
frozen=self.frozen,
assembler=self._assemble_block,
cache_descriptor=ParameterBlockCacheDescriptor(
family={"kind": kind},
channels=tuple(
ParameterBlockCacheChannel(
kind="Z",
values=(atomic_number,),
start=index,
stop=index + 1,
)
for index, atomic_number in enumerate(self.atomic_types)
),
reusable=False,
),
)
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return linear charge self-energy coefficient blocks."""
return (
self._parameter_block(
name="electronegativities",
kind="charge_self_chi",
parameter=self.electronegativities,
),
self._parameter_block(
name="hardnesses",
kind="charge_self_eta",
parameter=self.hardnesses,
),
)
def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None:
"""Assemble one charge self-energy block for fixed charges."""
self.validate_inputs(inputs)
assert self.atomic_types is not None
assert inputs.atomic_charges is not None
value_indices = inputs.atomic_category_indices(self.atomic_types)
covered_atoms = value_indices >= 0
if not torch.any(covered_atoms):
return None
charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype)
if block.name == "electronegativities":
factors = charges
elif block.name == "hardnesses":
factors = 0.5 * charges.square()
else:
return None
matrix = _empty_block_matrix(
targets,
block,
device=inputs.device,
dtype=inputs.dtype,
)
valid_per_atom = covered_atoms & (targets.per_atom_rows >= 0)
_add_entries(
matrix,
targets.per_atom_rows[valid_per_atom],
value_indices[valid_per_atom],
factors[valid_per_atom],
)
energy_rows = targets.energy_rows.index_select(0, inputs.system_index)
valid_energy = covered_atoms & (energy_rows >= 0)
_add_entries(
matrix,
energy_rows[valid_energy],
value_indices[valid_energy],
factors[valid_energy],
)
return None if torch.count_nonzero(matrix) == 0 else matrix
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble all requested charge self-energy blocks."""
blocks = () if options is None else options.blocks
return {
block.index: matrix
for block in blocks
if (matrix := self._assemble_block(block, batch.inputs, targets))
is not None
}
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate local charge self energy and charge potential."""
self.validate_inputs(inputs)
assert self.atomic_types is not None
assert inputs.atomic_charges is not None
charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype)
chi = self.electronegativities.to(device=inputs.device, dtype=inputs.dtype)
eta = self.hardnesses.to(device=inputs.device, dtype=inputs.dtype)
value_indices = inputs.atomic_category_indices(self.atomic_types)
covered_atoms = value_indices >= 0
per_atom_energy = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
charge_potential = torch.zeros_like(per_atom_energy)
if torch.any(covered_atoms):
local_chi = chi[value_indices[covered_atoms]]
local_eta = eta[value_indices[covered_atoms]]
local_q = charges[covered_atoms]
per_atom_energy[covered_atoms] = local_chi * local_q + (
0.5 * local_eta * local_q.square()
)
charge_potential[covered_atoms] = local_chi + local_eta * local_q
energy = torch.zeros(
inputs.n_systems,
dtype=inputs.dtype,
device=inputs.device,
)
energy.index_add_(0, inputs.system_index, per_atom_energy)
return UFPOutput(
energy=energy,
forces=torch.zeros(
(inputs.n_atoms, 3),
dtype=inputs.dtype,
device=inputs.device,
),
per_atom_energy=per_atom_energy,
features={"charge_potential": charge_potential},
)
[docs]
class CollinearSpinLandauTerm(OneBodyTerm):
"""Per-element onsite Landau energy for fixed scalar spin moments."""
def __init__(
self,
*,
atomic_types: Sequence[int],
quadratic=None,
quartic=None,
trainable: bool = True,
fittable: bool = True,
frozen: bool = False,
dtype: torch.dtype | None = None,
) -> None:
"""Store quadratic and quartic coefficients per element."""
normalized_atomic_types = _normalized_atomic_types(atomic_types)
super().__init__(cutoff=None, atomic_types=normalized_atomic_types)
shape = (len(normalized_atomic_types),)
self.fittable = bool(fittable)
self.frozen = bool(frozen)
self.quadratic = torch.nn.Parameter(
_element_parameter(
quadratic,
name="quadratic",
shape=shape,
dtype=dtype,
),
requires_grad=bool(trainable) and not self.frozen,
)
self.quartic = torch.nn.Parameter(
_element_parameter(
quartic,
name="quartic",
shape=shape,
dtype=dtype,
),
requires_grad=bool(trainable) and not self.frozen,
)
@property
def input_requirements(self) -> TermInputRequirements:
"""Require fixed local collinear spin moments."""
return TermInputRequirements(state_fields=("atomic_spin_moments",))
@property
def provides_forces(self) -> bool:
"""Report that this term contributes explicit zero forces."""
return True
@property
def optimizer_group(self) -> str | None:
"""Group trainable state-term parameters for workflow optimizers."""
return "charge_spin"
def _parameter_block(
self,
*,
name: str,
kind: str,
parameter: torch.nn.Parameter,
) -> ParameterBlock:
assert self.atomic_types is not None
return ParameterBlock(
name=name,
kind=kind,
shape=tuple(int(dim) for dim in parameter.shape),
read=lambda: parameter,
write=lambda values: copy_parameter_data(parameter, values),
label=f"{kind}[{self.atomic_types}]",
regularization_group="charge_spin",
fittable=self.fittable,
frozen=self.frozen,
assembler=self._assemble_block,
cache_descriptor=ParameterBlockCacheDescriptor(
family={"kind": kind},
channels=tuple(
ParameterBlockCacheChannel(
kind="Z",
values=(atomic_number,),
start=index,
stop=index + 1,
)
for index, atomic_number in enumerate(self.atomic_types)
),
reusable=False,
),
)
[docs]
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return linear spin Landau coefficient blocks."""
return (
self._parameter_block(
name="quadratic",
kind="spin_landau_quadratic",
parameter=self.quadratic,
),
self._parameter_block(
name="quartic",
kind="spin_landau_quartic",
parameter=self.quartic,
),
)
def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None:
"""Assemble one Landau block for fixed spin moments."""
self.validate_inputs(inputs)
assert self.atomic_types is not None
assert inputs.atomic_spin_moments is not None
value_indices = inputs.atomic_category_indices(self.atomic_types)
covered_atoms = value_indices >= 0
if not torch.any(covered_atoms):
return None
spins = inputs.atomic_spin_moments.to(
device=inputs.device,
dtype=inputs.dtype,
)
if block.name == "quadratic":
factors = spins.square()
elif block.name == "quartic":
factors = spins.pow(4)
else:
return None
matrix = _empty_block_matrix(
targets,
block,
device=inputs.device,
dtype=inputs.dtype,
)
valid_per_atom = covered_atoms & (targets.per_atom_rows >= 0)
_add_entries(
matrix,
targets.per_atom_rows[valid_per_atom],
value_indices[valid_per_atom],
factors[valid_per_atom],
)
energy_rows = targets.energy_rows.index_select(0, inputs.system_index)
valid_energy = covered_atoms & (energy_rows >= 0)
_add_entries(
matrix,
energy_rows[valid_energy],
value_indices[valid_energy],
factors[valid_energy],
)
return None if torch.count_nonzero(matrix) == 0 else matrix
[docs]
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble all requested Landau blocks."""
blocks = () if options is None else options.blocks
return {
block.index: matrix
for block in blocks
if (matrix := self._assemble_block(block, batch.inputs, targets))
is not None
}
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate onsite spin energy and effective field."""
self.validate_inputs(inputs)
assert self.atomic_types is not None
assert inputs.atomic_spin_moments is not None
spins = inputs.atomic_spin_moments.to(
device=inputs.device,
dtype=inputs.dtype,
)
quadratic = self.quadratic.to(device=inputs.device, dtype=inputs.dtype)
quartic = self.quartic.to(device=inputs.device, dtype=inputs.dtype)
value_indices = inputs.atomic_category_indices(self.atomic_types)
covered_atoms = value_indices >= 0
per_atom_energy = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
spin_effective_field = torch.zeros_like(per_atom_energy)
if torch.any(covered_atoms):
local_a = quadratic[value_indices[covered_atoms]]
local_b = quartic[value_indices[covered_atoms]]
local_m = spins[covered_atoms]
per_atom_energy[covered_atoms] = (
local_a * local_m.square() + local_b * local_m.pow(4)
)
spin_effective_field[covered_atoms] = -(
2.0 * local_a * local_m + 4.0 * local_b * local_m.pow(3)
)
energy = torch.zeros(
inputs.n_systems,
dtype=inputs.dtype,
device=inputs.device,
)
energy.index_add_(0, inputs.system_index, per_atom_energy)
return UFPOutput(
energy=energy,
forces=torch.zeros(
(inputs.n_atoms, 3),
dtype=inputs.dtype,
device=inputs.device,
),
per_atom_energy=per_atom_energy,
features={"spin_effective_field": spin_effective_field},
)
class _StateScaledSplinePairTerm(SplineTwoBodyTerm):
"""Common implementation for state-scaled spline pair terms."""
_state_field: str
_feature_name: str
_feature_derivative_sign: float
_block_kind: str
_label_prefix: str
@property
def input_requirements(self) -> TermInputRequirements:
"""Require a neighbor list and the configured atomwise state field."""
return TermInputRequirements(
neighbor_list=True,
state_fields=(self._state_field,),
)
@property
def optimizer_group(self) -> str | None:
"""Group trainable state-term parameters for workflow optimizers."""
return "charge_spin"
def parameter_blocks(self) -> tuple[ParameterBlock, ...]:
"""Return the state-scaled pair spline coefficient block."""
return (
ParameterBlock(
name="coeffs_by_pair",
kind=self._block_kind,
shape=tuple(int(dim) for dim in self.true_coeffs_by_pair.shape),
read=lambda: self.true_coeffs_by_pair,
write=self._write_true_coeffs_by_pair,
label=f"{self._label_prefix}[{self.atomic_types}]",
coefficient_provider=self.coefficient_provider,
coefficient_index=self.coefficient_index,
regularization_group="charge_spin",
fittable=self.fittable,
frozen=self.frozen,
assembler=self._assemble_block,
cache_descriptor=ParameterBlockCacheDescriptor(
family={
"kind": self._block_kind,
"symmetric": bool(self.symmetric),
"spline": str(self.spline),
"first_knot": float(self.first_knot),
"knot_spacing": float(self.knot_spacing),
"coeff_size": int(self.true_coeffs_by_pair.shape[1]),
"eps": float(self.eps),
},
channels=tuple(
ParameterBlockCacheChannel(
kind="pair",
values=self.pair_categories[pair_index],
start=int(pair_index)
* int(self.true_coeffs_by_pair.shape[1]),
stop=(int(pair_index) + 1)
* int(self.true_coeffs_by_pair.shape[1]),
)
for pair_index in self._active_pair_indices
),
reusable=False,
),
),
)
def _state_values(self, inputs: UFPInput) -> torch.Tensor:
values = getattr(inputs, self._state_field)
assert values is not None
return values.to(device=inputs.device, dtype=inputs.dtype)
def _check_pair_inputs(self, inputs: UFPInput) -> None:
self.validate_inputs(inputs)
if not self.symmetric and not inputs.neighbor_list.full_list:
raise RuntimeError(
"asymmetric state-scaled spline pair terms require a full neighbor list"
)
def _assemble_block(self, block, inputs: UFPInput, targets) -> torch.Tensor | None:
"""Assemble one state-scaled spline pair block."""
self._check_pair_inputs(inputs)
assert self.atomic_types is not None
_, n_knots = block.shape
pair_category, handled_mask = _inactive_aware_pair_mask(
inputs,
atomic_types=self.atomic_types,
symmetric=self.symmetric,
active_pair_mask=self.active_pair_mask,
active_pair_indices=self._active_pair_indices,
n_pair_categories=len(self.pair_categories),
)
if not torch.any(handled_mask):
return None
pair_distances = inputs.pair_distances(handled_mask)
support_mask = spline_support_mask_1d(
pair_distances,
coeff_size=int(n_knots),
first_knot=self.first_knot,
knot_spacing=self.knot_spacing,
spline=self.spline,
)
if not torch.any(support_mask):
return None
pair_distances = pair_distances[support_mask]
handled_pair_category = pair_category[handled_mask][support_mask]
pair_vectors = inputs.pair_vectors(handled_mask)[support_mask]
first_atom, second_atom = inputs.pair_indices(handled_mask)
pair_system_index = inputs.pair_system_index(handled_mask)
first_atom = first_atom[support_mask]
second_atom = second_atom[support_mask]
pair_system_index = pair_system_index[support_mask]
state = self._state_values(inputs)
state_product = state.index_select(0, first_atom) * state.index_select(
0,
second_atom,
)
stencil = uniform_stencil_1d(
pair_distances,
coeff_size=int(n_knots),
first_knot=self.first_knot,
knot_spacing=self.knot_spacing,
spline=self.spline,
)
cols = stencil.indices + handled_pair_category[:, None] * int(n_knots)
scale = pair_weight(inputs)
values = scale * state_product[:, None] * stencil.values
grads = scale * state_product[:, None] * stencil.grads
matrix = _empty_block_matrix(
targets,
block,
device=inputs.device,
dtype=inputs.dtype,
)
_add_entries(
matrix,
targets.energy_rows.index_select(0, pair_system_index)[:, None],
cols,
values,
)
half_values = 0.5 * values
_add_entries(
matrix,
targets.per_atom_rows.index_select(0, first_atom)[:, None],
cols,
half_values,
)
_add_entries(
matrix,
targets.per_atom_rows.index_select(0, second_atom)[:, None],
cols,
half_values,
)
inv_r = torch.where(
pair_distances > self.eps,
pair_distances.reciprocal(),
torch.zeros_like(pair_distances),
)
direction = pair_vectors * inv_r[:, None]
force_second = -(grads[:, :, None] * direction[:, None, :])
force_first = -force_second
_add_entries(
matrix,
targets.force_rows.index_select(0, first_atom)[:, :, None],
cols[:, None, :],
force_first.permute(0, 2, 1),
)
_add_entries(
matrix,
targets.force_rows.index_select(0, second_atom)[:, :, None],
cols[:, None, :],
force_second.permute(0, 2, 1),
)
return None if torch.count_nonzero(matrix) == 0 else matrix
def assemble_linear_blocks(
self,
batch,
targets,
options: LinearAssemblyOptions | None = None,
):
"""Assemble requested state-scaled pair blocks."""
blocks = () if options is None else options.blocks
return {
block.index: matrix
for block in blocks
if (matrix := self._assemble_block(block, batch.inputs, targets))
is not None
}
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate state-scaled spline pair energy, forces, and state features."""
self._check_pair_inputs(inputs)
assert self.atomic_types is not None
output = empty_atomwise_output(inputs, forces=True)
feature = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
output.features[self._feature_name] = feature
if not self._active_pair_indices:
return output
pair_category, handled_mask = _inactive_aware_pair_mask(
inputs,
atomic_types=self.atomic_types,
symmetric=self.symmetric,
active_pair_mask=self.active_pair_mask,
active_pair_indices=self._active_pair_indices,
n_pair_categories=len(self.pair_categories),
)
if not torch.any(handled_mask):
return output
coeffs_by_pair = self.true_coeffs_by_pair.to(
device=inputs.device,
dtype=inputs.dtype,
)
pair_distances = inputs.pair_distances(handled_mask)
support_mask = spline_support_mask_1d(
pair_distances,
coeff_size=int(coeffs_by_pair.shape[1]),
first_knot=self.first_knot,
knot_spacing=self.knot_spacing,
spline=self.spline,
)
if not torch.any(support_mask):
return output
pair_distances = pair_distances[support_mask]
handled_pair_category = pair_category[handled_mask][support_mask]
stencil = uniform_stencil_1d(
pair_distances,
coeff_size=int(coeffs_by_pair.shape[1]),
first_knot=self.first_knot,
knot_spacing=self.knot_spacing,
spline=self.spline,
)
coeff_window = coeffs_by_pair[handled_pair_category[:, None], stencil.indices]
pair_value = (stencil.values * coeff_window).sum(dim=1)
pair_grad = (stencil.grads * coeff_window).sum(dim=1)
first_atom, second_atom = inputs.pair_indices(handled_mask)
pair_system_index = inputs.pair_system_index(handled_mask)
pair_vectors = inputs.pair_vectors(handled_mask)
first_atom = first_atom[support_mask]
second_atom = second_atom[support_mask]
pair_system_index = pair_system_index[support_mask]
pair_vectors = pair_vectors[support_mask]
state = self._state_values(inputs)
first_state = state.index_select(0, first_atom)
second_state = state.index_select(0, second_atom)
state_product = first_state * second_state
scale = pair_weight(inputs)
weighted_pair_energy = scale * state_product * pair_value
weighted_pair_grad = scale * state_product * pair_grad
assert output.energy is not None
assert output.forces is not None
assert output.per_atom_energy is not None
output.energy.index_add_(0, pair_system_index, weighted_pair_energy)
per_atom_contribution = 0.5 * weighted_pair_energy
output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution)
output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution)
inv_r = torch.where(
pair_distances > self.eps,
pair_distances.reciprocal(),
torch.zeros_like(pair_distances),
)
force_on_second = -weighted_pair_grad[:, None] * pair_vectors * inv_r[:, None]
output.forces.index_add_(0, first_atom, -force_on_second)
output.forces.index_add_(0, second_atom, force_on_second)
sign = float(self._feature_derivative_sign)
feature.index_add_(0, first_atom, sign * scale * second_state * pair_value)
feature.index_add_(0, second_atom, sign * scale * first_state * pair_value)
return output
[docs]
class ChargeScaledSplinePairTerm(_StateScaledSplinePairTerm):
"""Short-range spline pair correction scaled by fixed local charges."""
_state_field = "atomic_charges"
_feature_name = "charge_potential"
_feature_derivative_sign = 1.0
_block_kind = "charge_twobody"
_label_prefix = "charge_twobody"
[docs]
class CollinearSpinExchangeTerm(_StateScaledSplinePairTerm):
"""Pairwise collinear exchange spline scaled by fixed spin moments."""
_state_field = "atomic_spin_moments"
_feature_name = "spin_effective_field"
_feature_derivative_sign = -1.0
_block_kind = "spin_exchange"
_label_prefix = "spin_exchange"
[docs]
class LocalChargeCoulombTerm(PairTerm):
"""Finite-cutoff softened Coulomb interaction for fixed local charges."""
def __init__(
self,
*,
cutoff: float,
atomic_types: Sequence[int],
active_pairs: Sequence[tuple[int, int]] | None = None,
symmetric: bool = True,
softening: float = 1.0e-6,
scale: float = 1.0,
cutoff_envelope: CutoffEnvelope | str | None = None,
eps: float = 1.0e-12,
) -> None:
"""Store local-charge Coulomb cutoff, screening, and active pair metadata."""
cutoff = float(cutoff)
if not math.isfinite(cutoff) or cutoff <= 0.0:
raise ValueError("`cutoff` must be a finite positive value")
softening = float(softening)
if not math.isfinite(softening) or softening < 0.0:
raise ValueError("`softening` must be finite and non-negative")
scale = float(scale)
if not math.isfinite(scale):
raise ValueError("`scale` must be finite")
eps = float(eps)
if not math.isfinite(eps) or eps <= 0.0:
raise ValueError("`eps` must be a finite positive value")
normalized_atomic_types = _normalized_atomic_types(atomic_types)
super().__init__(cutoff=cutoff, atomic_types=normalized_atomic_types)
self.symmetric = bool(symmetric)
self.softening = softening
self.scale = scale
self.eps = eps
self.cutoff_envelope = normalize_cutoff_envelope(
cutoff_envelope,
cutoff=cutoff,
default_kind="none",
)
pair_categories = _pair_categories(
normalized_atomic_types,
symmetric=self.symmetric,
)
object.__setattr__(self, "_pair_categories", pair_categories)
active_mask = _active_pair_mask(
pair_categories,
active_pairs=active_pairs,
symmetric=self.symmetric,
)
self.register_buffer("active_pair_mask", active_mask, persistent=False)
object.__setattr__(
self,
"_active_pair_indices",
tuple(
index for index, enabled in enumerate(active_mask.tolist()) if enabled
),
)
@property
def input_requirements(self) -> TermInputRequirements:
"""Require a neighbor list and fixed local charge state."""
return TermInputRequirements(
neighbor_list=True,
state_fields=("atomic_charges",),
)
@property
def provides_forces(self) -> bool:
"""Report that this term provides analytic pair forces."""
return True
@property
def optimizer_group(self) -> str | None:
"""Return the common charge/spin optimizer group name."""
return "charge_spin"
@property
def pair_categories(self) -> tuple[tuple[int, int], ...]:
"""Return configured pair categories."""
return self._pair_categories
@property
def active_pair_categories(self) -> tuple[tuple[int, int], ...]:
"""Return active pair categories."""
return tuple(self.pair_categories[index] for index in self._active_pair_indices)
def _check_pair_inputs(self, inputs: UFPInput) -> None:
self.validate_inputs(inputs)
if not self.symmetric and not inputs.neighbor_list.full_list:
raise RuntimeError(
"asymmetric local charge Coulomb terms require a full neighbor list"
)
[docs]
def forward(self, inputs: UFPInput) -> UFPOutput:
"""Evaluate softened Coulomb energy, forces, and charge potential."""
self._check_pair_inputs(inputs)
assert self.atomic_types is not None
assert inputs.atomic_charges is not None
output = empty_atomwise_output(inputs, forces=True)
charge_potential = torch.zeros(
inputs.n_atoms,
dtype=inputs.dtype,
device=inputs.device,
)
output.features["charge_potential"] = charge_potential
if not self._active_pair_indices:
return output
pair_category, handled_mask = _inactive_aware_pair_mask(
inputs,
atomic_types=self.atomic_types,
symmetric=self.symmetric,
active_pair_mask=self.active_pair_mask,
active_pair_indices=self._active_pair_indices,
n_pair_categories=len(self.pair_categories),
)
del pair_category
if not torch.any(handled_mask):
return output
pair_distances = inputs.pair_distances(handled_mask)
support_mask = pair_distances < float(self.cutoff)
if not torch.any(support_mask):
return output
pair_distances = pair_distances[support_mask]
pair_vectors = inputs.pair_vectors(handled_mask)[support_mask]
first_atom, second_atom = inputs.pair_indices(handled_mask)
pair_system_index = inputs.pair_system_index(handled_mask)
first_atom = first_atom[support_mask]
second_atom = second_atom[support_mask]
pair_system_index = pair_system_index[support_mask]
charges = inputs.atomic_charges.to(device=inputs.device, dtype=inputs.dtype)
first_charge = charges.index_select(0, first_atom)
second_charge = charges.index_select(0, second_atom)
charge_product = first_charge * second_charge
softening_sq = float(self.softening) * float(self.softening)
softened_sq = pair_distances.square() + softening_sq
min_softened_sq = float(self.eps) * float(self.eps)
safe_softened_sq = softened_sq.clamp_min(min_softened_sq)
inverse_softened = torch.rsqrt(safe_softened_sq)
inverse_softened_cubed = safe_softened_sq.pow(-1.5)
base = COULOMB_CONSTANT_EV_ANGSTROM * float(self.scale) * charge_product
envelope = self.cutoff_envelope.values(pair_distances).to(
device=inputs.device,
dtype=inputs.dtype,
)
envelope_grad = self.cutoff_envelope.derivatives(pair_distances).to(
device=inputs.device,
dtype=inputs.dtype,
)
scale = pair_weight(inputs)
pair_energy = scale * base * inverse_softened * envelope
softened_grad = -pair_distances * inverse_softened_cubed
pair_grad = (
scale * base * (softened_grad * envelope + inverse_softened * envelope_grad)
)
assert output.energy is not None
assert output.forces is not None
assert output.per_atom_energy is not None
output.energy.index_add_(0, pair_system_index, pair_energy)
per_atom_contribution = 0.5 * pair_energy
output.per_atom_energy.index_add_(0, first_atom, per_atom_contribution)
output.per_atom_energy.index_add_(0, second_atom, per_atom_contribution)
inv_r = torch.where(
pair_distances > self.eps,
pair_distances.reciprocal(),
torch.zeros_like(pair_distances),
)
force_on_second = -pair_grad[:, None] * pair_vectors * inv_r[:, None]
output.forces.index_add_(0, first_atom, -force_on_second)
output.forces.index_add_(0, second_atom, force_on_second)
potential_base = (
scale
* COULOMB_CONSTANT_EV_ANGSTROM
* float(self.scale)
* inverse_softened
* envelope
)
charge_potential.index_add_(0, first_atom, potential_base * second_charge)
charge_potential.index_add_(0, second_atom, potential_base * first_charge)
return output
__all__ = [
"COULOMB_CONSTANT_EV_ANGSTROM",
"ChargeScaledSplinePairTerm",
"ChargeSelfEnergyTerm",
"CollinearSpinExchangeTerm",
"CollinearSpinLandauTerm",
"LocalChargeCoulombTerm",
]