-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] CO-Optimal Transport solver #447
Conversation
This reverts commit f3d36b2.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #447 +/- ##
==========================================
+ Coverage 94.85% 94.92% +0.06%
==========================================
Files 30 31 +1
Lines 6770 6879 +109
==========================================
+ Hits 6422 6530 +108
- Misses 348 349 +1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushed a few fixes to the Doc
no idea what's happening with the doc here... |
CrcleCI ususally runs in 20 minutes but sometimes when it is busy it is very slow and the doc buid is canceled (here not during an example from this PR). I cannot ask for a new run but the dame doc is running and building on other PR right now.... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @6Ulm we are nearly there, could you please upadte the tests as discussed above?
README.md
Outdated
@@ -303,6 +303,8 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer | |||
|
|||
[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. | |||
|
|||
[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. | |||
[47] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add your ref at the end and use the same number in the doc of the function
# Main function | ||
|
||
if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make tets that check that the error is raised
eps_samp, eps_feat = epsilon, epsilon | ||
else: | ||
if len(epsilon) != 2: | ||
raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again test
alpha_samp, alpha_feat = alpha, alpha | ||
else: | ||
if len(alpha) != 2: | ||
raise ValueError("Alpha must be either a scalar or an indexable object of length 2.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
ny_feat, type_as=Y)) # shape nx_feat, ny_feat | ||
else: | ||
pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"] | ||
duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also add a warmstart in one of teh tests
break | ||
|
||
if verbose: | ||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tets also with verbose
(vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY)) | ||
|
||
if log: | ||
return coot, dict_log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test both
Types of changes
Implementation of CO-Optimal Transport (COOT) [47] in the new file
coot.py
, containing two methods:co_optimal_transport
: returns the sample and feature couplings.co_optimal_transport2
: returns the COOT distance.Both methods can handle:
ot.lp.emd
solver will be used.ot.sinkhorn
solver will be used.The COOT distance outputed by
co_optimal_transport2
is also sub-differentiable with respect to the input matrices and marginal distributions.Motivation and context / Related issue
CO-Optimal Transport is not yet available on POT.
How has this been tested (if it applies)
test_coot.py
.PR checklist