Source code for mythos.observables.membrane_melting_temp

r"""Membrane melting temperature observable.

Computes the melting temperature (Tm) of a lipid membrane by fitting a sigmoid
to area-per-lipid (APL) vs. temperature data, following the approach from
jax-martini. The sigmoid model is:

.. math::
    \text{APL}(T) = \text{apl}_0 + c_{pg} \cdot T
        + \frac{\Delta\text{APL}}{1 + \exp(-k (T - T_m))}

The five fit parameters are ``[apl0, c_p_g, dAPL, k, Tm]``.

The module provides both standalone functions for sigmoid fitting and a
:class:`MembraneMeltingTemp` observable class that takes a
:class:`~mythos.simulators.io.SimulatorTrajectory` as input.
"""

import chex
import jax.numpy as jnp
import MDAnalysis
from jaxopt import LevenbergMarquardt

from mythos.observables.area_per_lipid import AreaPerLipid
from mythos.simulators.io import SimulatorTrajectory


[docs] def calculate_apl( t: jnp.ndarray, apl0: float, c_p_g: float, dAPL: float, # noqa: N803 — matches jax-martini naming k: float, Tm: float, # noqa: N803 — matches jax-martini naming ) -> jnp.ndarray: """Evaluate the APL sigmoid model at temperature(s) *t*. Args: t: Temperature(s) in Kelvin. apl0: Baseline APL (gel phase). c_p_g: Linear temperature coefficient. dAPL: APL jump across the transition. k: Steepness of the sigmoid. Tm: Melting temperature in Kelvin. Returns: Predicted APL value(s). """ return apl0 + c_p_g * t + dAPL / (1 + jnp.exp(-k * (t - Tm)))
[docs] def apl_residual( coeffs: jnp.ndarray, data: tuple[jnp.ndarray, jnp.ndarray], ) -> jnp.ndarray: """Residual function for least-squares sigmoid fitting. Follows the ``residual_fun(params, *args)`` convention expected by :class:`jaxopt.LevenbergMarquardt`. The data arguments are packed into a single tuple to ensure compatibility with jaxopt's implicit differentiation. Args: coeffs: Parameter vector ``[apl0, c_p_g, dAPL, k, Tm]``. data: Tuple of ``(sim_apls, sim_temps)`` where *sim_apls* are the observed APL values and *sim_temps* the corresponding temperatures, both of shape ``(n_temps,)``. Returns: Element-wise residual ``sim_apls - predicted_apls``. """ sim_apls, sim_temps = data apl0, c_p_g, dAPL, k, Tm = coeffs # noqa: N806 return sim_apls - calculate_apl(sim_temps, apl0, c_p_g, dAPL, k, Tm)
[docs] def get_initial_guess(sim_apls: jnp.ndarray, sim_temps: jnp.ndarray) -> jnp.ndarray: """Heuristic initial guess for the sigmoid parameters. Args: sim_apls: Observed APL values, shape ``(n_temps,)``. sim_temps: Corresponding temperatures, shape ``(n_temps,)``. Returns: Parameter vector ``[apl0, c_p_g, dAPL, k, Tm]``. """ apl0 = jnp.min(sim_apls) - 0.0001 * 276 c_p_g = 1e-4 dAPL = jnp.max(sim_apls) - jnp.min(sim_apls) # noqa: N806 k = 1.0 Tm = jnp.median(sim_temps) # noqa: N806 return jnp.array([apl0, c_p_g, dAPL, k, Tm])
[docs] def fit_apl_sigmoid( sim_apls: jnp.ndarray, sim_temps: jnp.ndarray, *, implicit_diff: bool = True, maxiter: int = 5000, ) -> jnp.ndarray: """Fit the sigmoid model to APL-vs-temperature data via nonlinear least squares. Uses Levenberg-Marquardt, which is more robust than Gauss-Newton for the strongly nonlinear sigmoid model. Args: sim_apls: Observed (or reweighted) APL values, shape ``(n_temps,)``. sim_temps: Corresponding temperatures in Kelvin, shape ``(n_temps,)``. implicit_diff: Whether to use implicit differentiation through the solver, allowing JAX to back-propagate gradients. maxiter: Maximum number of solver iterations. Returns: Fitted parameter vector ``[apl0, c_p_g, dAPL, k, Tm]``. """ init_guess = get_initial_guess(sim_apls, sim_temps) lm = LevenbergMarquardt( residual_fun=apl_residual, implicit_diff=implicit_diff, maxiter=maxiter, ) res = lm.run(init_guess, (sim_apls, sim_temps)) return res.params
[docs] def compute_membrane_tm( sim_apls: jnp.ndarray, sim_temps: jnp.ndarray, *, implicit_diff: bool = True, ) -> float: """Compute the membrane melting temperature from APL-vs-temperature data. Convenience wrapper around :func:`fit_apl_sigmoid` that returns just Tm. Args: sim_apls: Observed (or reweighted) APL values, shape ``(n_temps,)``. sim_temps: Temperatures in Kelvin, shape ``(n_temps,)``. implicit_diff: Whether to use implicit differentiation. Returns: Melting temperature in Kelvin. """ params = fit_apl_sigmoid(sim_apls, sim_temps, implicit_diff=implicit_diff) return params[4]
[docs] @chex.dataclass(frozen=True, kw_only=True) class MembraneMeltingTemp: """Observable that computes lipid membrane melting temperature. Given a concatenated :class:`SimulatorTrajectory` containing frames from simulations at multiple temperatures (identified via per-frame metadata), this observable: 1. Computes per-frame area-per-lipid using :class:`AreaPerLipid`. 2. Groups frames by temperature using ``trajectory.temperature``. 3. Computes the weighted expected APL at each temperature (weighted by optional DiffTRe importance-sampling weights). 4. Fits a sigmoid to APL vs. temperature and returns the melting temperature :math:`T_m`. Attributes: topology: MDAnalysis Universe describing the system topology. lipid_sel: MDAnalysis selection string for lipid tail atoms (e.g. ``"name GL1 GL2"``). temperatures: Array of simulation temperatures in Kelvin to fit over. implicit_diff: Whether to use implicit differentiation through the least-squares solver. temp_rtol: Relative tolerance for grouping frames by temperature. Frames with temperature within this relative tolerance are considered to belong to the same temperature group. Default is 1e-3 (0.1%). """ topology: MDAnalysis.Universe lipid_sel: str temperatures: jnp.ndarray implicit_diff: bool = True temp_rtol: float = 1e-3
[docs] def __call__( self, trajectory: SimulatorTrajectory, weights: jnp.ndarray | None = None, ) -> float: """Compute the membrane melting temperature. Args: trajectory: Concatenated trajectory with per-frame temperature weights: Optional per-frame importance-sampling weights, shape ``(N,)``. When ``None``, uniform weights are used (equivalent to an unweighted mean per temperature). Returns: Melting temperature in Kelvin. """ # Group frames by temperature and compute APL one temperature at a # time so that only one temperature's worth of frames is in memory. if weights is None: weights = jnp.ones(trajectory.length()) apl_fn = AreaPerLipid(topology=self.topology, lipid_sel=self.lipid_sel) expected_apls = [] for temp in self.temperatures: indices = jnp.where(jnp.abs(trajectory.temperature - temp) < self.temp_rtol * jnp.abs(temp))[0] if indices.size == 0: raise ValueError(f"No frames found for temperature {temp} within relative tolerance {self.temp_rtol}.") batch_apls = apl_fn(trajectory.slice(indices)) batch_weights = weights[indices] weight_sum = jnp.sum(batch_weights) if weight_sum == 0: raise ValueError(f"Sum of weights is zero for temperature {temp}. Cannot compute weighted average APL.") expected_apls.append( jnp.sum(batch_weights * batch_apls) / weight_sum, ) expected_apls = jnp.stack(expected_apls) return compute_membrane_tm(expected_apls, self.temperatures, implicit_diff=self.implicit_diff)