Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ulrgw #410

Merged
merged 49 commits into from
Sep 8, 2023
Merged

Feature/ulrgw #410

merged 49 commits into from
Sep 8, 2023

Conversation

michalk8
Copy link
Collaborator

@michalk8 michalk8 commented Aug 3, 2023

Add unbalanced low-rank (fused) GW.

TODOs:

  • verify the impl.
  • check initialization
  • fix cost value in the fused case
  • [ ] decide on how to handle/whether to deprecate LR option in GW (would say yes) will be done in a future PR
  • update the docs (e.g., remove mentions of LRSinkhorn)
  • fix primal_cost
  • check balanced case with the new LRGromovWasserstein class
  • update notebooks
  • tests

@michalk8 michalk8 added the enhancement New feature or request label Aug 3, 2023
@michalk8 michalk8 self-assigned this Aug 3, 2023
@codecov
Copy link

codecov bot commented Aug 3, 2023

Codecov Report

Merging #410 (f04757f) into main (21d3627) will decrease coverage by 0.93%.
The diff coverage is 83.33%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #410      +/-   ##
==========================================
- Coverage   91.40%   90.48%   -0.93%     
==========================================
  Files          55       56       +1     
  Lines        5943     6251     +308     
  Branches      866      887      +21     
==========================================
+ Hits         5432     5656     +224     
- Misses        375      453      +78     
- Partials      136      142       +6     
Files Changed Coverage Δ
src/ott/initializers/linear/initializers_lr.py 91.30% <40.00%> (-1.08%) ⬇️
src/ott/solvers/quadratic/gromov_wasserstein_lr.py 81.47% <81.47%> (ø)
src/ott/initializers/quadratic/initializers.py 88.88% <100.00%> (+2.00%) ⬆️
src/ott/math/utils.py 94.54% <100.00%> (ø)
src/ott/solvers/linear/lr_utils.py 100.00% <100.00%> (ø)
src/ott/solvers/linear/sinkhorn_lr.py 97.33% <100.00%> (-1.04%) ⬇️
src/ott/solvers/quadratic/gromov_wasserstein.py 84.32% <100.00%> (-3.18%) ⬇️

... and 2 files with indirect coverage changes

📢 Have feedback on the report? [Share it here](https://about.codecov.io/codecov-pr-comment-feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax).

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Michal! let's chat :)

src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved
src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved
src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved
src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved
src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved


@jax.tree_util.register_pytree_node_class
class LRGromovWasserstein(sinkhorn.Sinkhorn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this supposed to inherite from WasSolver? It's not clear either why there should be a link to Sinkhorn (I would understand LRSinkhorn)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I think it makes sense, since if we were to use WasSolver with LRSinkhornSolver, we'd have:

  • 3 loops (outermost GW loop, outer loop in LR Sinkhorn, Dykstra loop)
  • LRSinkhornSolver would have to be modified to accept a quadratic problem (or have a specific inner solver for unbalanced LR GW, which is more complicated than this)

We definitely need a abstraction for the solvers, this is just temporary; will come up with a better solution to this in the future.

I also think that our current LR GW (balanced) is a bit wrong when it comes to the convergence criterion (we exit the outermost GW loop when the successive costs are close; in the paper it should exist when the errors between succcessive Q/R/g are close [the convergence criterion for the linearized objective]).
This will be fixed in a future PR once the LR GW (balanced and unbalanced) are unified in this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, I am not sure exactly what we're inheriting from Sinkhorn? It seems Sinkhorn is used as a simpler type of WassersteinSolver?

src/ott/solvers/quadratic/gromov_wasserstein_lr.py Outdated Show resolved Hide resolved
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -1,7 +1,6 @@
{
Copy link
Contributor

@marcocuturi marcocuturi Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are using the primal cost, then it's better to rename the titles to "Primal cost of LR Solution" and " ... of Entropic Solution" to avoid the ambiguity. I am also surprised that the plot of low rank is looks so much like a rank 1 or 2 matrix, can you double check?


Reply via ReviewNB

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, but I'd like to chat about the inheritance of LRGromovWasserstein because I am not sure where this is headed! thanks!



@jax.tree_util.register_pytree_node_class
class LRGromovWasserstein(sinkhorn.Sinkhorn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, I am not sure exactly what we're inheriting from Sinkhorn? It seems Sinkhorn is used as a simpler type of WassersteinSolver?

@michalk8
Copy link
Collaborator Author

michalk8 commented Sep 5, 2023

This LGTM, but I'd like to chat about the inheritance of LRGromovWasserstein because I am not sure where this is headed! thanks!

Agreed, we sorely need a better solver hierarchy, let's talk about it! Will just write some tests and merge.

@michalk8 michalk8 merged commit a18c16c into ott-jax:main Sep 8, 2023
11 checks passed
@michalk8 michalk8 deleted the feature/ulrgw branch September 8, 2023 07:46
michalk8 added a commit that referenced this pull request Jun 27, 2024
* Remove low-rank from GromovWasserstein solver

* First skeleton loop

* Add LRGW implementation

* Add ULFGW

* Revert change

* Add a TODO

* Fix `grad_g` in the fused case

* Update docs

* Remove duplicate citation

* Fix cost for the fused case

* Fix bugs in TI

* Remove unused import

* Change way array extraction in LR init works

* Disallow LR in the old GW solver

* Disallow LR in old GW class

* Remove `is_entropic` property

* Use `jnp.linalg.norm`

* Simplify initializers in GW

* Simplify initializer creation for low-rank

* Remove temporary name

* Fix norms

* Fix linkcheck

* Remove old initializers test

* Fix more initializer tests

* Remove `LRQuadraticInitializer`, `reg_ot_cost -> reg_gw_cost`

* `host_callback` -> `io_callback`

* Fix more initializers tests

* Fix more tests

* Remove initializer mention from the docs

* Remove mention of LR initializer

* Start incorporating GWLoss

* Simplify reg GW cost computation

* Finish `primal_cost`

* Don't calculate unbal. grads in balanced case

* Fix `primal_cost` in balanced case

* Update GW LR notebook

* Convert quad problem to LR if possible

* Convert quad problem to LR if possible

* Regenerate GWLR Sinkhorn

* Regenerate `LRSinkhorn`

* [ci skip] Fix linter

* Fix convergence metric

* Undo TODO

* Fix factor

* Regenerate notebooks

* Add tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants