Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] CO-Optimal Transport solver #447

Merged
merged 39 commits into from
Mar 22, 2023
Merged

[MRG] CO-Optimal Transport solver #447

merged 39 commits into from
Mar 22, 2023

Conversation

6Ulm
Copy link
Collaborator

@6Ulm 6Ulm commented Mar 16, 2023

Types of changes

Implementation of CO-Optimal Transport (COOT) [47] in the new file coot.py, containing two methods:

  • co_optimal_transport: returns the sample and feature couplings.
  • co_optimal_transport2: returns the COOT distance.

Both methods can handle:

  • Unregularized COOT ($\varepsilon = 0$). In this case, ot.lp.emd solver will be used.
  • Entropic regularized COOT ($\varepsilon > 0$). In this case, ot.sinkhorn solver will be used.
  • COOT with linear terms for both sample and feature couplings, in the same spirit as fused Gromov-Wasserstein (in which only linear term for sample coupling exists)

The COOT distance outputed by co_optimal_transport2 is also sub-differentiable with respect to the input matrices and marginal distributions.

Motivation and context / Related issue

CO-Optimal Transport is not yet available on POT.

How has this been tested (if it applies)

  • The implementation is tested sucessfully on toy examples. The tests are reported in test_coot.py.
  • Tests on more complex examples are reported in examples.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Mar 16, 2023

Codecov Report

Merging #447 (23ae989) into master (b9ed7b1) will increase coverage by 0.06%.
The diff coverage is 99.08%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #447      +/-   ##
==========================================
+ Coverage   94.85%   94.92%   +0.06%     
==========================================
  Files          30       31       +1     
  Lines        6770     6879     +109     
==========================================
+ Hits         6422     6530     +108     
- Misses        348      349       +1     

Copy link
Collaborator

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed a few fixes to the Doc

@agramfort
Copy link
Collaborator

no idea what's happening with the doc here...

@rflamary
Copy link
Collaborator

CrcleCI ususally runs in 20 minutes but sometimes when it is busy it is very slow and the doc buid is canceled (here not during an example from this PR). I cannot ask for a new run but the dame doc is running and building on other PR right now....

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @6Ulm we are nearly there, could you please upadte the tests as discussed above?

README.md Outdated
@@ -303,6 +303,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations.

[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.
[47] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add your ref at the end and use the same number in the doc of the function

# Main function

if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make tets that check that the error is raised

eps_samp, eps_feat = epsilon, epsilon
else:
if len(epsilon) != 2:
raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again test

alpha_samp, alpha_feat = alpha, alpha
else:
if len(alpha) != 2:
raise ValueError("Alpha must be either a scalar or an indexable object of length 2.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

ny_feat, type_as=Y)) # shape nx_feat, ny_feat
else:
pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"]
duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also add a warmstart in one of teh tests

break

if verbose:
print(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tets also with verbose

(vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY))

if log:
return coot, dict_log
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test both

@rflamary rflamary changed the title [WIP] CO-Optimal Transport solver [MRG] CO-Optimal Transport solver Mar 21, 2023
@rflamary rflamary merged commit 897026e into PythonOT:master Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants