-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers #431
[MRG] Semi-relaxed (fused) gromov-wasserstein divergence and improvements of gromov-wasserstein solvers #431
Conversation
…cuaz/POT into semirelaxed_gromov
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #431 +/- ##
==========================================
+ Coverage 94.70% 94.80% +0.09%
==========================================
Files 24 30 +6
Lines 6608 6752 +144
==========================================
+ Hits 6258 6401 +143
- Misses 350 351 +1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank @cedricvincentcuaz
This is an awesome PR, but it still needs some work ;).
Also I would be interested in seeing a smal benchmark fo computational time for POT 0.8.2 and this PR (in order to check if this is more efficient)
…cuaz/POT into semirelaxed_gromov
…stopping criterions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, all tests pass and the refactorization works well. Will merge shortly. @cedricvincentcuaz coudl you put in teh description some computational times for GW before and after the merge?
Types of changes
Split the ot.gromov.py file into the subdirectory ot/gromov/ with new files:
ot.gromov.__init__.py
ot.gromov._utils.py
(common functions shared by gw-based solvers)ot.gromov._gw.py
(gw solvers)ot.gromov._bregman.py
(entropic gw solvers)ot.gromov._estimators.py
(gw estimators)ot.gromov._dictionary.py
(gw dictionary learning)ot.gromov._semirelaxed.py
(new semi-relaxed gw solvers)Refactoring and new functions in optim.py file:
semirelaxed_cg
: Solve the general regularized and semi-relaxed OT problem with cg.generic_conditional_gradient
: new generic cg solver fed withlp_solver
(emd, sinkhorn, semi-relaxed) as parameter; andline_search
solver as parameter. Wrapping existingcg
andgcg
solvers, plus newsemirelaxed_cg
solver.cg
,gcg
: now callgeneric_conditional_gradient
+ supportline_search
solver as parameter.solve_1d_linesearch_quad
: change solver to avoid dependency to the constant term of the quadratic function that implied an overhead + change operations to avoid type errors (e.g solve Issue An Issue with solve_1d_linesearch_quad Function #442 ).solve_linesearch
: moved to ot.gromov._gw.py as the new functionsolve_gromov_linesearch
+ factor and speed up the previous function used e.g in (f)gw solvers.Modifications of existing functions in ot.gromov.py file and moved into adequate files of the subdirectory ot/gromov/:
in ot.gromov._utils.py :
(new parameter)
init_matrix
,tensor_product
,gwloss, gwggrad
: added backendnx
parameter allowing to avoid repeated calls toot.backend.get_backend
. The parameter is set by default to None, implying a backend test.init_matrix_semirelaxed
: constant tensors for semi-relaxed (F)GW fast computation. Only support square_loss cost function for now, for now raise an issue ifkl_loss
is provided.in ot.gromov._gw.py :
gromov_wasserstein(2)
,fused_gromov_wasserstein(2)
,gromov_barycenters
,fgw_barycenters
: add symmetric (bool) parameter to handle symmetric/asymmetric structure matrices. Default is None implying symmetry tests, can be set to True/False to skip tests. Solvers are corrected to support both cases. + Correct existing feature i.e when 'kl_loss' is given, use the armijo line search function instead of exact one for 'square_loss' cost function + Add new feature to control stopping criterionmax_iter
,tol_rel
andtol_abs
past toot.optim.cg
solver.In ot.gromov._bregman.py :
entropic_gw(2)
: add symmetric (bool) parameter to handle symmetric/asymmetric structure matrices. Default is None implying symmetry tests, can be set to True/False to skip tests. Solvers are corrected to support both cases. + Correct existing feature i.e when 'kl_loss' is given, use the armijo line search function instead of exact one for 'square_loss' cost function. + Add note in the doc on constraint feasibility issues related to Issue Negative Gromov-Wasserstein distance #406In ot.gromov._dictionary.py :
gromov_wasserstein_dictionary_learning
,fused_gromov_wasserstein_dictionary_learning
: adapted to support last new feature of (f)gw solvers. symmetric parameter of these solvers is deduced from the projection parameter of DL solvers.In ot.gromov._semirelaxed.py :
semirelaxed_gromov_wasserstein(2)
,semirelaxed_fused_gromov_wasserstein(2)
: cg solvers for the semi-relaxed (F)GW problem. Armijo line search is left aside for these solvers.solve_semirelaxed_gromov_linesearch
: line search for semi-relaxed (fused) gromov-wasserstein (new) solvers.Motivation and context / Related issue
gromov_wasserstein
between POT 0.8.2 and this PR, we match 50 pairs of graphs, with random euclidean distance matrices as structures, and the same number of nodes varying in {10, 50, 100, 250, 500}. The averaged runtimes for each pair of graphs are reported in the next Table (computed on a Intel(R) Core(TM) i7-4510U CPU @ 2.00GHz):In these settings, the new version of the
gromov_wasserstein
function goes 17.6% to 29.6% faster than its previous version.PR checklist