Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] add the sparsity-constrained optimal transport funtionality and example #459

Merged
merged 14 commits into from
Apr 25, 2023

Conversation

liutianlin0121
Copy link
Contributor

@liutianlin0121 liutianlin0121 commented Apr 14, 2023

Types of changes

Add a new optimal transport functionality and an example.

Motivation and context / Related issue

"Sparsity-constrained optimal transport" is a variant of the optimal transport problem that provides direct control over the number of non-zero values allowed in the optimal plan. This formulation may be of interest to users of POT. For more information, please refer to the following paper: https://openreview.net/forum?id=yHY9NbQJ5BP.

How has this been tested (if it applies)

The tests check that the marginal constraints of the optimal plans are approximately satisfied, and that the sparsity constraints are satisfied.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@liutianlin0121
Copy link
Contributor Author

Question: In the ot/smooth.py module, it seems that only the dual formulation smooth_ot_dual supports multiple backends while the semi-dual formulation smooth_ot_semi_dual does not. I follow this convention in my pull request---only the dual formulation in ot/sparse.py calls get_backend. However, is there an obstacle to making the semi-dual formulation compatible with different backends too?

ot/sparse.py Outdated Show resolved Hide resolved
test/test_sparse.py Outdated Show resolved Hide resolved
ot/sparse.py Outdated Show resolved Hide resolved
@liutianlin0121
Copy link
Contributor Author

Thanks! I updated the pull request.

@codecov
Copy link

codecov bot commented Apr 18, 2023

Codecov Report

Merging #459 (e8bb4e0) into master (2bbfbbb) will decrease coverage by 0.02%.
The diff coverage is 93.65%.

❗ Current head e8bb4e0 differs from pull request most recent head 11e07aa. Consider uploading reports for the commit 11e07aa to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #459      +/-   ##
==========================================
- Coverage   94.92%   94.91%   -0.02%     
==========================================
  Files          31       32       +1     
  Lines        6879     6942      +63     
==========================================
+ Hits         6530     6589      +59     
- Misses        349      353       +4     

@rflamary
Copy link
Collaborator

Hello thanks for the PR, we will do a code reveiw ASAP!

Copy link

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tian Lin! Two other comments below.

ot/sparse.py Outdated Show resolved Hide resolved
ot/sparse.py Outdated Show resolved Hide resolved
@rflamary
Copy link
Collaborator

Also small comment, since you are using functions from ot.smooth and your paper is clearly related I would put the new solvers in ot.smooth and the function projection_sparse_simplex in ot.utils (we might need it somewhere else ;) ).

@liutianlin0121
Copy link
Contributor Author

Hey both, thanks! I incorporated the suggestions.

Note that I only added the check_grad test for the SparsityConstrained regularization. Similar grad checks can be added to other regularization classes, like NegEntropy and SquaredL2. But to do that for SquaredL2, I think I'll need to modify the function projection_simplex so that it can accept 1-dim array. The reason is that the objective function used by check_grad has to be a scalar-valued one. This means that we need to supply a one-dim array X as the input of max_Omega, which currently raises an error because projection_simplex only works for 2-dim arrays. If you think it is necessary, I can modify projection_simplex in a way similar to projection_sparse_simplex and then add gradient checks on SquaredL2.

ot/smooth.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hello @liutianlin0121 ,

Thanks so much for the PR. I still have a few comments before merging but we are neraly there.

examples/plot_OT_1D_smooth.py Show resolved Hide resolved
ot/smooth.py Show resolved Hide resolved
ot/smooth.py Outdated Show resolved Hide resolved
ot/smooth.py Outdated Show resolved Hide resolved
ot/utils.py Show resolved Hide resolved
ot/utils.py Show resolved Hide resolved
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @liutianlin0121 could you please add you PR to the RELEASES.md file with a short description of the new feature? I thank after that we can merge.

@rflamary rflamary merged commit 42a62c1 into PythonOT:master Apr 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants