Optimization ============ ``mythos`` provides a structured optimization framework for fitting coarse-grained force field parameters to experimental observables. This page covers the optimization lifecycle, the available optimizer and objective classes, and how to wire observables and loss functions into objectives. For the catalog of available observables and the ``ObservableLossFn`` wrapper, see :doc:`observables`. .. contents:: On this page :local: :depth: 2 .. _optimization-lifecycle: Optimization Lifecycle ---------------------- The optimization framework is built around four abstractions: .. image:: ../_static/mythos_opt_diagram.svg :align: center :width: 80% - **Simulator**: Runs a simulation and exposes one or more **Observables**. See :doc:`simulators` for available backends. - **Observable**: Something produced by a ``Simulator`` — a trajectory, scalar, vector, or tensor. Observables are named and matched to objectives by convention. - **Objective**: Takes one or more ``Observables`` and computes the **gradients** of a loss function with respect to the parameters being optimized. See :ref:`objectives` below. - **Optimizer**: Coordinates running ``Simulators`` to produce the ``Observables`` needed by ``Objectives``, aggregates gradients, and applies parameter updates. A single optimization step proceeds as: 1. The **Optimizer** runs the required **Simulators** (possibly in parallel). 2. Each **Simulator** produces its **Observables**. 3. The **Optimizer** passes **Observables** to the relevant **Objectives**. 4. Each **Objective** computes gradients with respect to the optimizable parameters. 5. The **Optimizer** aggregates gradients and applies an update rule (e.g., Adam via `optax `_). .. _optimizers: Optimizers ---------- SimpleOptimizer ^^^^^^^^^^^^^^^ ``SimpleOptimizer`` pairs a single ``Simulator`` with a single ``Objective`` and a gradient transformation (e.g., ``optax.adam``). It is the right choice when fitting parameters against one simulation and one loss function. .. code-block:: python from mythos.optimization.optimization import SimpleOptimizer optimizer = SimpleOptimizer( objective=my_objective, simulator=my_simulator, optimizer=optax.adam(learning_rate=1e-4), ) optimizer.run(params, n_steps=100) See :doc:`autoapi/mythos/optimization/optimization/index` for the full API. **Examples:** `simple optimizations notebook `_ .. _ray-optimizer: RayOptimizer ^^^^^^^^^^^^ ``RayOptimizer`` runs multiple ``Simulators`` and ``Objectives`` in parallel using `Ray `_. This is useful when: - You have **multiple simulations** at different conditions (e.g., multiple temperatures). - You have **multiple objectives** (e.g., fit structural properties *and* thermodynamic properties simultaneously). - You want to increase sample size by running multiple trajectories in parallel and combining in the objective. The ``RayOptimizer`` requires a user-supplied ``aggregate_grad_fn`` to combine gradients from multiple objectives into a single update: .. code-block:: python import jax.tree_util as tree_util from mythos.optimization.optimization import RayOptimizer def mean_grads(grads_list): return tree_util.tree_map( lambda *gs: sum(gs) / len(gs), *grads_list ) optimizer = RayOptimizer( objectives=[obj_1, obj_2], simulators=[sim_1, sim_2], optimizer=optax.adam(learning_rate=1e-4), aggregate_grad_fn=mean_grads, ) optimizer.run(params, n_steps=100) See :doc:`autoapi/mythos/optimization/optimization/index` for the full API. For detailed guidance on Ray session setup, resource hints, memory management, and gradient aggregation patterns, see :doc:`ray_optimizer`. **Examples:** `advanced optimizations notebook `_ .. _objectives: Objectives ---------- An ``Objective`` takes observables and computes gradients. The base ``Objective`` class accepts a ``grad_or_loss_fn`` callable that receives the requested observables and returns ``(grads, aux)``. .. _difftre: DiffTRe (Differentiable Trajectory Reweighting) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ For simulators that are **not differentiable** (oxDNA, GROMACS, LAMMPS), ``mythos`` uses the `DiffTRe algorithm `_ to estimate gradients via Boltzmann reweighting of reference trajectories. The ``DiffTReObjective`` extends ``Objective`` with: - An ``energy_fn`` used to recompute energies under new parameters - Boltzmann weight computation: :math:`w_i \propto \exp(-\beta (E_{\text{new},i} - E_{\text{ref},i}))` - Effective sample size tracking (``n_eff``) to detect weight collapse - Automatic reference parameter updates when ``n_eff`` drops too low See :doc:`autoapi/mythos/optimization/objective/index` for the full API. Defining a Loss Function for DiffTRe """"""""""""""""""""""""""""""""""""" .. note:: There is an important terminology distinction here. In the optimization framework, **"observable"** has two meanings: 1. **Simulator observables** — the raw outputs exposed by a ``Simulator`` (typically a ``SimulatorTrajectory``, i.e. a sequence of rigid-body states). These are what the ``DiffTReObjective`` receives and passes to your loss function as ``ref_states`` and ``observables``. 2. **Observable API objects** — classes like ``PropellerTwist``, ``BondDistances``, etc. (see :doc:`observables`) that compute physical quantities *from* a trajectory. The DiffTRe loss function operates on simulator trajectories (sense 1). Inside it, you can use Observable API objects (sense 2) to compute the physical quantities you want to optimize against. The ``DiffTReObjective`` accepts a ``grad_or_loss_fn`` that is called with the simulator's trajectory observables. Internally, DiffTRe uses ``jax.value_and_grad`` and calls the ``grad_or_loss_fn`` providing positional arguments using the following signature: .. code-block:: python def loss_fn( ref_states, # SimulatorTrajectory — the reference trajectory weights, # Boltzmann reweighting weights (per frame) energy_fn, # energy function with current parameters opt_params, # current optimization parameters observables, # list of simulator observables passed through ) -> tuple[float, tuple[Any, Any]]: """Return (loss, (measured_value, extra)).""" Here ``ref_states`` is a ``SimulatorTrajectory`` (rigid-body states from the simulator), **not** an Observable API object. The ``observables`` list likewise contains the raw simulator outputs matched by ``required_observables``. The ``weights`` array contains the Boltzmann reweighting factors — these replace the manual equilibration masking used in the simple optimization case. Your loss function should use these weights when aggregating values across frames. Inside this loss function, you wire up Observable API objects to compute physical quantities from the trajectory, then compare to experimental targets: .. code-block:: python import jax.numpy as jnp from mythos.observables.propeller import PropellerTwist target_propeller_twist = jnp.array(0.6109) # experimental target prop_twist_obs = PropellerTwist( rigid_body_transform_fn=transform_fn, h_bonded_base_pairs=h_bonded_pairs, ) def my_loss_fn(ref_states, weights, energy_fn, opt_params, observables): measured = jnp.sum(prop_twist_obs(ref_states) * weights) loss = jnp.sqrt(jnp.mean((measured - target_propeller_twist) ** 2)) return loss, (measured, None) This loss function is then passed to the ``DiffTReObjective``: .. code-block:: python from mythos.optimization.objective import DiffTReObjective objective = DiffTReObjective( name="propeller_twist", required_observables=("trajectory.oxDNASimulator.sim1",), energy_fn=my_energy_fn, grad_or_loss_fn=my_loss_fn, n_equilibration_steps=1000, min_n_eff_factor=0.95, ) For more complex optimizations — such as fitting to full distributions using ``WassersteinDistance``, or combining multiple observables in a single objective — see the `advanced optimization examples `_. For the full catalog of available observables and the ``ObservableLossFn`` convenience wrapper, see :doc:`observables`.