Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers #431

Merged
merged 39 commits into from
Mar 9, 2023

Conversation

cedricvincentcuaz
Copy link
Collaborator

@cedricvincentcuaz cedricvincentcuaz commented Feb 2, 2023

Types of changes

Split the ot.gromov.py file into the subdirectory ot/gromov/ with new files:

  • ot.gromov.__init__.py
  • ot.gromov._utils.py (common functions shared by gw-based solvers)
  • ot.gromov._gw.py (gw solvers)
  • ot.gromov._bregman.py (entropic gw solvers)
  • ot.gromov._estimators.py (gw estimators)
  • ot.gromov._dictionary.py (gw dictionary learning)
  • ot.gromov._semirelaxed.py (new semi-relaxed gw solvers)

Refactoring and new functions in optim.py file:

  • semirelaxed_cg : Solve the general regularized and semi-relaxed OT problem with cg.
  • generic_conditional_gradient: new generic cg solver fed with lp_solver (emd, sinkhorn, semi-relaxed) as parameter; and line_search solver as parameter. Wrapping existing cg and gcg solvers, plus new semirelaxed_cg solver.
  • cg, gcg : now call generic_conditional_gradient + support line_search solver as parameter.
  • solve_1d_linesearch_quad: change solver to avoid dependency to the constant term of the quadratic function that implied an overhead + change operations to avoid type errors (e.g solve Issue An Issue with solve_1d_linesearch_quad Function #442 ).
  • (moved and renamed) solve_linesearch : moved to ot.gromov._gw.py as the new function solve_gromov_linesearch+ factor and speed up the previous function used e.g in (f)gw solvers.

Modifications of existing functions in ot.gromov.py file and moved into adequate files of the subdirectory ot/gromov/:

in ot.gromov._utils.py :

  • (new parameter) init_matrix, tensor_product, gwloss, gwggrad : added backend nx parameter allowing to avoid repeated calls to ot.backend.get_backend. The parameter is set by default to None, implying a backend test.

  • init_matrix_semirelaxed : constant tensors for semi-relaxed (F)GW fast computation. Only support square_loss cost function for now, for now raise an issue if kl_loss is provided.

in ot.gromov._gw.py :

  • gromov_wasserstein(2), fused_gromov_wasserstein(2), gromov_barycenters, fgw_barycenters : add symmetric (bool) parameter to handle symmetric/asymmetric structure matrices. Default is None implying symmetry tests, can be set to True/False to skip tests. Solvers are corrected to support both cases. + Correct existing feature i.e when 'kl_loss' is given, use the armijo line search function instead of exact one for 'square_loss' cost function + Add new feature to control stopping criterion max_iter, tol_rel and tol_abs past to ot.optim.cg solver.

In ot.gromov._bregman.py :

  • entropic_gw(2) : add symmetric (bool) parameter to handle symmetric/asymmetric structure matrices. Default is None implying symmetry tests, can be set to True/False to skip tests. Solvers are corrected to support both cases. + Correct existing feature i.e when 'kl_loss' is given, use the armijo line search function instead of exact one for 'square_loss' cost function. + Add note in the doc on constraint feasibility issues related to Issue Negative Gromov-Wasserstein distance #406

In ot.gromov._dictionary.py :

  • gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_dictionary_learning: adapted to support last new feature of (f)gw solvers. symmetric parameter of these solvers is deduced from the projection parameter of DL solvers.

In ot.gromov._semirelaxed.py :

  • semirelaxed_gromov_wasserstein(2), semirelaxed_fused_gromov_wasserstein(2): cg solvers for the semi-relaxed (F)GW problem. Armijo line search is left aside for these solvers.
  • solve_semirelaxed_gromov_linesearch: line search for semi-relaxed (fused) gromov-wasserstein (new) solvers.

Motivation and context / Related issue

  • Checked that existing (F)GW-based tests still worked.
  • Add backend parameter tests.
  • Add symmetry/ asymmetry tests.
  • Add tests for semi-relaxed (F)GW problems.
  • Speed up (f)gw solvers: To perform a small benchmark of gromov_wasserstein between POT 0.8.2 and this PR, we match 50 pairs of graphs, with random euclidean distance matrices as structures, and the same number of nodes varying in {10, 50, 100, 250, 500}. The averaged runtimes for each pair of graphs are reported in the next Table (computed on a Intel(R) Core(TM) i7-4510U CPU @ 2.00GHz):
graph sizes 10 50 100 250 500
new cg (ms) 0.8 2.6 8.9 96 1338
old cg (ms) 1.1 3.4 10.8 124 1901
|new-old|/old (%) 27.3 23.5 17.6 22.6 29.6

In these settings, the new version of the gromov_wasserstein function goes 17.6% to 29.6% faster than its previous version.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Feb 6, 2023

Codecov Report

Merging #431 (fb86e46) into master (263a36f) will increase coverage by 0.09%.
The diff coverage is 95.99%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #431      +/-   ##
==========================================
+ Coverage   94.70%   94.80%   +0.09%     
==========================================
  Files          24       30       +6     
  Lines        6608     6752     +144     
==========================================
+ Hits         6258     6401     +143     
- Misses        350      351       +1     

@rflamary rflamary changed the title Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers [WIP] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers Feb 15, 2023
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank @cedricvincentcuaz

This is an awesome PR, but it still needs some work ;).

Also I would be interested in seeing a smal benchmark fo computational time for POT 0.8.2 and this PR (in order to check if this is more efficient)

examples/gromov/plot_semirelaxed_fgw.py Outdated Show resolved Hide resolved
ot/optim.py Outdated Show resolved Hide resolved
ot/optim.py Outdated Show resolved Hide resolved
ot/optim.py Outdated Show resolved Hide resolved
ot/optim.py Outdated Show resolved Hide resolved
ot/gromov.py Outdated Show resolved Hide resolved
ot/gromov.py Outdated Show resolved Hide resolved
ot/gromov.py Outdated Show resolved Hide resolved
ot/gromov.py Outdated Show resolved Hide resolved
ot/gromov.py Outdated Show resolved Hide resolved
@rflamary rflamary changed the title [WIP] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers [MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers Mar 9, 2023
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, all tests pass and the refactorization works well. Will merge shortly. @cedricvincentcuaz coudl you put in teh description some computational times for GW before and after the merge?

@rflamary rflamary merged commit a5930d3 into PythonOT:master Mar 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants