Source code for mythos.observables.bond_distances

"""Observable for computing bond distances from a Martini trajectory."""

from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp
from jax_md import space

from mythos.energy.martini.base import MartiniTopology, get_periodic
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N


[docs] def _bond_distance(centers: jnp.ndarray, pair: jnp.ndarray, displacement_fn: Callable) -> float: """Compute the distance between a bonded pair of atoms.""" return space.distance(displacement_fn(centers[pair[0]], centers[pair[1]]))
[docs] @chex.dataclass(frozen=True, kw_only=True) class BondDistances: """Observable for computing bond distances for a single bond name. Given a :class:`MartiniTopology` and a bond name, this observable computes pairwise distances for all matching bonds across the trajectory. Attributes: topology: The Martini topology containing bond information. bond_name: Bond name string to compute distances for. Has the form ``RESIDUE_BEAD1_BEAD2`` (e.g. ``"DMPC_GL1_GL2"``). All bonds in the topology matching this name will be included in the output. displacement_fn: Factory that, given a box size vector, returns a displacement function respecting periodic boundary conditions. """ topology: MartiniTopology bond_name: str displacement_fn: Callable = get_periodic
[docs] def _matching_pairs(self) -> jnp.ndarray: all_names = self.topology.bond_names indices = [i for i, name in enumerate(all_names) if name == self.bond_name] if not indices: raise ValueError( f"No bonds matching '{self.bond_name}' found in the topology. " f"Available bond names: {sorted(set(all_names))}" ) return self.topology.bonded_neighbors[jnp.array(indices)]
[docs] def __call__(self, trajectory: SimulatorTrajectory) -> jnp.ndarray: """Compute bond distances for the requested bond name. Args: trajectory: A :class:`SimulatorTrajectory` whose ``center`` has shape ``(n_states, n_atoms, 3)`` and ``box_size`` has shape ``(n_states, 3)``. Returns: Distance array of shape ``(n_states, n_matching_bonds)``. """ pairs = self._matching_pairs() def _distances_for_state(centers: Arr_N, box: Arr_N, _pairs: Arr_N = pairs) -> Arr_N: disp_fn = self.displacement_fn(box) return jax.vmap(_bond_distance, in_axes=(None, 0, None))(centers, _pairs, disp_fn) # vmap over states: (n_states, n_atoms, 3) x (n_states, 3) -> (n_states, n_matching) return jax.vmap(_distances_for_state)(trajectory.center, trajectory.box_size)
[docs] @chex.dataclass(frozen=True, kw_only=True) class BondDistancesMapped: """Observable for computing bond distances for multiple bond names. Given a :class:`MartiniTopology` and a set of bond names, this observable computes pairwise distances for all matching bonds across the trajectory, returning a dictionary keyed by bond name. Attributes: topology: The Martini topology containing bond information. bond_names: Tuple of bond name strings to compute distances for. Each name has the form ``RESIDUE_BEAD1_BEAD2`` (e.g. ``"DMPC_GL1_GL2"``). All bonds in the topology matching a given name will be included in the output. displacement_fn: Factory that, given a box size vector, returns a displacement function respecting periodic boundary conditions. """ topology: MartiniTopology bond_names: tuple[str, ...] displacement_fn: Callable = get_periodic
[docs] def __call__(self, trajectory: SimulatorTrajectory) -> dict[str, jnp.ndarray]: """Compute bond distances for each requested bond name. Args: trajectory: A :class:`SimulatorTrajectory` whose ``center`` has shape ``(n_states, n_atoms, 3)`` and ``box_size`` has shape ``(n_states, 3)``. Returns: Dictionary mapping each bond name to a distance array of shape ``(n_states, n_matching_bonds)``. """ return { bond_name: BondDistances( topology=self.topology, bond_name=bond_name, displacement_fn=self.displacement_fn, )(trajectory) for bond_name in self.bond_names }