Optimization
mythos provides a structured optimization framework for fitting
coarse-grained force field parameters to experimental observables. This page
covers the optimization lifecycle, the available optimizer and objective
classes, and how to wire observables and loss functions into objectives.
For the catalog of available observables and the ObservableLossFn wrapper,
see Observables.
Optimization Lifecycle
The optimization framework is built around four abstractions:
Simulator: Runs a simulation and exposes one or more Observables. See Simulators for available backends.
Observable: Something produced by a
Simulator— a trajectory, scalar, vector, or tensor. Observables are named and matched to objectives by convention.Objective: Takes one or more
Observablesand computes the gradients of a loss function with respect to the parameters being optimized. See Objectives below.Optimizer: Coordinates running
Simulatorsto produce theObservablesneeded byObjectives, aggregates gradients, and applies parameter updates.
A single optimization step proceeds as:
The Optimizer runs the required Simulators (possibly in parallel).
Each Simulator produces its Observables.
The Optimizer passes Observables to the relevant Objectives.
Each Objective computes gradients with respect to the optimizable parameters.
The Optimizer aggregates gradients and applies an update rule (e.g., Adam via optax).
Optimizers
SimpleOptimizer
SimpleOptimizer pairs a single Simulator with a single Objective
and a gradient transformation (e.g., optax.adam). It is the right choice
when fitting parameters against one simulation and one loss function.
from mythos.optimization.optimization import SimpleOptimizer
optimizer = SimpleOptimizer(
objective=my_objective,
simulator=my_simulator,
optimizer=optax.adam(learning_rate=1e-4),
)
optimizer.run(params, n_steps=100)
See mythos.optimization.optimization for the full API.
Examples: simple optimizations notebook
RayOptimizer
RayOptimizer runs multiple Simulators and Objectives in parallel
using Ray. This is useful when:
You have multiple simulations at different conditions (e.g., multiple temperatures).
You have multiple objectives (e.g., fit structural properties and thermodynamic properties simultaneously).
You want to increase sample size by running multiple trajectories in parallel and combining in the objective.
The RayOptimizer requires a user-supplied aggregate_grad_fn to combine
gradients from multiple objectives into a single update:
import jax.tree_util as tree_util
from mythos.optimization.optimization import RayOptimizer
def mean_grads(grads_list):
return tree_util.tree_map(
lambda *gs: sum(gs) / len(gs), *grads_list
)
optimizer = RayOptimizer(
objectives=[obj_1, obj_2],
simulators=[sim_1, sim_2],
optimizer=optax.adam(learning_rate=1e-4),
aggregate_grad_fn=mean_grads,
)
optimizer.run(params, n_steps=100)
See mythos.optimization.optimization for the full API.
For detailed guidance on Ray session setup, resource hints, memory management, and gradient aggregation patterns, see Ray Optimizer.
Examples: advanced optimizations notebook
Objectives
An Objective takes observables and computes gradients. The base
Objective class accepts a grad_or_loss_fn callable that receives the
requested observables and returns (grads, aux).
DiffTRe (Differentiable Trajectory Reweighting)
For simulators that are not differentiable (oxDNA, GROMACS, LAMMPS),
mythos uses the DiffTRe algorithm to estimate gradients
via Boltzmann reweighting of reference trajectories.
The DiffTReObjective extends Objective with:
An
energy_fnused to recompute energies under new parametersBoltzmann weight computation: \(w_i \propto \exp(-\beta (E_{\text{new},i} - E_{\text{ref},i}))\)
Effective sample size tracking (
n_eff) to detect weight collapseAutomatic reference parameter updates when
n_effdrops too low
See mythos.optimization.objective for the full API.
Defining a Loss Function for DiffTRe
Note
There is an important terminology distinction here. In the optimization framework, “observable” has two meanings:
Simulator observables — the raw outputs exposed by a
Simulator(typically aSimulatorTrajectory, i.e. a sequence of rigid-body states). These are what theDiffTReObjectivereceives and passes to your loss function asref_statesandobservables.Observable API objects — classes like
PropellerTwist,BondDistances, etc. (see Observables) that compute physical quantities from a trajectory.
The DiffTRe loss function operates on simulator trajectories (sense 1). Inside it, you can use Observable API objects (sense 2) to compute the physical quantities you want to optimize against.
The DiffTReObjective accepts a grad_or_loss_fn that is called with the
simulator’s trajectory observables. Internally, DiffTRe uses
jax.value_and_grad and calls the grad_or_loss_fn providing positional
arguments using the following signature:
def loss_fn(
ref_states, # SimulatorTrajectory — the reference trajectory
weights, # Boltzmann reweighting weights (per frame)
energy_fn, # energy function with current parameters
opt_params, # current optimization parameters
observables, # list of simulator observables passed through
) -> tuple[float, tuple[Any, Any]]:
"""Return (loss, (measured_value, extra))."""
Here ref_states is a SimulatorTrajectory (rigid-body states from the
simulator), not an Observable API object. The observables list likewise
contains the raw simulator outputs matched by required_observables.
The weights array contains the Boltzmann reweighting factors — these
replace the manual equilibration masking used in the simple optimization case.
Your loss function should use these weights when aggregating values across
frames.
Inside this loss function, you wire up Observable API objects to compute physical quantities from the trajectory, then compare to experimental targets:
import jax.numpy as jnp
from mythos.observables.propeller import PropellerTwist
target_propeller_twist = jnp.array(0.6109) # experimental target
prop_twist_obs = PropellerTwist(
rigid_body_transform_fn=transform_fn,
h_bonded_base_pairs=h_bonded_pairs,
)
def my_loss_fn(ref_states, weights, energy_fn, opt_params, observables):
measured = jnp.sum(prop_twist_obs(ref_states) * weights)
loss = jnp.sqrt(jnp.mean((measured - target_propeller_twist) ** 2))
return loss, (measured, None)
This loss function is then passed to the DiffTReObjective:
from mythos.optimization.objective import DiffTReObjective
objective = DiffTReObjective(
name="propeller_twist",
required_observables=("trajectory.oxDNASimulator.sim1",),
energy_fn=my_energy_fn,
grad_or_loss_fn=my_loss_fn,
n_equilibration_steps=1000,
min_n_eff_factor=0.95,
)
For more complex optimizations — such as fitting to full distributions using
WassersteinDistance, or combining multiple observables in a single
objective — see the
advanced optimization examples.
For the full catalog of available observables and the ObservableLossFn
convenience wrapper, see Observables.