Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New API for OT solver (with pre-computed ground cost matrix) #388

Merged
merged 33 commits into from
Dec 15, 2022

Conversation

rflamary
Copy link
Collaborator

@rflamary rflamary commented Jul 18, 2022

Types of changes

I implement the new POT API for general OT solvers. It comes with the new function ot.solve that can be used to solve exact, regularized and unbalanced OT depending on the parameter it receives and returns a new OTResult class that contains all information that can be useful to the user (OT value, OT plan, OT marginals and OT dual potentials).

  • Implement base OTResult class
  • Implement ot.solve function with exact OT and exact unbalanced OT
  • Implement call to all regularized a solvers (sinkhorn, L2)
  • Add TV as type of unbalanced OT and solvers for cross divergence unbalanced (KL + L2)
  • backjend version of ot.partial and ot.smooth
  • Write the documentation for the solve function
  • Add all tests to ensure good code coverage

Motivation and context / Related issue

The API has been discussed with @agramfort @jeanfeydy @hichamjanati and @ncourty and aim at providing a general solver mechanism for the most common OT problems.

import numpy as np
import ot

#%% Data

np.random.seed(42)

xs = np.random.randn(5,2)
xt = np.random.randn(6,2)

M = ot.dist(xs,xt)

a = ot.unif(5)
b = ot.unif(6)


#Solve  exact ot
sol = ot.solve(M)

# get the results
G = sol.plan # OT plan
ot_loss = sol.value # OT objective fucntion value
ot_loss_linear = sol.value_linear # OT value for linera term np.sum(sol.plan*M)
alpha, beta = sol.potentials # dual potentials

# direct plan and loss computation
G = ot.solve(M).plan
ot_loss = ot.solve(M).value

# OT exact with marginals a/b
sol2 = ot.solve(M, a, b)

# regularized OT
sol_rkl = ot.solve(M, a, b, reg=1) # KL regularization
sol_rentropy = ot.solve(M, a, b, reg=1, reg_type='entropy') # enropic reg (Sinkhorn paper) only change the loss
sol_rl2 = ot.solve(M, a, b, reg=1, reg_type='L2')



# Exact unbalanced OT with diferent penalizations
sol_utv = ot.solve(M, a, b, unbalanced=10, unbalanced_type='TV')
sol_ul2 = ot.solve(M, a, b, unbalanced=10, unbalanced_type='L2')
sol_ukl = ot.solve(M, a, b, unbalanced=10, unbalanced_type='KL')


# Unbalanced and regularized OT 
sol_rkl_ukl = ot.solve(M, a, b, reg=10, unbalanced=10) # KL + KL
sol_rl2_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='L2') # L2 + L2
sol_rkl_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='KL', unbalanced_type='L2') # KL + L2
sol_rl2_ukl = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='L2', unbalanced_type='KL') # KL + L2

sol_rentropy_ul2 = ot.solve(M, a, b, reg=10, unbalanced=10, reg_type='entropy', unbalanced_type='L2') # KL + L2

The code has been written in part by @jeanfeydy in an HackMD file used during the discussion.

How has this been tested (if it applies)

New tests for all possible product of parameters

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Jul 18, 2022

Codecov Report

Merging #388 (06e19ed) into master (8490196) will increase coverage by 0.23%.
The diff coverage is 97.95%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #388      +/-   ##
==========================================
+ Coverage   94.02%   94.26%   +0.23%     
==========================================
  Files          22       23       +1     
  Lines        5924     6203     +279     
==========================================
+ Hits         5570     5847     +277     
- Misses        354      356       +2     

@jeanfeydy
Copy link

Hi @rflamary,

Fantastic :-)
I have also started working on implementing this API last month (in the ot_api branch of GeomLoss), and this will be the project of the summer alongside a clean benchmarking platform for OT solvers that follows the structure of ann-benchmarks.

I'd be happy to come and visit you in Saclay in September to synchronize all of this - and we can have a visio call more or less anytime in August if you're not offline.

In any case, have a good summer and see you soon!
Jean

@rflamary
Copy link
Collaborator Author

@jeanfeydy this is great I will also work on that during the summer I think and we definitely want to talk. Especially since I changed the API a little bit and I don't really like having the same OTResult class for traditional OT problem (value+plan+...) and for OT barycenter when the result is a distribution (masses+ support position). August is fine for a virtual meeting and of course we should meetup in saclay at the beginning of the academic year (end of September the beginning will be hectic for me)

@jeanfeydy
Copy link

Ok perfect, see you soon!
(And I agree with the barycenter change, for sure.)

@rflamary rflamary changed the title [WIP] New API for OT solver (with pre-computed ground cost matrix) [WIP] New alpha API for OT solver (with pre-computed ground cost matrix) Jul 20, 2022
@rflamary rflamary changed the title [WIP] New alpha API for OT solver (with pre-computed ground cost matrix) [MRG] New alpha API for OT solver (with pre-computed ground cost matrix) Dec 3, 2022
@rflamary
Copy link
Collaborator Author

rflamary commented Dec 6, 2022

I think its OK now (doc for fyunction is done exemples and example update will be done later when doing the full API V2 release).

@agramfort care for a quick code reveiw?

ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
test/test_solvers.py Show resolved Hide resolved
@rflamary rflamary changed the title [MRG] New alpha API for OT solver (with pre-computed ground cost matrix) [MRG] New API for OT solver (with pre-computed ground cost matrix) Dec 9, 2022
@rflamary rflamary merged commit 0411ea2 into master Dec 15, 2022
@agramfort
Copy link
Collaborator

it would be really cool to have this new API showed early in this page https://pythonot.github.io/quickstart.html

@rflamary
Copy link
Collaborator Author

I agree but it is not ready yet because we need to implement also GW, OT on sample and OT on grid which is a lot of work in addition to the doc...

I am going for a feature/bug 8.3 release shortly (we have many bugs in 8.2) where teh new API is not yet promoted (or only with beta status) and then we go twoard POT 1.0 with a big documentation and exemple revamp.

@rflamary rflamary deleted the api_solve branch March 3, 2023 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants