Source code for mythos.observables.wasserstein

"""Wasserstein distance observables."""

import math
from dataclasses import field

import chex
import jax.numpy as jnp

from mythos.observables.base import BaseObservable
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N, Scalar


[docs] def wasserstein_1d(u: Arr_N, v: Arr_N, u_weights: Arr_N | None = None, v_weights: Arr_N | None = None) -> Scalar: """Compute the 1D Wasserstein distance between two distributions u and v.""" u = jnp.asarray(u, dtype=jnp.float64) v = jnp.asarray(v, dtype=jnp.float64) if u_weights is None: u_weights = jnp.full(u.shape, 1.0 / u.size, dtype=jnp.float64) else: u_weights = jnp.asarray(u_weights, dtype=jnp.float64) if v_weights is None: v_weights = jnp.full(v.shape, 1.0 / v.size, dtype=jnp.float64) else: v_weights = jnp.asarray(v_weights, dtype=jnp.float64) if u_weights.shape != u.shape: raise ValueError(f"u_weights must have the same shape as u; got {u_weights.shape} and {u.shape}.") if v_weights.shape != v.shape: raise ValueError(f"v_weights must have the same shape as v; got {v_weights.shape} and {v.shape}.") # Validate that total masses match (within numerical tolerance) if not jnp.isclose(jnp.sum(u_weights), jnp.sum(v_weights), rtol=1e-5, atol=1e-5): raise ValueError( "u_weights and v_weights must sum to the same total mass; " f"got {jnp.sum(u_weights)} and {jnp.sum(v_weights)}." ) # Sort u and v with their weights u_sort_idx = jnp.argsort(u) v_sort_idx = jnp.argsort(v) u = u[u_sort_idx] v = v[v_sort_idx] u_weights = u_weights[u_sort_idx] v_weights = v_weights[v_sort_idx] # Merge all support points all_vals = jnp.concatenate([u, v]) all_weights = jnp.concatenate([u_weights, -v_weights]) sort_idx = jnp.argsort(all_vals) all_vals = all_vals[sort_idx] all_weights = all_weights[sort_idx] # Compute CDF difference over each interval diffs = jnp.cumsum(all_weights) dx = all_vals[1:] - all_vals[:-1] avg_heights = jnp.abs(diffs[:-1]) return jnp.sum(dx * avg_heights)
[docs] def _compute_wasserstein_distance( obs_values: Arr_N, v: Arr_N, weights: Arr_N | None = None, v_weights: Arr_N | None = None ) -> Scalar: obs_shape = obs_values.shape # flatten the observable output if it's not already 1D obs_values = obs_values.flatten() # reshape weights to match flattened obs_values, if provided. Each weight is # expected to correspond to a state in trajectory, thus we need to copy # those into the per-state distribution of obs_values. if weights is not None: n_per_weight = math.prod(obs_shape[1:], start=1) weights = jnp.repeat(weights, n_per_weight) / n_per_weight return wasserstein_1d(obs_values, v, u_weights=weights, v_weights=v_weights)
[docs] @chex.dataclass(frozen=True, kw_only=True) class WassersteinDistance: """Compute the 1D Wasserstein distance between two distributions. The U distribution is obtained by calling the supplied observable on the trajectory, and the V distribution is provided as a fixed reference distribution. Weights can optionally be provided for the V distribution as a property, and for the U distribution at call time. The observable, when called on a trajectory, should return a (n_states, n_values) array, where n_states is the number of states in the trajectory. This will be flattened on its way into the Wasserstein distance computation. The weights supplied to the call method are expected to correspond to states in the trajectory, and will apply to all values in the observable output distribution for that state. Attributes: observable: The observable whose output distribution defines U. v_distribution: The fixed reference distribution V to compare against. v_weights: Optional weights for the V distribution (should sum to 1). """ observable: BaseObservable v_distribution: Arr_N v_weights: Arr_N | None = None
[docs] def __call__(self, trajectory: SimulatorTrajectory, weights: Arr_N | None = None) -> Scalar: """Compute the Wasserstein distance between observable and reference distributions.""" obs_values = self.observable(trajectory) return _compute_wasserstein_distance(obs_values, self.v_distribution, weights=weights, v_weights=self.v_weights)
[docs] @chex.dataclass(frozen=True, kw_only=True) class WassersteinDistanceMapped: """Compute the 1D Wasserstein distance between two distributions, by key. This is a generalization of WassersteinDistance that allows computing distances for multiple observables and reference distributions at once, by key. The input observable is expected to return a dictionary mapping keys to observable outputs, and the v_distribution_map (value corresponding to v_distribution of `:class:WassersteinDistanceMapped`) and v_weights_map (value corresponding to v_weights of `:class:WassersteinDistanceMapped`) should have matching keys. See `:class:WassersteinDistance` for more information on inputs and calling. Attributes: observable: The observable whose output distribution defines U, expected to return a dictionary mapping keys to observable outputs. v_distribution_map: Dictionary mapping keys to fixed reference distributions V to compare against. v_weights_map: Optional dictionary mapping keys to weights for the V distributions. """ observable: BaseObservable v_distribution_map: dict[str, Arr_N] v_weights_map: dict[str, Arr_N | None] = field(default_factory=dict)
[docs] def __call__(self, trajectory: SimulatorTrajectory, weights: Arr_N | None = None) -> dict[str, Scalar]: """Compute the Wasserstein distance between all observable and reference distributions, by key.""" obs_values = self.observable(trajectory) return { key: _compute_wasserstein_distance( obs_values[key], self.v_distribution_map[key], weights=weights, v_weights=self.v_weights_map.get(key) ) for key in self.v_distribution_map }