Ray Optimizer
The RayOptimizer runs multiple simulators and objectives in parallel (on a
local machine or a cluster of remote machines) using Ray.
This page covers practical recommendations for getting the most out of it.
For the basic API and gradient aggregation example, see RayOptimizer on the Optimization page.
Initializing a Ray Session
Ray will automatically connect to an existing cluster or start a local session
when it encounters the first remote call. To pass options, your script should
call ray.init() as appropriate prior to running the optimizer (or any remote
calls).
JAX environment variables
Ray workers are separate processes. Any jax.config calls you make in the
driver process (e.g., jax.config.update("jax_enable_x64", True)) have no
effect inside workers. You must either start your workers with the desired
environment variables set, or pass them through runtime_env environment
variables:
import ray
ray.init(
runtime_env={
"env_vars": {
"JAX_ENABLE_X64": "True", # 64-bit precision
"JAX_PLATFORM_NAME": "cpu", # force CPU (optional)
},
},
)
Other useful ray.init options
ray.init(
num_cpus=32, # limit visible CPUs
ignore_reinit_error=True, # safe to call ray.init() multiple times
log_to_driver=True, # forward worker logs to the driver
)
Resource Hints
Both Simulator and Objective inherit from SchedulerUnit, which
accepts an optional generic scheduler_hints parameter. The RayOptimizer
translates these hints into Ray task options when dispatching remote calls.
from mythos.utils.scheduler import SchedulerHints
simulator = oxdna.oxDNASimulator(
...,
scheduler_hints=SchedulerHints(
num_cpus=4, # Sets OMP_NUM_THREADS=4, a commonly used environment variable for controlling CPU parallelism
num_gpus=0, # For NVIDIA devices, sets CUDA_VISIBLE_DEVICES to partition GPUs for fulfilling requests
mem_mb=8192, # 8 GB — converted to bytes for Ray
max_retries=2,
),
)
Hints are generic, and the RayOptimizer translates several common fields
into ray options. Those ray-specific options that are not listed can be passed
via the custom field specifying “ray” as the engine:
Field |
Type |
Description |
|---|---|---|
|
|
Number of CPUs to reserve for this task. |
|
|
Number of GPUs. Fractional values (e.g. |
|
|
Memory reservation in megabytes. Converted to bytes for Ray. |
|
|
Number of retry attempts on task failure. |
|
|
Engine-specific options. For Ray, use |
remote_options_default on the RayOptimizer itself sets baseline Ray
options that apply to all tasks. Per-unit scheduler_hints override
these defaults:
optimizer = RayOptimizer(
...,
remote_options_default={"num_cpus": 2}, # default for all tasks
)
Creating Multiple Simulator Instances
Use Simulator.create_n() to create multiple instances of the same
simulator with unique names. This is the typical pattern for running
multiple independent trajectories:
simulators = oxdna.oxDNASimulator.create_n(
n=8,
name="oxdna-sim",
input_dir="data/templates/simple-helix",
sim_type=jdna_types.oxDNASimulatorType.DNA1,
energy_configs=energy_fn_configs,
scheduler_hints=SchedulerHints(num_cpus=4, mem_mb=4096),
)
Each simulator gets a name like oxdna-sim.0, oxdna-sim.1, etc., and
each exposes its own uniquely qualified observables
(e.g., trajectory.oxDNASimulator.oxdna-sim.0). Wire these into your
objectives via required_observables.
Note
Names across simulators and objectives must be unique. These names are used for routing observables in addition to tagging the task in ray for observability. The optimizer will raise an error if it detects any naming conflicts.
Memory Considerations
When combining trajectories from many parallel simulators into a single objective, memory can become a bottleneck. Here are some strategies:
Trajectory concatenation
DiffTReObjective concatenates all required trajectories into a single
array using SimulatorTrajectory.concat() before computing reference
energies. If you have N simulators each producing T frames of M
particles, the concatenated trajectory has shape (N × T, M, ...). This
can easily exceed available memory for large systems.
Recommendations:
Reduce frame count per simulator rather than reducing the number of simulators. Shorter trajectories are cheaper to concatenate and reweight, while still benefiting from parallel sampling.
Set ``mem_mb`` on both simulators and objectives via
scheduler_hintsso Ray can schedule tasks onto nodes with sufficient memory.
Gradient computation
DiffTRe computes gradients by evaluating the energy function on every frame of the concatenated trajectory. For large systems, this is the most memory-intensive step because JAX must hold the full computation graph for autodiff.
If you encounter out-of-memory errors during gradient computation:
Use fewer frames (shorter simulations or sub-sampling).
Set
JAX_PLATFORM_NAME=cputo avoid GPU memory limits for the reweighting step, while still using GPU for the simulation binary.Consider splitting objectives across more simulators with fewer frames each, and aggregating gradients via
aggregate_grad_fn.Set map_checkpoint to
Trueand tune map_batch_size on the energy function to enable gradient checkpointing during the traced loss computation.
Gradient Aggregation
aggregate_grad_fn receives a list of gradient pytrees — one per objective,
in the same order as the objectives list — and must return a single pytree
of the same structure.
Common patterns:
Mean (equal weighting):
import operator
import jax
def tree_mean(grads):
summed = jax.tree.map(operator.add, *grads)
return jax.tree.map(lambda x: x / len(grads), summed)
Weighted sum:
def weighted_grads(grads, weights=(0.7, 0.3)):
scaled = [jax.tree.map(lambda g: g * w, g) for g, w in zip(grads, weights)]
return jax.tree.map(operator.add, *scaled)
Single objective (passthrough):
aggregate_grad_fn = lambda grads: grads[0]
Execution Model
Understanding how RayOptimizer dispatches work can help with debugging
and performance tuning.
Simulators and objectives are dispatched as Ray tasks (stateless remote function calls), not Ray actors.
Each simulator task returns one
ObjectRefper exposed observable, plus one for the simulator state. These refs are passed by reference to objective tasks — data is not copied through the driver.Objectives resolve their input refs inside the worker via
ray.get().The optimizer uses a reactive scheduling loop: it checks which objectives have all their required observables available, dispatches those that are ready, and waits for any result to come back before checking again.
An objective may signal that it is not ready (e.g., it needs fresh simulation data). The optimizer will re-run the necessary simulators and retry, up to a maximum of 2 attempts per step.
Callbacks
optimizer.run() accepts an optional callback function that is called
after each step:
def my_callback(optimizer_output, step):
for component, obs in optimizer_output.observables.items():
for name, value in obs.items():
print(f"step {step} | {component}.{name} = {value}")
continue_running = True
return None, continue_running # (modified_output | None, keep_going)
optimizer.run(params, n_steps=100, callback=my_callback)
Return (None, False) to trigger early stopping.
NaN detection
The optimizer automatically checks for NaN/Inf values in gradients after
each step and raises RuntimeError if any are found. This typically
indicates a learning rate that is too high or a numerical precision issue
(see the JAX_ENABLE_X64 note above).
Observability and Debugging
Ray’s dashboard is a powerful tool for monitoring task execution, resource
usage, and debugging. mythos declares a dependency on ray, which doesn’t
automatically install the dashboard components. Use ray[default] or
ray[dashboard] to get the full experience:
pip install "ray[default]"
When using the dashboard, it is also important to start the cluster separately:
ray start --head
Your application will automatically connect to the running cluster if your
application starts on the same machine. If you are running on a different
machine, you can specify the address of the head node in ray.init():
ray.init(address="<head-node-ip>:6379")
For more details on using the dashboard and configuring Ray clusters, see the Ray documentation.
When running on a Slurm cluster (see also Running on Slurm HPC Systems) with ray
symmetric-run, the head node is started automatically — no need to run ray
start --head separately. All that is required is having ray[default]
installed. The dashboard will be available on port 8265 of the head node.
Since HPC login nodes typically don’t expose compute-node ports directly, use SSH port forwarding to access the dashboard from your local machine:
# From your local machine, forward through the login node to the head node:
ssh -L 8265:<head-node>:8265 <user>@<login-node>
Replace <head-node> with the hostname of the first node in your Slurm
allocation (the one printed as “IP Head” in the sbatch script). Once
connected, open http://localhost:8265 in your browser.
See Running on Slurm HPC Systems for the full Slurm setup guide.