Source code for mythos.optimization.optimization

"""Runs an optimization loop using Ray actors for objectives and simulators."""

import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import field
from typing import Any

import chex
import jax
import jax.numpy as jnp
import optax
import ray
from ray import ObjectRef as RayRef
from ray.remote_function import RemoteFunction
from typing_extensions import override

from mythos.optimization.objective import Objective, ObjectiveOutput
from mythos.simulators.base import Simulator
from mythos.ui.loggers import logger as jdna_logger
from mythos.utils.helpers import try_to_float
from mythos.utils.scheduler import SchedulerUnit
from mythos.utils.types import Grads, Params

ERR_MISSING_OBJECTIVES = "At least one objective is required."
ERR_MISSING_SIMULATORS = "At least one simulator is required."
ERR_MISSING_AGG_GRAD_FN = "An aggregate gradient function is required."
ERR_MISSING_OPTIMIZER = "An optimizer is required."
# to prevent infinite unresolvable loops in step. The first call may use cached
# observables, so may required rerun of sims. After this, we don't expect any
# new information, so not-ready state after this is an error.
OBJECTIVE_PER_STEP_CALL_LIMIT = 2
LOGGER = logging.getLogger(__name__)


[docs] @chex.dataclass(frozen=True, kw_only=True) class OptimizerState: """State container for optimization loops. This dataclass stores all mutable state needed during optimization, allowing optimizers to work with frozen objective dataclasses. Attributes: observables: Current observable values from simulators. state: Per-objective/simulator state (keyed by name). Both objective and simulator state share this namespace, so names should be unique across objectives and simulators. optimizer_state: Current optax optimizer state. """ observables: dict[str, Any] = field(default_factory=dict) component_state: dict[str, dict[str, Any]] = field(default_factory=dict) optimizer_state: Any | None = None # optax.OptState
[docs] @chex.dataclass(frozen=True, kw_only=True) class OptimizerOutput: """Output container for optimization steps. Attributes: grads: The computed (aggregate) gradients from the optimization step. opt_params: The updated parameters after the optimization step. state: The updated optimizer state after the optimization step. This data structure should be passed back into the next call to step. observables: The logged observables from the optimization step. These are keyed by component name (e.g. objective) and each value should itself be a dict of observable name to value. """ grads: Grads opt_params: Params state: OptimizerState observables: dict[str, dict[str, Any]] = field(default_factory=dict)
[docs] @chex.dataclass(frozen=True, kw_only=True) class Optimizer(ABC): """Abstract base class for optimizers.""" logger: jdna_logger.Logger = field(default_factory=jdna_logger.NullLogger)
[docs] @abstractmethod def step(self, params: Params, state: OptimizerState | None = None) -> OptimizerOutput: """Perform a single optimization step. Args: params: The current parameters. state: The current optimization state. If None, an empty state is initialized. Returns: An optimizer output including params, new state, grads, and observables. """
[docs] def run(self, params: Params, n_steps: int, callback: Callable | None = None) -> OptimizerOutput: """Run the optimization loop for a given number of steps. The callback function, if provided, is called at the end of each step with the signature `callback(optimizer_output: OptimizerOutput, step: int)`. The callback must return a tuple of (OptimizerOutput | None, bool) where the bool indicates whether to continue iteration. If False, the optimization loop will stop early. If the returned OptimizerOutput is not None, it will replace the output of the step function (prior to logging observables). Additionally, if at the end of a step, any gradient contains NaN or Inf values, a RuntimeError is raised to prevent silent failures. Users can either use the callback mechanism or compose their optimizer in such a way to handle NaN/Inf values as desired. Args: params: The initial parameters for optimization. n_steps: The number of optimization steps to run. callback: An optional function to call at the end of each step. Returns: The final optimizer output after running the optimization loop, potentially stopped early based on the callback signal. """ state = None if n_steps < 1: raise ValueError("n_steps must be at least 1.") for step in range(n_steps): output = self.step(params, state) if callback is not None: cb_output, keep_going = callback(optimizer_output=output, step=step) output = cb_output if cb_output is not None else output else: keep_going = True for component, obs in output.observables.items(): for obs_name, value in obs.items(): if (value := try_to_float(value)) is not None: self.logger.log_metric(f"{component}.{obs_name}", value, step=step) if not keep_going: LOGGER.info("Early stopping optimization at step %s based on callback signal.", step) break grad_leaves = jax.tree.leaves(output.grads) if any(jnp.any(~jnp.isfinite(leaf)) for leaf in grad_leaves): raise RuntimeError(f"NaN or Inf detected in gradients at step {step}. Is your learning rate too high?") params = output.opt_params state = output.state return output
# Helper functions for running remote tasks @ray.remote def _simulator_run_fn(simulator: Simulator, params: Params, state: dict[str, Any]) -> tuple[list[RayRef], RayRef]: output = simulator.run(opt_params=params, **state) return *output.observables, output.state @ray.remote def _objective_compute_fn( objective: Objective, obs: dict[str, RayRef], params: Params, state: dict[str, Any] ) -> ObjectiveOutput: # Since ray won't unpack nested refs automatically, we do so since we can # guarantee the shape of the object obs = {k: ray.get(v) for k, v in obs.items()} return objective.calculate(observables=obs, opt_params=params, **state) # This indirection helps with mocking in testing
[docs] def _create_and_run_remote(fun: RemoteFunction, ray_options: dict, *args) -> RayRef | list[RayRef]: return fun.options(**ray_options).remote(*args)
[docs] @chex.dataclass(frozen=True, kw_only=True) class RayOptimizer(Optimizer): """Optimization of a list of objectives using a list of simulators. Parameters: objectives: A list of objectives to optimize. simulators: A list of simulators to use for the optimization. aggregate_grad_fn: A function that takes a list of gradients and aggregates them into a single gradient. The gradients are provided in the same order as the list of objectives. optimizer: An optax optimizer. optimizer_state: The state of the optimizer. logger: A logger to use for the optimization. remote_options_default: Default Ray options to apply to all remote calls. """ objectives: list[Objective] simulators: list[Simulator] aggregate_grad_fn: Callable[[list[Grads]], Grads] optimizer: optax.GradientTransformation remote_options_default: dict[str, Any] = field(default_factory=dict)
[docs] def __post_init__(self) -> None: """Validate the initialization of the Optimization.""" if not self.objectives: raise ValueError(ERR_MISSING_OBJECTIVES) if not self.simulators: raise ValueError(ERR_MISSING_SIMULATORS) if self.aggregate_grad_fn is None: raise ValueError(ERR_MISSING_AGG_GRAD_FN) if self.optimizer is None: raise ValueError(ERR_MISSING_OPTIMIZER) # Check for conflicts in global namespaces that we use for coordination all_names = ( [obj.name for obj in self.objectives] + [sim.name for sim in self.simulators] + [exp for sim in self.simulators for exp in sim.exposes()] ) if len(all_names) != len(set(all_names)): raise ValueError("All objective, simulator, and exposes names must be unique")
[docs] def _get_ray_options(self, unit: SchedulerUnit) -> dict[str, Any]: options = {} if unit_hints := getattr(unit, "scheduler_hints", None): options = unit_hints.to_dict(engine="ray", rewrite_options={"mem_mb": "memory"}) if "memory" in options: options["memory"] = int(options["memory"] * 1024 * 1024) # Ray expects bytes return {**self.remote_options_default, **options}
[docs] def _run_simulator(self, simulator: Simulator, params: Params, **state) -> tuple[list[RayRef], RayRef]: ray_opts = { **self._get_ray_options(simulator), "name": "simulator_run:" + simulator.name, "num_returns": 1 + len(simulator.exposes()), } refs = _create_and_run_remote(_simulator_run_fn, ray_opts, simulator, params, state) return refs[:-1], refs[-1] # observables as a list, state
[docs] def _run_objective(self, objective: Objective, observables: dict[str, RayRef], params: Params, **state) -> RayRef: ray_opts = { **self._get_ray_options(objective), "name": "objective_compute:" + objective.name, } return _create_and_run_remote(_objective_compute_fn, ray_opts, objective, observables, params, state)
[docs] def _wait_remotes(self, refs: list[RayRef]) -> list[RayRef]: ref_list = list(refs) ray.wait(ref_list, fetch_local=False, num_returns=1) # The below is to maximize our chance of getting multiple at once # (for example multiple observables and state from a simulator) ready, _ = ray.wait(ref_list, fetch_local=False, timeout=0.1) return ready
[docs] @override def step(self, params: Params, state: OptimizerState | None = None) -> OptimizerOutput: # noqa: C901, PLR0912 state = state or OptimizerState() state_observables, component_state = state.observables.copy(), state.component_state.copy() obj_lookup = {obj.name: obj for obj in self.objectives} call_count = {obj.name: 0 for obj in self.objectives} sim_lookup = {sim.name: sim for sim in self.simulators} expose_lookup = {exp: sim for sim in self.simulators for exp in sim.exposes()} ref_map, grads_completed, output_observables = {}, {}, {} # schedule all objectives that already have their observables in state while (needed_objectives := set(obj_lookup) - set(grads_completed)) or ref_map: for obj_name in needed_objectives: objective = obj_lookup[obj_name] # skip if we are currently running it if objective.name in ref_map.values(): continue # It is an unresolvable state if we have called the objective # more than twice if call_count[objective.name] > OBJECTIVE_PER_STEP_CALL_LIMIT: raise RuntimeError(f"Objective {objective.name} could not be resolved after multiple attempts.") # If we have all the observables in state, we make an attempt at # the objective. This may return a not ready signal, in which # case observables will be cleared to trigger this logic again. if set(objective.required_observables).issubset(state_observables): obj_observables = {k: state_observables[k] for k in objective.required_observables} obj_state = component_state.get(objective.name, {}) ref = self._run_objective(objective, obj_observables, params, **obj_state) ref_map[ref] = objective.name call_count[objective.name] += 1 # there are simulators running that provide some of what we # need, so we have gone through the scheduling step elif set(objective.required_observables).intersection(ref_map.values()): continue else: needed_sims = {expose_lookup[exp].name for exp in objective.required_observables} # filter out sims we know are running based on state ref for sim_name in needed_sims - set(ref_map.values()): sim = sim_lookup[sim_name] # make sure we aren't waiting on any of the observables # this provides. It may be possible that the ref of # state or some observables become available separately if set(sim.exposes()).intersection(ref_map.values()): continue sim_state = component_state.get(sim.name, {}) refs, md_ref = self._run_simulator(sim, params, **sim_state) for r, exp in zip(refs, sim.exposes(), strict=True): ref_map[r] = exp # noqa: PERF403 dict is instantiated and used elsewhere ref_map[md_ref] = sim.name # wait for anything to finish. We do a second wait without num # returns but non-blocking to gather as many as we can at once ready = self._wait_remotes(ref_map.keys()) for ref in ready: producer = ref_map.pop(ref) if producer in obj_lookup: output = ray.get(ref) component_state[producer] = output.state if output.is_ready: grads_completed[producer] = output.grads output_observables[producer] = output.observables else: # remove the needs from the state observables so the # above loop check will schedule the providing simulator state_observables = {k: v for k, v in state.observables.items() if k not in output.needs_update} elif producer in expose_lookup: state_observables[producer] = ref else: # finally it must be simulator state component_state[producer] = ray.get(ref) # return the grads of the objectives in the order they were provided grads = self.aggregate_grad_fn([grads_completed[obj.name] for obj in self.objectives]) opt_state = state.optimizer_state or self.optimizer.init(params) updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) return OptimizerOutput( opt_params=new_params, state=state.replace( optimizer_state=opt_state, component_state=component_state, observables=state_observables, ), grads=grads, observables=output_observables, )
[docs] @chex.dataclass(frozen=True) class SimpleOptimizer(Optimizer): """A simple optimizer that uses a single objective and simulator. This optimizer manages the state for a frozen Objective dataclass, passing observables and state through the compute method. State is managed via OptimizationState which is passed in and out. """ objective: Objective simulator: Simulator optimizer: optax.GradientTransformation
[docs] @override def step(self, params: Params, state: OptimizerState | None = None) -> OptimizerOutput: state = state or OptimizerState() obj_state = state.component_state.get(self.objective.name, {}) sim_state = state.component_state.get(self.simulator.name, {}) obj_output = None if state.observables: obj_output = self.objective.calculate(state.observables, opt_params=params, **obj_state) obj_state = obj_output.state if obj_output is None or not obj_output.is_ready: sim_output = self.simulator.run(params, **sim_state) sim_state = sim_output.state exposes = self.simulator.exposes() state = state.replace(observables=dict(zip(exposes, sim_output.observables, strict=True))) # Try again with updated observables obj_output = self.objective.calculate(state.observables, opt_params=params, **obj_state) obj_state = obj_output.state if not obj_output.is_ready: # this should be an impossible state, could end up in infinite loop raise ValueError("Objective readiness check failed after simulation run.") grads = obj_output.grads opt_state = state.optimizer_state or self.optimizer.init(params) updates, opt_state = self.optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) # should this be filtered? also be allowed to return filtered from # simulators? output_observables = {self.objective.name: obj_output.observables} return OptimizerOutput( opt_params=new_params, state=state.replace( optimizer_state=opt_state, component_state={ **state.component_state, self.objective.name: obj_state, self.simulator.name: sim_state, }, ), grads=grads, observables=output_observables, )