-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Efficient Discrete Multi Marginal Optimal Transport #454
Conversation
… ot, build failed need to fix
Thanks for the PR. We will do a code reveiw s soon as possible. |
Thank you for considering! |
Dear POT Team, I hope you're well. Just a quick reminder about our group's pull request (#454) from two weeks ago. I understand the team's been busy, but we'd appreciate your feedback when you have a moment. If you need clarification or have questions, feel free to reach out. We are eager to make any necessary adjustments! Best regards, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello and thanks for the PR,
I did a quick code review. The PR needs major changes so that I can better understand what the new solvers are doing (you cannot ask people to read the paper to use a function, it should be described in the documentation). It is a bit unclear to me if you provide an EMD, multimarginal solver, a barycenter estimator (especially since you function takes only one array as input) so please clarify this in the code and maybe the PR description.
What needs to be clarified and changed:
- Move all solvers to a submodule better named more clear such as
ot.lp.discrete_emd
. - Please respect POT API: use M for ground loss, a,A for distributions (on the simplex) and all parameters names. All function names can be long but must describe precisely what the function does. For instance if a function computes a barycenter then it should be similarly named to other solvers and be called the same way.
- Add to the documentation of all function, the optimization problems solved in math environment using same names as other POT functions.
- Take into account all the small comments below.
I know this is a lot of work but we need to ensure that the code fits well in POT and is easy to use and understand by users/conributors.
docs/source/all.rst
Outdated
@@ -19,6 +19,7 @@ API and modules | |||
coot | |||
da | |||
datasets | |||
demd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not OK with demd
wich is too short and not clear enough for a module name. This is also the case with the main functions, not everyone has reda the paper and it sould be clear from the function name what it is doing.
Due to historical reasons ot.emd
is the exact disrete Ot solver thaat is very general and not only for EMD so we need to find another name for the new solvers in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new solvers should be in ot.lp.discrete_emd
or something else more descriptive
examples/others/plot_demd_1d.py
Outdated
|
||
|
||
def lp_1d_bary(data, M, n, d): | ||
A = np.vstack(data).T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the example should not need to transpose the data. it means that the API for the implemented function is not good (it should retrun the smae thing as ot.lp.barycenter)
examples/others/plot_demd_1d.py
Outdated
print('') | ||
print('D-EMD Algorithm:') | ||
ot.tic() | ||
demd_obj = ot.demd(np.vstack(data), n, d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the barycener? Objective is nice but the barycenetr shoudl be rtruned
examples/others/plot_demd_1d.py
Outdated
return ns, lp_times, demd_times | ||
|
||
|
||
ns, lp_times, demd_times = increasing_bins() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only plot thr time? pleade also plot the barycenter.
|
||
# data, M = getData(n, d, 'uniform') | ||
data, M = getData(n, d, 'skewedGauss') | ||
data = np.vstack(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plot the data
ot/demd.py
Outdated
|
||
dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)]) | ||
|
||
return {'x': xx, 'primal objective': obj, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the output of the function should be of the same type as the input of the function
ot/demd.py
Outdated
except Exception: | ||
pass | ||
|
||
dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use
ot/demd.py
Outdated
|
||
def demd(x, d, n, return_dual_vars=False): | ||
r""" | ||
Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
desctibe mroe precisely what you are solkving. is it a barycenter?
ot/demd.py
Outdated
'dual': dual, 'dual objective': dualobj} | ||
|
||
|
||
def demd(x, d, n, return_dual_vars=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad naming, at least discrete_emd
or discrete_emd2
if teh function return emd without the plan . if you give the function an empirical distribution then you shoud also put it in the name.
Finally emd is computed beteen twoi distibutions so why is there only on naumpy arrya here?
ot/demd.py
Outdated
`f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)` | ||
x : ndarray, shape (d, n) | ||
The initial point for the optimization algorithm. | ||
d : int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those shapes can be infered from x they should not be passed as parameters
Dear POT team, Thank you for thoughtful and detailed code review. We apologize for the conflicts and problems right now. We will resolve them one by one and make sure it fits well in POT as well as contributors. We appreciate your time and effort. Thank you. Best |
Dear POT team, I hope this message finds you well. I have recently completed work on the bug fix for workflow checks, and I believe it is now ready for review. I understand everyone has busy schedules, but I would appreciate if you could review this pull request and approve a workflow check at your convenience. Please feel free to raise any questions or concerns. Thank you in advance for your time and patient. Best regards, |
We will do that, thanks for your work it is appreciated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello,
Thanks for all the work and the major modifications in your code, it is now in my opinion in a much better shape and we are approaching a merge.
Still I have a few important modifications now that I better understand what your proposed solvers actually do. I know it can get tiresome but POT is becoming very large and we need to be very careful with positioning new API and solvers.
My understanding is that they are MMOT solvers for marginal distribution that have a support on a regular 1D grid and using a specific ground metric (max-min that has monge property) that corresponds to the absolute value loss/EMD ground metric with two marginals.
This is very interesting and I want it in POT but it also means that we cannot use the very general names of the method that you chose in your published paper because it would lead to much confusion for POT users. So I proposed more precise names below that describe I think better what is solved and done. I provide more specific comments below (for instance having mathematical object follow the name of function parameters).
On a more scientific question I am surprised by the shape or your "barycenter" that is actually from my understanding a convergence point for the minimization of your MMOT formulation. It seems to be very similar to L2 (average) barycenters especially compared to the LP solution. Do you have an intuition why?
examples/others/plot_d-mmot.py
Outdated
# Compare Barycenters in both methods | ||
# --------- | ||
pl.figure(1, figsize=(6.4, 3)) | ||
for i in range(len(barys)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shoumd compare it to the l2 (np.mean) barycenter because your barycenter looks very similar
examples/others/plot_d-mmot.py
Outdated
# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d) | ||
barys, log = ot.lp.discrete_mmot_converge( | ||
A, niters=3000, lr=0.000002, log=True) | ||
dmmot_obj = log['primal objective'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both the objective value ad the norm of the graient increase at the end which is very surprising since it is supposed to be a gardient decsnet no?
examples/others/plot_d-mmot.py
Outdated
# values cannot be compared. | ||
|
||
# Perform gradient descent optimization using the d-MMOT method. | ||
barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
ot/lp/dmmot.py
Outdated
Parameters | ||
---------- | ||
i : list | ||
The list for which the generalized EMD cost is to be computed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list of integer indexes...
ot/lp/dmmot.py
Outdated
from ..backend import get_backend | ||
|
||
|
||
def dist_monge(i): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dist_monge is a very generic for a very specific ground cost that indeed has monge property. This one MMOT ground cost with monge property not the only one. dist_monge_max_min
for instanec is better
ot/lp/dmmot.py
Outdated
return obj | ||
|
||
|
||
def discrete_mmot_converge( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here I would use monge_mmot_1dgrid_optimize
to stet clearly what the function does. i also need more discussion about why one would optimize all distributions together, you use is as some kind of "barycenter" since they all converge to a given distribution but be clear that it is not a barycenetr in the traditional OT sens.
test/test_dmmot.py
Outdated
return A.T, x | ||
|
||
|
||
def test_discrete_mmot(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you implemented the function with backends, you should run the tests on arrays from other backends, you can do that by adding a parameter nx to the test function that will be automatically run with all available backends.
examples/others/plot_d-mmot.py
Outdated
pl.figure(1, figsize=(6.4, 3)) | ||
for i in range(len(barys)): | ||
if i == 0: | ||
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT?
|
||
def discrete_mmot(A, verbose=False, log=False): | ||
r""" | ||
Compute the discrete multi-marginal optimal transport of distributions A. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should explain clearly ere that you suppose that the support of the distributions are supposed integers on the real line, this will be suggested by the new function name but it needs to be stated clearly.
test/test_dmmot.py
Outdated
return A.T, x | ||
|
||
|
||
def test_discrete_mmot(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to have a test comparing the loss returned by your solver with two marginals et the exact OT solver with absolute ground metric since they should be equivalent no?.
Dear POT team, I hope you are doing well! This pull request update contains changes and improvements based on the previous review. We appreciate your previous feedback, which guided these revisions. We look forward to your feedback on these updates. Best Regards, In summary, here are the key points of the update:
Comments resolved below: Q) both the objective value ad the norm of the gradient increase at the end which is very surprising since it is supposed to be a gradient descent no?
Q) dist_monge is a very generic for a very specific ground cost that indeed has monge property. This one MMOT ground cost with monge property not the only one. dist_monge_max_min for instanec is better Q) again this function name is too general, frm the name it looks like a general MMOT solver when in practice it applies only on regular 1D grids (for the marginals) and with a very specific ground metric. I suggest Q) state the ground cost instead of using "generalized Monge" that requires to read your paper. I understand the OT plan is indeednet of which Monge cost but you return the loss for a fixed one. Q) use alpha instead of p to maje the link with the input A of the function Q) Do not use x for the OT plan, we use either or a bold matrix T for plan in POT. x is already used for support position (that are integers i in your ase) and we need unified API/notations Q) here it would be nice to retrun a loss withe the gtradeints defined properly so that it can be used in pytorch with standard gradient decsnet algorithms. To do that you can use the backend function set_gradients that define the forward/backward relations . An example of its use can be fnd here: Q) same here I would use monge_mmot_1dgrid_optimizeto stet clearly what the function does. i also need more discussion about why one would optimize all distributions together, you use is as some kind of "barycenter" since they all converge to a given distribution but be clear that it is not a barycenetr in the traditional OT sens. Q) since you implemented the function with backends, you should run the tests on arrays from other backends, you can do that by adding a parameter nx to the test function that will be automatically run with all available backends. Q) it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT? Q) you should explain clearly ere that you suppose that the support of the distributions are supposed integers on the real line, this will be suggested by the new function name but it needs to be stated clearly. Q) It would be nice to have a test comparing the loss returned by your solver with two marginals et the exact OT solver with absolute ground metric since they should be equivalent no?. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @x12hengyu for all those changes. I know it was a lot of work but it looks good and more maintainable in the future.
I think the contribution is nearly there but there remains a small problem with the backend line where the gradient is set (see below).
Once this is done we can merge , Good work
I will wait for this and do a new release so this should be shortly available in the stable version.
ot/lp/dmmot.py
Outdated
'dual objective': dualobj} | ||
|
||
# define forward/backward relations for pytorch | ||
obj = nx.set_gradients(obj, (nx.from_numpy(A)), (dual)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here you need to use the A in input of the function (not a conversion from numpy) so that torch makes the link between this A and the objective. For instance store A0=A at the begining of the function and the use A0 here in set_gradient.
This will be very nice because your loss will be usable and differentiable with .backward()
in torch, opening the door to stochasti optimization and deep learning applications.
Dear POT team, I hope you are doing well! I have fixed the gradient problem. Thanks for your time and patient in reviewing our work in past months. We really appreciate! Best Regards, |
Store input variable instead of copying it
Types of changes
This introduce DEMD modules, a file
demd.py
contains all the vanilla modules of Efficient Discrete Multi Marginal Optimal Transport Regularization.Also includes two examples,
examples/others/plot_demd_1d.py
andexamples/others/plot_demd_gradient_minimize.py
Motivation and context / Related issue
Add new methods for Efficient Discrete Multi Marginal Optimal Transport Regularization, paper on ICLR 2023.
How has this been tested (if it applies)
Example
plot_demd_1d.py
uses two 1d Gaussian Distribution as test data and compares computing time with LP.Example
plot_demd_gradient_minimize.py
compares loss between demd minimizes using gradient decent and lp method.Also tested three functions separately in
test_demd.py
for sanity checks.PR checklist