mythos.observables.membrane_melting_temp

Membrane melting temperature observable.

Computes the melting temperature (Tm) of a lipid membrane by fitting a sigmoid to area-per-lipid (APL) vs. temperature data, following the approach from jax-martini. The sigmoid model is:

\[\text{APL}(T) = \text{apl}_0 + c_{pg} \cdot T + \frac{\Delta\text{APL}}{1 + \exp(-k (T - T_m))}\]

The five fit parameters are [apl0, c_p_g, dAPL, k, Tm].

The module provides both standalone functions for sigmoid fitting and a MembraneMeltingTemp observable class that takes a SimulatorTrajectory as input.

Classes

MembraneMeltingTemp

Observable that computes lipid membrane melting temperature.

Functions

calculate_apl(→ jax.numpy.ndarray)

Evaluate the APL sigmoid model at temperature(s) t.

apl_residual(→ jax.numpy.ndarray)

Residual function for least-squares sigmoid fitting.

get_initial_guess(→ jax.numpy.ndarray)

Heuristic initial guess for the sigmoid parameters.

fit_apl_sigmoid(→ jax.numpy.ndarray)

Fit the sigmoid model to APL-vs-temperature data via nonlinear least squares.

compute_membrane_tm(→ float)

Compute the membrane melting temperature from APL-vs-temperature data.

Module Contents

mythos.observables.membrane_melting_temp.calculate_apl(t: jax.numpy.ndarray, apl0: float, c_p_g: float, dAPL: float, k: float, Tm: float) jax.numpy.ndarray[source]

Evaluate the APL sigmoid model at temperature(s) t.

Parameters:
  • t – Temperature(s) in Kelvin.

  • apl0 – Baseline APL (gel phase).

  • c_p_g – Linear temperature coefficient.

  • dAPL – APL jump across the transition.

  • k – Steepness of the sigmoid.

  • Tm – Melting temperature in Kelvin.

Returns:

Predicted APL value(s).

mythos.observables.membrane_melting_temp.apl_residual(coeffs: jax.numpy.ndarray, data: tuple[jax.numpy.ndarray, jax.numpy.ndarray]) jax.numpy.ndarray[source]

Residual function for least-squares sigmoid fitting.

Follows the residual_fun(params, *args) convention expected by jaxopt.LevenbergMarquardt. The data arguments are packed into a single tuple to ensure compatibility with jaxopt’s implicit differentiation.

Parameters:
  • coeffs – Parameter vector [apl0, c_p_g, dAPL, k, Tm].

  • data – Tuple of (sim_apls, sim_temps) where sim_apls are the observed APL values and sim_temps the corresponding temperatures, both of shape (n_temps,).

Returns:

Element-wise residual sim_apls - predicted_apls.

mythos.observables.membrane_melting_temp.get_initial_guess(sim_apls: jax.numpy.ndarray, sim_temps: jax.numpy.ndarray) jax.numpy.ndarray[source]

Heuristic initial guess for the sigmoid parameters.

Parameters:
  • sim_apls – Observed APL values, shape (n_temps,).

  • sim_temps – Corresponding temperatures, shape (n_temps,).

Returns:

Parameter vector [apl0, c_p_g, dAPL, k, Tm].

mythos.observables.membrane_melting_temp.fit_apl_sigmoid(sim_apls: jax.numpy.ndarray, sim_temps: jax.numpy.ndarray, *, implicit_diff: bool = True, maxiter: int = 5000) jax.numpy.ndarray[source]

Fit the sigmoid model to APL-vs-temperature data via nonlinear least squares.

Uses Levenberg-Marquardt, which is more robust than Gauss-Newton for the strongly nonlinear sigmoid model.

Parameters:
  • sim_apls – Observed (or reweighted) APL values, shape (n_temps,).

  • sim_temps – Corresponding temperatures in Kelvin, shape (n_temps,).

  • implicit_diff – Whether to use implicit differentiation through the solver, allowing JAX to back-propagate gradients.

  • maxiter – Maximum number of solver iterations.

Returns:

Fitted parameter vector [apl0, c_p_g, dAPL, k, Tm].

mythos.observables.membrane_melting_temp.compute_membrane_tm(sim_apls: jax.numpy.ndarray, sim_temps: jax.numpy.ndarray, *, implicit_diff: bool = True) float[source]

Compute the membrane melting temperature from APL-vs-temperature data.

Convenience wrapper around fit_apl_sigmoid() that returns just Tm.

Parameters:
  • sim_apls – Observed (or reweighted) APL values, shape (n_temps,).

  • sim_temps – Temperatures in Kelvin, shape (n_temps,).

  • implicit_diff – Whether to use implicit differentiation.

Returns:

Melting temperature in Kelvin.

class mythos.observables.membrane_melting_temp.MembraneMeltingTemp[source]

Observable that computes lipid membrane melting temperature.

Given a concatenated SimulatorTrajectory containing frames from simulations at multiple temperatures (identified via per-frame metadata), this observable:

  1. Computes per-frame area-per-lipid using AreaPerLipid.

  2. Groups frames by temperature using trajectory.temperature.

  3. Computes the weighted expected APL at each temperature (weighted by optional DiffTRe importance-sampling weights).

  4. Fits a sigmoid to APL vs. temperature and returns the melting temperature \(T_m\).

topology

MDAnalysis Universe describing the system topology.

lipid_sel

MDAnalysis selection string for lipid tail atoms (e.g. "name GL1 GL2").

temperatures

Array of simulation temperatures in Kelvin to fit over.

implicit_diff

Whether to use implicit differentiation through the least-squares solver.

temp_rtol

Relative tolerance for grouping frames by temperature. Frames with temperature within this relative tolerance are considered to belong to the same temperature group. Default is 1e-3 (0.1%).

topology: MDAnalysis.Universe
lipid_sel: str
temperatures: jax.numpy.ndarray
implicit_diff: bool = True
temp_rtol: float = 0.001
__call__(trajectory: mythos.simulators.io.SimulatorTrajectory, weights: jax.numpy.ndarray | None = None) float[source]

Compute the membrane melting temperature.

Parameters:
  • trajectory – Concatenated trajectory with per-frame temperature

  • weights – Optional per-frame importance-sampling weights, shape (N,). When None, uniform weights are used (equivalent to an unweighted mean per temperature).

Returns:

Melting temperature in Kelvin.