Source code for mythos.observables.triplet_angles

"""Observable for computing angles between atom triplets from a Martini trajectory."""

from collections.abc import Callable

import chex
import jax
import jax.numpy as jnp

from mythos.energy.martini.base import MartiniTopology, get_periodic
from mythos.energy.martini.m2.angle import compute_angle
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N


[docs] def _triplet_angle(centers: jnp.ndarray, triplet: jnp.ndarray, displacement_fn: Callable) -> float: """Compute the angle (in radians) at the central atom of a triplet. The angle is measured at atom ``j`` between rays ``j→i`` and ``j→k``. Args: centers: Positions of all atoms, shape ``(n_atoms, 3)``. triplet: Array of three atom indices ``[i, j, k]``. displacement_fn: Displacement function respecting boundary conditions. Returns: Angle in radians. """ i, j, k = triplet[0], triplet[1], triplet[2] r_ij = displacement_fn(centers[j], centers[i]) r_kj = displacement_fn(centers[j], centers[k]) return compute_angle(r_ij, r_kj)
[docs] @chex.dataclass(frozen=True, kw_only=True) class TripletAngles: """Observable for computing angles between atom triplets for a single angle name. Given a :class:`MartiniTopology` and an angle name, this observable computes the angle at the central atom for all matching triplets across the trajectory. Attributes: topology: The Martini topology containing angle information. angle_name: Angle name string to compute angles for. Has the form ``RESIDUE_BEAD1_BEAD2_BEAD3`` (e.g. ``"DMPC_NC3_PO4_GL1"``). All angles in the topology matching this name will be included in the output. displacement_fn: Factory that, given a box size vector, returns a displacement function respecting periodic boundary conditions. """ topology: MartiniTopology angle_name: str displacement_fn: Callable = get_periodic
[docs] def _matching_triplets(self) -> jnp.ndarray: """Return topology angle rows whose derived name matches the angle name. Returns: Array of shape ``(n_matching, 3)`` with atom-index triplets. Raises: ValueError: If no angles in the topology match the angle name. """ all_names = self.topology.angle_names indices = [i for i, name in enumerate(all_names) if name == self.angle_name] if not indices: raise ValueError( f"No angles matching '{self.angle_name}' found in the topology. " f"Available angle names: {sorted(set(all_names))}" ) return self.topology.angles[jnp.array(indices)]
[docs] def __call__(self, trajectory: SimulatorTrajectory) -> jnp.ndarray: """Compute angles for the requested angle name. Args: trajectory: A :class:`SimulatorTrajectory` whose ``center`` has shape ``(n_states, n_atoms, 3)`` and ``box_size`` has shape ``(n_states, 3)``. Returns: Array of angles (in radians) with shape ``(n_states, n_matching_angles)``. """ triplets = self._matching_triplets() # (n_matching, 3) def _angles_for_state(centers: Arr_N, box: Arr_N, _triplets: Arr_N = triplets) -> Arr_N: disp_fn = self.displacement_fn(box) return jax.vmap(_triplet_angle, in_axes=(None, 0, None))(centers, _triplets, disp_fn) # vmap over states: (n_states, n_atoms, 3) x (n_states, 3) -> (n_states, n_matching) return jax.vmap(_angles_for_state)(trajectory.center, trajectory.box_size)
[docs] @chex.dataclass(frozen=True, kw_only=True) class TripletAnglesMapped: """Observable for computing angles between atom triplets for multiple angle names. Given a :class:`MartiniTopology` and a set of angle names, this observable computes the angle at the central atom for all matching triplets across the trajectory, returning a dictionary keyed by angle name. Attributes: topology: The Martini topology containing angle information. angle_names: Tuple of angle name strings to compute angles for. Each name has the form ``RESIDUE_BEAD1_BEAD2_BEAD3`` (e.g. ``"DMPC_NC3_PO4_GL1"``). All angles in the topology matching a given name will be included in the output. displacement_fn: Factory that, given a box size vector, returns a displacement function respecting periodic boundary conditions. """ topology: MartiniTopology angle_names: tuple[str, ...] displacement_fn: Callable = get_periodic
[docs] def __call__(self, trajectory: SimulatorTrajectory) -> dict[str, jnp.ndarray]: """Compute angles for each requested angle name. Args: trajectory: A :class:`SimulatorTrajectory` whose ``center`` has shape ``(n_states, n_atoms, 3)`` and ``box_size`` has shape ``(n_states, 3)``. Returns: Dictionary mapping each angle name to an array of angles (in radians) with shape ``(n_states, n_matching_angles)``. """ return { angle_name: TripletAngles( topology=self.topology, angle_name=angle_name, displacement_fn=self.displacement_fn, )(trajectory) for angle_name in self.angle_names }