"""Lennard-Jones potential energy function for Martini 2."""
import chex
import jax
import jax.numpy as jnp
from jax_md import space
from typing_extensions import override
from mythos.energy.martini.base import MartiniEnergyConfiguration, MartiniEnergyFunction
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N, Arr_States_3, MatrixSq
LJ_SIGMA_PREFIX = "lj_sigma_"
LJ_EPSILON_PREFIX = "lj_epsilon_"
[docs]
class LJConfiguration(MartiniEnergyConfiguration):
"""Configuration for Martini Lennard-Jones energy function.
All parameters provided must be of the form "lj_sigma_A_B" or "lj_epsilon_A_B",
where A and B are bead types. Pair order is ignored unless both orderings
are provided. It is required that sigma and epsilon parameters are provided
for any bead type pairs present in the system.
Couplings are supported (see :class:`MartiniEnergyConfiguration` for details).
"""
[docs]
@override
def __post_init__(self) -> None:
bead_types = set()
for param in self.params:
if not param.startswith((LJ_SIGMA_PREFIX, LJ_EPSILON_PREFIX)):
raise ValueError(f"Unexpected parameter {param} for LJConfiguration")
bead_types.update(param.split("_")[2:4])
self.bead_types = tuple(sorted(bead_types))
# Construct lookup tables for the values for use in vmapped energy
# calculations. These should be symmetric matrices, but we do not
# explicitly enforce that. At least one of the pair orderings must exist
# or an exception is raised.
def get_param(prefix: str, a: str, b: str) -> float:
param = self.params.get(f"lj_{prefix}_{a}_{b}", self.params.get(f"lj_{prefix}_{b}_{a}"))
if param is None:
raise ValueError(f"Missing LJ {prefix} parameter for pair {a}_{b} ({b}_{a})")
return param
self.sigmas: MatrixSq = jnp.array(
[[get_param("sigma", i, j) for j in self.bead_types] for i in self.bead_types]
)
self.epsilons: MatrixSq = jnp.array(
[[get_param("epsilon", i, j) for j in self.bead_types] for i in self.bead_types]
)
[docs]
def lennard_jones(r: float, eps: float, sigma: float) -> float:
"""Calculate Lennard-Jones potential given distance r, epsilon, and sigma."""
cutoff = 1.1
# calculating the standard LJ potential
v = 4 * eps * ((sigma / r) ** 12 - (sigma / r) ** 6)
# calculating the value of the potential at cutoff
v_c = 4 * eps * ((sigma / cutoff) ** 12 - (sigma / cutoff) ** 6)
# applying the shifting function: V_s(r) = V(r) - V(r_c) for r < r_c, 0 otherwise
return jnp.where(
r < cutoff,
v - v_c,
0.0, # shifting the potential by subtracting V(r_c)
)
[docs]
def pair_lj(
centers: Arr_States_3,
i: int,
j: int,
bonded_mask: MatrixSq,
sigmas: MatrixSq,
epsilons: MatrixSq,
types: Arr_N,
displacement_fn: callable,
) -> float:
"""Calculate LJ energy for a given pair of particles."""
i_type = types[i]
j_type = types[j]
sigma = sigmas[i_type, j_type]
eps = epsilons[i_type, j_type]
r = space.distance(displacement_fn(centers[i], centers[j]))
return lennard_jones(r, eps, sigma) * bonded_mask[i, j] # Mask out bonded pairs
[docs]
@chex.dataclass(frozen=True, kw_only=True)
class LJ(MartiniEnergyFunction):
"""Lennard-Jones potential energy function for Martini 2."""
params: LJConfiguration
[docs]
@override
def __post_init__(self, topology: None = None) -> None:
# Cache a mapping between atom index and its type within sigma/epsilon
# matrices
MartiniEnergyFunction.__post_init__(self)
type_map = {t: i for i, t in enumerate(self.params.bead_types)}
atom_type_map = jnp.array([type_map[t] for t in self.atom_types])
object.__setattr__(self, "_atom_type_map", atom_type_map)
[docs]
def _build_pair_info(self) -> tuple[Arr_N, Arr_N, MatrixSq]:
# Build indices of all non-self unordered pairs to iterate over and
# then construct a mask (inverted) for rejecting bonded pairs based on
# those indices. This method is much more efficient than building pair
# tuples as a concrete array, and does not have to be passed remotely.
triu_i, triu_j = jnp.triu_indices(len(self.atom_types), k=1)
bonded_mask = jnp.ones((len(self.atom_types), len(self.atom_types)), dtype=bool)
bn_i, bn_j = self.bonded_neighbors[:, 0], self.bonded_neighbors[:, 1]
bonded_mask = bonded_mask.at[bn_i, bn_j].set(False)
bonded_mask = bonded_mask.at[bn_j, bn_i].set(False)
return triu_i, triu_j, bonded_mask
[docs]
@override
def map(self, body_sequence: SimulatorTrajectory) -> jnp.ndarray:
# override to enable pre-computation of pair info for efficiency, since
# it does not depend on the trajectory states. We do here instead of in
# __post_init__ as the data structure could be large, we want to avoid
# potential serialization.
bonds_info = self._build_pair_info()
def map_fn(trajectory: SimulatorTrajectory) -> float:
# Apply any configured transform_fn before computing energy,
# while still reusing the precomputed bonds_info.
if self.transform_fn is not None:
trajectory = self.transform_fn(trajectory)
return self.compute_energy(trajectory, _bonds_info=bonds_info)
inner_fun = jax.checkpoint(map_fn) if self.map_checkpoint else map_fn
return jax.lax.map(inner_fun, body_sequence, batch_size=self.map_batch_size)
[docs]
@override
def compute_energy(
self, trajectory: SimulatorTrajectory, _bonds_info: tuple[Arr_N, Arr_N, MatrixSq] | None = None
) -> float:
displacement_fn = self.displacement_fn(trajectory.box_size)
# use precomputed pair info, or create for this state if not provided.
if _bonds_info is None:
_bonds_info = self._build_pair_info()
triu_i, triu_j, bonded_mask = _bonds_info
ljmap = jax.vmap(pair_lj, in_axes=(None, 0, 0, None, None, None, None, None))
return ljmap(
trajectory.center,
triu_i,
triu_j,
bonded_mask,
self.params.sigmas,
self.params.epsilons,
self._atom_type_map,
displacement_fn,
).sum()