Source code for mythos.energy.martini.m2.lj

"""Lennard-Jones potential energy function for Martini 2."""

import chex
import jax
import jax.numpy as jnp
from jax_md import space
from typing_extensions import override

from mythos.energy.martini.base import MartiniEnergyConfiguration, MartiniEnergyFunction
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N, Arr_States_3, MatrixSq

LJ_SIGMA_PREFIX = "lj_sigma_"
LJ_EPSILON_PREFIX = "lj_epsilon_"


[docs] class LJConfiguration(MartiniEnergyConfiguration): """Configuration for Martini Lennard-Jones energy function. All parameters provided must be of the form "lj_sigma_A_B" or "lj_epsilon_A_B", where A and B are bead types. Pair order is ignored unless both orderings are provided. It is required that sigma and epsilon parameters are provided for any bead type pairs present in the system. Couplings are supported (see :class:`MartiniEnergyConfiguration` for details). """
[docs] @override def __post_init__(self) -> None: bead_types = set() for param in self.params: if not param.startswith((LJ_SIGMA_PREFIX, LJ_EPSILON_PREFIX)): raise ValueError(f"Unexpected parameter {param} for LJConfiguration") bead_types.update(param.split("_")[2:4]) self.bead_types = tuple(sorted(bead_types)) # Construct lookup tables for the values for use in vmapped energy # calculations. These should be symmetric matrices, but we do not # explicitly enforce that. At least one of the pair orderings must exist # or an exception is raised. def get_param(prefix: str, a: str, b: str) -> float: param = self.params.get(f"lj_{prefix}_{a}_{b}", self.params.get(f"lj_{prefix}_{b}_{a}")) if param is None: raise ValueError(f"Missing LJ {prefix} parameter for pair {a}_{b} ({b}_{a})") return param self.sigmas: MatrixSq = jnp.array( [[get_param("sigma", i, j) for j in self.bead_types] for i in self.bead_types] ) self.epsilons: MatrixSq = jnp.array( [[get_param("epsilon", i, j) for j in self.bead_types] for i in self.bead_types] )
[docs] def lennard_jones(r: float, eps: float, sigma: float) -> float: """Calculate Lennard-Jones potential given distance r, epsilon, and sigma.""" cutoff = 1.1 # calculating the standard LJ potential v = 4 * eps * ((sigma / r) ** 12 - (sigma / r) ** 6) # calculating the value of the potential at cutoff v_c = 4 * eps * ((sigma / cutoff) ** 12 - (sigma / cutoff) ** 6) # applying the shifting function: V_s(r) = V(r) - V(r_c) for r < r_c, 0 otherwise return jnp.where( r < cutoff, v - v_c, 0.0, # shifting the potential by subtracting V(r_c) )
[docs] def pair_lj( centers: Arr_States_3, i: int, j: int, bonded_mask: MatrixSq, sigmas: MatrixSq, epsilons: MatrixSq, types: Arr_N, displacement_fn: callable, ) -> float: """Calculate LJ energy for a given pair of particles.""" i_type = types[i] j_type = types[j] sigma = sigmas[i_type, j_type] eps = epsilons[i_type, j_type] r = space.distance(displacement_fn(centers[i], centers[j])) return lennard_jones(r, eps, sigma) * bonded_mask[i, j] # Mask out bonded pairs
[docs] @chex.dataclass(frozen=True, kw_only=True) class LJ(MartiniEnergyFunction): """Lennard-Jones potential energy function for Martini 2.""" params: LJConfiguration
[docs] @override def __post_init__(self, topology: None = None) -> None: # Cache a mapping between atom index and its type within sigma/epsilon # matrices MartiniEnergyFunction.__post_init__(self) type_map = {t: i for i, t in enumerate(self.params.bead_types)} atom_type_map = jnp.array([type_map[t] for t in self.atom_types]) object.__setattr__(self, "_atom_type_map", atom_type_map)
[docs] def _build_pair_info(self) -> tuple[Arr_N, Arr_N, MatrixSq]: # Build indices of all non-self unordered pairs to iterate over and # then construct a mask (inverted) for rejecting bonded pairs based on # those indices. This method is much more efficient than building pair # tuples as a concrete array, and does not have to be passed remotely. triu_i, triu_j = jnp.triu_indices(len(self.atom_types), k=1) bonded_mask = jnp.ones((len(self.atom_types), len(self.atom_types)), dtype=bool) bn_i, bn_j = self.bonded_neighbors[:, 0], self.bonded_neighbors[:, 1] bonded_mask = bonded_mask.at[bn_i, bn_j].set(False) bonded_mask = bonded_mask.at[bn_j, bn_i].set(False) return triu_i, triu_j, bonded_mask
[docs] @override def map(self, body_sequence: SimulatorTrajectory) -> jnp.ndarray: # override to enable pre-computation of pair info for efficiency, since # it does not depend on the trajectory states. We do here instead of in # __post_init__ as the data structure could be large, we want to avoid # potential serialization. bonds_info = self._build_pair_info() def map_fn(trajectory: SimulatorTrajectory) -> float: # Apply any configured transform_fn before computing energy, # while still reusing the precomputed bonds_info. if self.transform_fn is not None: trajectory = self.transform_fn(trajectory) return self.compute_energy(trajectory, _bonds_info=bonds_info) inner_fun = jax.checkpoint(map_fn) if self.map_checkpoint else map_fn return jax.lax.map(inner_fun, body_sequence, batch_size=self.map_batch_size)
[docs] @override def compute_energy( self, trajectory: SimulatorTrajectory, _bonds_info: tuple[Arr_N, Arr_N, MatrixSq] | None = None ) -> float: displacement_fn = self.displacement_fn(trajectory.box_size) # use precomputed pair info, or create for this state if not provided. if _bonds_info is None: _bonds_info = self._build_pair_info() triu_i, triu_j, bonded_mask = _bonds_info ljmap = jax.vmap(pair_lj, in_axes=(None, 0, 0, None, None, None, None, None)) return ljmap( trajectory.center, triu_i, triu_j, bonded_mask, self.params.sigmas, self.params.epsilons, self._atom_type_map, displacement_fn, ).sum()