mythos.observables.wasserstein
Wasserstein distance observables.
Classes
Compute the 1D Wasserstein distance between two distributions. |
|
Compute the 1D Wasserstein distance between two distributions, by key. |
Functions
|
Compute the 1D Wasserstein distance between two distributions u and v. |
|
Module Contents
- mythos.observables.wasserstein.wasserstein_1d(u: mythos.utils.types.Arr_N, v: mythos.utils.types.Arr_N, u_weights: mythos.utils.types.Arr_N | None = None, v_weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]
Compute the 1D Wasserstein distance between two distributions u and v.
- mythos.observables.wasserstein._compute_wasserstein_distance(obs_values: mythos.utils.types.Arr_N, v: mythos.utils.types.Arr_N, weights: mythos.utils.types.Arr_N | None = None, v_weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]
- class mythos.observables.wasserstein.WassersteinDistance[source]
Compute the 1D Wasserstein distance between two distributions.
The U distribution is obtained by calling the supplied observable on the trajectory, and the V distribution is provided as a fixed reference distribution. Weights can optionally be provided for the V distribution as a property, and for the U distribution at call time.
The observable, when called on a trajectory, should return a (n_states, n_values) array, where n_states is the number of states in the trajectory. This will be flattened on its way into the Wasserstein distance computation.
The weights supplied to the call method are expected to correspond to states in the trajectory, and will apply to all values in the observable output distribution for that state.
- observable
The observable whose output distribution defines U.
- v_distribution
The fixed reference distribution V to compare against.
- v_weights
Optional weights for the V distribution (should sum to 1).
- observable: mythos.observables.base.BaseObservable
- v_distribution: mythos.utils.types.Arr_N
- v_weights: mythos.utils.types.Arr_N | None = None
- __call__(trajectory: mythos.simulators.io.SimulatorTrajectory, weights: mythos.utils.types.Arr_N | None = None) mythos.utils.types.Scalar[source]
Compute the Wasserstein distance between observable and reference distributions.
- class mythos.observables.wasserstein.WassersteinDistanceMapped[source]
Compute the 1D Wasserstein distance between two distributions, by key.
This is a generalization of WassersteinDistance that allows computing distances for multiple observables and reference distributions at once, by key. The input observable is expected to return a dictionary mapping keys to observable outputs, and the v_distribution_map (value corresponding to v_distribution of :class:WassersteinDistanceMapped) and v_weights_map (value corresponding to v_weights of :class:WassersteinDistanceMapped) should have matching keys.
See :class:WassersteinDistance for more information on inputs and calling.
- observable
The observable whose output distribution defines U, expected to return a dictionary mapping keys to observable outputs.
- v_distribution_map
Dictionary mapping keys to fixed reference distributions V to compare against.
- v_weights_map
Optional dictionary mapping keys to weights for the V distributions.
- observable: mythos.observables.base.BaseObservable
- v_distribution_map: dict[str, mythos.utils.types.Arr_N]
- v_weights_map: dict[str, mythos.utils.types.Arr_N | None]
- __call__(trajectory: mythos.simulators.io.SimulatorTrajectory, weights: mythos.utils.types.Arr_N | None = None) dict[str, mythos.utils.types.Scalar][source]
Compute the Wasserstein distance between all observable and reference distributions, by key.