Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] New API for gromov solvers #536

Merged
merged 16 commits into from
Oct 24, 2023
Merged

[MRG] New API for gromov solvers #536

merged 16 commits into from
Oct 24, 2023

Conversation

rflamary
Copy link
Collaborator

@rflamary rflamary commented Oct 18, 2023

Types of changes

in this PR I propose the implementation of the new API for Gromov-Wasserstsein solvers. Similarly to ot.solve we now have a function ot.solve_gromov that can be used as follows

# Gromov Wassrerstein (GW) wity L2 loss
res = ot.solve_gromov(Ca, Cb) # uniform weights
res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights

# GW with KL loss
res = ot.solve_gromov(Ca, Cb, loss='KL') # uniform weights

# Fused Gromov-Wassertsein (FGW)
res = ot.solve_gromov(Ca, Cb, M, alpha=0.5) # uniform weights
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights

# GW and FGW with entropy reg
res = ot.solve_gromov(Ca, Cb, reg=1) # GW
res = ot.solve_gromov(Ca, Cb, M, reg=1, alpha=0.5) # FGW

# Semirelaxed GW and FGW
res=ot.solve_gromov(Ca, Cb, unbalanced_type='semirelaxed') # GW
res=ot.solve_gromov(Ca, Cb, M, unbalanced_type='semirelaxed') # FGW

# partial GW
res=ot.solve_gromov(Ca, Cb,unbalanced=0.8, unbalanced_type='partial') # partial GW with m=0.8 of mass transported

#results can be obtained from OTResult() as
res.plan # OT plan
res.value # optimal loss (all objective with regularization is present)
res.value_quad # quadratic (GW) part of the loss
res.value_linear # linear (W) part of the loss for FGW

Motivation and context / Related issue

How has this been tested (if it applies)

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title [WIP] New API for gromo solvers [WIP] New API for gromov solvers Oct 18, 2023
@codecov
Copy link

codecov bot commented Oct 18, 2023

Codecov Report

Merging #536 (5d78f60) into master (57eda61) will increase coverage by 0.06%.
The diff coverage is 100.00%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #536      +/-   ##
==========================================
+ Coverage   96.31%   96.37%   +0.06%     
==========================================
  Files          67       67              
  Lines       14136    14389     +253     
==========================================
+ Hits        13615    13868     +253     
  Misses        521      521              

@rflamary rflamary changed the title [WIP] New API for gromov solvers [MRG] New API for gromov solvers Oct 20, 2023
Copy link
Collaborator

@cedricvincentcuaz cedricvincentcuaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @rflamary,
thank you for this wonderful API for gromov solvers ! ;)
I've made a few comments to help you with the final details.

T, logv = entropic_fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter,
tol, solver, warmstart, verbose, log=True, **kwargs)

logv['T'] = T

lin_term = nx.sum(T * M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could avoid this overhead by defining both logv['quad_loss'] and logv['lin_loss'] directly in entropic_fused_gromov_wasserstein if log=True

ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Outdated Show resolved Hide resolved
ot/solvers.py Show resolved Hide resolved
ot/solvers.py Show resolved Hide resolved
ot/solvers.py Show resolved Hide resolved
ot/solvers.py Show resolved Hide resolved
ot/gromov/_gw.py Outdated Show resolved Hide resolved
@rflamary rflamary merged commit a9de7a0 into master Oct 24, 2023
15 of 16 checks passed
@rflamary rflamary deleted the api_gromov branch November 23, 2023 09:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants