mythos.observables.wasserstein

Wasserstein distance observables.

Classes

WassersteinDistance

Compute the 1D Wasserstein distance between two distributions.

WassersteinDistanceMapped

Compute the 1D Wasserstein distance between two distributions, by key.

Functions

wasserstein_1d(→ mythos.utils.types.Scalar)

Compute the 1D Wasserstein distance between two distributions u and v.

_compute_wasserstein_distance(→ mythos.utils.types.Scalar)

Module Contents

mythos.observables.wasserstein.wasserstein_1d(u: mythos.utils.types.Arr_N, v: mythos.utils.types.Arr_N, u_weights: mythos.utils.types.Arr_N | None = None, v_weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]

Compute the 1D Wasserstein distance between two distributions u and v.

mythos.observables.wasserstein._compute_wasserstein_distance(obs_values: mythos.utils.types.Arr_N, v: mythos.utils.types.Arr_N, weights: mythos.utils.types.Arr_N | None = None, v_weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]
class mythos.observables.wasserstein.WassersteinDistance[source]

Compute the 1D Wasserstein distance between two distributions.

The U distribution is obtained by calling the supplied observable on the trajectory, and the V distribution is provided as a fixed reference distribution. Weights can optionally be provided for the V distribution as a property, and for the U distribution at call time.

The observable, when called on a trajectory, should return a (n_states, n_values) array, where n_states is the number of states in the trajectory. This will be flattened on its way into the Wasserstein distance computation.

The weights supplied to the call method are expected to correspond to states in the trajectory, and will apply to all values in the observable output distribution for that state.

observable

The observable whose output distribution defines U.

v_distribution

The fixed reference distribution V to compare against.

v_weights

Optional weights for the V distribution (should sum to 1).

observable: mythos.observables.base.BaseObservable
v_distribution: mythos.utils.types.Arr_N
v_weights: mythos.utils.types.Arr_N | None = None
__call__(trajectory: mythos.simulators.io.SimulatorTrajectory, weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]

Compute the Wasserstein distance between observable and reference distributions.

class mythos.observables.wasserstein.WassersteinDistanceMapped[source]

Compute the 1D Wasserstein distance between two distributions, by key.

This is a generalization of WassersteinDistance that allows computing distances for multiple observables and reference distributions at once, by key. The input observable is expected to return a dictionary mapping keys to observable outputs, and the v_distribution_map (value corresponding to v_distribution of :class:WassersteinDistanceMapped) and v_weights_map (value corresponding to v_weights of :class:WassersteinDistanceMapped) should have matching keys.

See :class:WassersteinDistance for more information on inputs and calling.

observable

The observable whose output distribution defines U, expected to return a dictionary mapping keys to observable outputs.

v_distribution_map

Dictionary mapping keys to fixed reference distributions V to compare against.

v_weights_map

Optional dictionary mapping keys to weights for the V distributions.

observable: mythos.observables.base.BaseObservable
v_distribution_map: dict[str, mythos.utils.types.Arr_N]
v_weights_map: dict[str, mythos.utils.types.Arr_N | None]
__call__(trajectory: mythos.simulators.io.SimulatorTrajectory, weights: mythos.utils.types.Arr_N | None = None) dict[str, mythos.utils.types.Scalar][source]

Compute the Wasserstein distance between all observable and reference distributions, by key.