Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update documentation "Why OT" section #220

Merged
merged 13 commits into from
Dec 22, 2020
5 changes: 5 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ html:
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

html-noplot:
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

.PHONY: dirhtml
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
Expand Down
448 changes: 304 additions & 144 deletions docs/source/quickstart.rst

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/barycenters/plot_free_support_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
====================================================
========================================================
2D free support Wasserstein barycenters of distributions
====================================================
========================================================

Illustration of 2D Wasserstein barycenters if distributions are weighted
sum of diracs.
Expand Down
4 changes: 2 additions & 2 deletions examples/domain-adaptation/plot_otda_jcpot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
========================
================================
OT for multi-source target shift
========================
================================

This example introduces a target shift problem with two 2D source and 1 target domain.

Expand Down
2 changes: 1 addition & 1 deletion examples/gromov/plot_barycenter_fgw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================
Plot graphs' barycenter using FGW
Plot graphs barycenter using FGW
=================================

This example illustrates the computation barycenter of labeled graphs using
Expand Down
10 changes: 5 additions & 5 deletions examples/gromov/plot_fgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

##############################################################################
# Generate data
# ---------
# -------------

#%% parameters
# We create two 1D random measures
Expand Down Expand Up @@ -76,7 +76,7 @@

##############################################################################
# Create structure matrices and across-feature distance matrix
# ---------
# ------------------------------------------------------------

#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
Expand All @@ -88,7 +88,7 @@

##############################################################################
# Plot matrices
# ---------
# -------------

#%%
cmap = 'Reds'
Expand Down Expand Up @@ -131,7 +131,7 @@

##############################################################################
# Compute FGW/GW
# ---------
# --------------

#%% Computing FGW and GW
alpha = 1e-3
Expand All @@ -145,7 +145,7 @@

##############################################################################
# Visualize transport matrices
# ---------
# ----------------------------

#%% visu OT matrix
cmap = 'Blues'
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_OT_1D_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

##############################################################################
# Solve Smooth OT
# --------------
# ---------------


#%% Smooth OT with KL regularization
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_OT_2D_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

##############################################################################
# Emprirical Sinkhorn
# ----------------
# -------------------

#%% sinkhorn

Expand Down
16 changes: 9 additions & 7 deletions examples/sliced-wasserstein/plot_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
2D Sliced Wasserstein Distance
==============================

This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
This example illustrates the computation of the sliced Wasserstein Distance as
proposed in [31].

[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45

"""

Expand Down Expand Up @@ -50,9 +52,9 @@
pl.legend(loc=0)
pl.title('Source and target distributions')

###################################################################################
# Compute Sliced Wasserstein distance for different seeds and number of projections
# -----------
###############################################################################
# Sliced Wasserstein distance for different seeds and number of projections
# -------------------------------------------------------------------------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
Expand All @@ -66,9 +68,9 @@
res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)

###################################################################################
###############################################################################
# Plot Sliced Wasserstein Distance
# -----------
# --------------------------------

pl.figure(2)
pl.plot(n_projections_arr, res_mean, label="SWD")
Expand Down
3 changes: 1 addition & 2 deletions examples/unbalanced-partial/plot_UOT_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@

##############################################################################
# Solve Unbalanced Sinkhorn
# --------------

# -------------------------

# Sinkhorn

Expand Down
10 changes: 6 additions & 4 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance

Expand All @@ -46,9 +47,10 @@

__version__ = "0.7.0"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance']
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
'smooth', 'stochastic', 'unbalanced', 'partial']
30 changes: 30 additions & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True

**Choosing a Sinkhorn solver**

By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
fast approximation of the Sinkhorn problem.


Returns
-------
Expand Down Expand Up @@ -175,6 +190,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True

**Choosing a Sinkhorn solver**

By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
fast approximation of the Sinkhorn problem.

Returns
-------
W : (n_hists) ndarray or float
Expand Down