Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New API ot.solve_sample #563

Merged
merged 32 commits into from
Nov 17, 2023
Merged

[MRG] New API ot.solve_sample #563

merged 32 commits into from
Nov 17, 2023

Conversation

rflamary
Copy link
Collaborator

@rflamary rflamary commented Nov 7, 2023

Types of changes

Implement the first shot as the ot.solve_sample API.

For standard parameters (the same as ot.solve) it is a simple wrapper that pr-compute the cost matrix using the metric parameter and return the same solutions. But this function also provides other approximated OT solvers that can be selected with method and large scale lazy solvers that avoid computation of the full cost matrix with lazy=True

Some examples of use below:

import numpy as np
import ot

n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
x2 = rng.randn(n//2, 2)+5

#%% ot.solve_sample is a wrapper for ot.solve when lazy=False (and method==None)

M = ot.dist(x, x2, metric='sqeuclidean')
sol0  = ot.solve(M)

sol = ot.solve_sample(x,x2)
print(sol.value)

# use anothe metric
sol2 = ot.solve_sample(x,x2, metric='cityblock')
print(sol2.value)
# sol == sol0 for all parameters in ot.solve (juset a wrapper)

#%% other methods

# solve 1D wasserstein in paralel for all dimensions
sol = ot.solve_sample(x,x2, method='1d')
print(sol)
# sol.value return the wassretsein for each dimensions

# Compute the empirical squared Bures wassrestein distance
sol = ot.solve_sample(x,x2, method='gaussian')
print(sol)

# compute factored OT
sol = ot.solve_sample(x,x2, method='factored', rank=10)
print(sol) # sol.plan is returned if lazy=False
print(sol.lazy_plan) # the low rank lazy tensor is always available for factored OT

#%%  Sinkhorn and Lazy Sinkhorn

# Sinkhorn solution
sol0 = ot.solve_sample(x,x2, reg=1)
print(sol0)

# Lazy sinkhorn (O(n) in memory)
sol = ot.solve_sample(x,x2, reg=1, lazy=True)
print(sol)
print(sol.lazy_plan) # the low rank lazy tensor is always available for lazy sinkhorn

Motivation and context / Related issue

How has this been tested (if it applies)

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link

codecov bot commented Nov 7, 2023

Codecov Report

Merging #563 (d3f5bf3) into master (6f4a40d) will increase coverage by 0.07%.
The diff coverage is 99.42%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #563      +/-   ##
==========================================
+ Coverage   96.54%   96.61%   +0.07%     
==========================================
  Files          74       74              
  Lines       14870    15036     +166     
==========================================
+ Hits        14356    14527     +171     
+ Misses        514      509       -5     

@rflamary rflamary changed the title [WIP] New API ot.solve_sample [MRG] New API ot.solve_sample Nov 15, 2023
Copy link
Collaborator

@cedricvincentcuaz cedricvincentcuaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool PR @rflamary, one lazy function to rule them all.
I've made a few comments to help you conclude this PR.

ot/bregman/_empirical.py Outdated Show resolved Hide resolved
ot/bregman/_empirical.py Outdated Show resolved Hide resolved
ot/bregman/_empirical.py Show resolved Hide resolved


def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None,
unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

method missing in the doc.

ot/solvers.py Show resolved Hide resolved
@rflamary rflamary merged commit ef6c3c1 into master Nov 17, 2023
15 of 16 checks passed
@rflamary rflamary deleted the newapi_solve_sample branch February 29, 2024 11:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants