Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Efficient Discrete Multi Marginal Optimal Transport #454

Merged
merged 41 commits into from
Aug 3, 2023

Conversation

xzyu02
Copy link
Contributor

@xzyu02 xzyu02 commented Apr 8, 2023

Types of changes

This introduce DEMD modules, a file demd.py contains all the vanilla modules of Efficient Discrete Multi Marginal Optimal Transport Regularization.
Also includes two examples, examples/others/plot_demd_1d.py and examples/others/plot_demd_gradient_minimize.py

Motivation and context / Related issue

Add new methods for Efficient Discrete Multi Marginal Optimal Transport Regularization, paper on ICLR 2023.

How has this been tested (if it applies)

Example plot_demd_1d.py uses two 1d Gaussian Distribution as test data and compares computing time with LP.
Example plot_demd_gradient_minimize.py compares loss between demd minimizes using gradient decent and lp method.
Also tested three functions separately in test_demd.py for sanity checks.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@xzyu02 xzyu02 changed the title Efficient Discrete Multi Marginal Optimal Transport Regularization [MRG] Efficient Discrete Multi Marginal Optimal Transport Regularization Apr 8, 2023
@rflamary
Copy link
Collaborator

Thanks for the PR. We will do a code reveiw s soon as possible.

@xzyu02
Copy link
Contributor Author

xzyu02 commented Apr 12, 2023

Thanks for the PR. We will do a code reveiw s soon as possible.

Thank you for considering!

@xzyu02
Copy link
Contributor Author

xzyu02 commented Apr 21, 2023

Thanks for the PR. We will do a code reveiw s soon as possible.

Dear POT Team,

I hope you're well. Just a quick reminder about our group's pull request (#454) from two weeks ago. I understand the team's been busy, but we'd appreciate your feedback when you have a moment. If you need clarification or have questions, feel free to reach out. We are eager to make any necessary adjustments!

Best regards,
Xizheng Yu

@rflamary rflamary changed the title [MRG] Efficient Discrete Multi Marginal Optimal Transport Regularization [WIP] Efficient Discrete Multi Marginal Optimal Transport Regularization Apr 24, 2023
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello and thanks for the PR,

I did a quick code review. The PR needs major changes so that I can better understand what the new solvers are doing (you cannot ask people to read the paper to use a function, it should be described in the documentation). It is a bit unclear to me if you provide an EMD, multimarginal solver, a barycenter estimator (especially since you function takes only one array as input) so please clarify this in the code and maybe the PR description.

What needs to be clarified and changed:

  • Move all solvers to a submodule better named more clear such as ot.lp.discrete_emd .
  • Please respect POT API: use M for ground loss, a,A for distributions (on the simplex) and all parameters names. All function names can be long but must describe precisely what the function does. For instance if a function computes a barycenter then it should be similarly named to other solvers and be called the same way.
  • Add to the documentation of all function, the optimization problems solved in math environment using same names as other POT functions.
  • Take into account all the small comments below.

I know this is a lot of work but we need to ensure that the code fits well in POT and is easy to use and understand by users/conributors.

@@ -19,6 +19,7 @@ API and modules
coot
da
datasets
demd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not OK with demd wich is too short and not clear enough for a module name. This is also the case with the main functions, not everyone has reda the paper and it sould be clear from the function name what it is doing.

Due to historical reasons ot.emd is the exact disrete Ot solver thaat is very general and not only for EMD so we need to find another name for the new solvers in this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new solvers should be in ot.lp.discrete_emd or something else more descriptive



def lp_1d_bary(data, M, n, d):
A = np.vstack(data).T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the example should not need to transpose the data. it means that the API for the implemented function is not good (it should retrun the smae thing as ot.lp.barycenter)

print('')
print('D-EMD Algorithm:')
ot.tic()
demd_obj = ot.demd(np.vstack(data), n, d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the barycener? Objective is nice but the barycenetr shoudl be rtruned

return ns, lp_times, demd_times


ns, lp_times, demd_times = increasing_bins()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only plot thr time? pleade also plot the barycenter.


# data, M = getData(n, d, 'uniform')
data, M = getData(n, d, 'skewedGauss')
data = np.vstack(data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plot the data

ot/demd.py Outdated

dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)])

return {'x': xx, 'primal objective': obj,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the output of the function should be of the same type as the input of the function

ot/demd.py Outdated
except Exception:
pass

dualobj = sum([_.dot(_d) for _, _d in zip(aa, dual)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use $nx.sum and nx.dot to ensure that it will work across becknds$

ot/demd.py Outdated

def demd(x, d, n, return_dual_vars=False):
r"""
Solver of our proposed method: d−Dimensional Earch Mover’s Distance (DEMD).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

desctibe mroe precisely what you are solkving. is it a barycenter?

ot/demd.py Outdated
'dual': dual, 'dual objective': dualobj}


def demd(x, d, n, return_dual_vars=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad naming, at least discrete_emd or discrete_emd2 if teh function return emd without the plan . if you give the function an empirical distribution then you shoud also put it in the name.

Finally emd is computed beteen twoi distibutions so why is there only on naumpy arrya here?

ot/demd.py Outdated
`f(x, d, n, return_dual_vars=True) -> (float, ndarray, ...)`
x : ndarray, shape (d, n)
The initial point for the optimization algorithm.
d : int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those shapes can be infered from x they should not be passed as parameters

@xzyu02
Copy link
Contributor Author

xzyu02 commented Apr 24, 2023

Hello and thanks for the PR,

I did a quick code review. The PR needs major changes so that I can better understand what the new solvers are doing (you cannot ask people to read the paper to use a function, it should be described in the documentation). It is a bit unclear to me if you provide an EMD, multimarginal solver, a barycenter estimator (especially since you function takes only one array as input) so please clarify this in the code and maybe the PR description.

What needs to be clarified and changed:

  • Move all solvers to a submodule better named more clear such as ot.lp.discrete_emd .
  • Please respect POT API: use M for ground loss, a,A for distributions (on the simplex) and all parameters names. All function names can be long but must describe precisely what the function does. For instance if a function computes a barycenter then it should be similarly named to other solvers and be called the same way.
  • Add to the documentation of all function, the optimization problems solved in math environment using same names as other POT functions.
  • Take into account all the small comments below.

I know this is a lot of work but we need to ensure that the code fits well in POT and is easy to use and understand by users/conributors.

Dear POT team,

Thank you for thoughtful and detailed code review. We apologize for the conflicts and problems right now. We will resolve them one by one and make sure it fits well in POT as well as contributors. We appreciate your time and effort. Thank you.

Best

@xzyu02 xzyu02 changed the title [WIP] Efficient Discrete Multi Marginal Optimal Transport Regularization [MRG] Efficient Discrete Multi Marginal Optimal Transport Regularization Jun 1, 2023
@xzyu02 xzyu02 requested a review from rflamary June 1, 2023 17:32
@xzyu02
Copy link
Contributor Author

xzyu02 commented Jun 16, 2023

Dear POT team,

I hope this message finds you well. I have recently completed work on the bug fix for workflow checks, and I believe it is now ready for review. I understand everyone has busy schedules, but I would appreciate if you could review this pull request and approve a workflow check at your convenience. Please feel free to raise any questions or concerns. Thank you in advance for your time and patient.

Best regards,
Xizheng Yu

@rflamary
Copy link
Collaborator

We will do that, thanks for your work it is appreciated.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello,

Thanks for all the work and the major modifications in your code, it is now in my opinion in a much better shape and we are approaching a merge.

Still I have a few important modifications now that I better understand what your proposed solvers actually do. I know it can get tiresome but POT is becoming very large and we need to be very careful with positioning new API and solvers.

My understanding is that they are MMOT solvers for marginal distribution that have a support on a regular 1D grid and using a specific ground metric (max-min that has monge property) that corresponds to the absolute value loss/EMD ground metric with two marginals.

This is very interesting and I want it in POT but it also means that we cannot use the very general names of the method that you chose in your published paper because it would lead to much confusion for POT users. So I proposed more precise names below that describe I think better what is solved and done. I provide more specific comments below (for instance having mathematical object follow the name of function parameters).

On a more scientific question I am surprised by the shape or your "barycenter" that is actually from my understanding a convergence point for the minimization of your MMOT formulation. It seems to be very similar to L2 (average) barycenters especially compared to the LP solution. Do you have an intuition why?

# Compare Barycenters in both methods
# ---------
pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shoumd compare it to the l2 (np.mean) barycenter because your barycenter looks very similar

# dmmot_obj, log = ot.lp.discrete_mmot(A.T, n, d)
barys, log = ot.lp.discrete_mmot_converge(
A, niters=3000, lr=0.000002, log=True)
dmmot_obj = log['primal objective']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both the objective value ad the norm of the graient increase at the end which is very surprising since it is supposed to be a gardient decsnet no?

# values cannot be compared.

# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.discrete_mmot_converge(A, niters=9000, lr=0.00001)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

ot/lp/dmmot.py Outdated
Parameters
----------
i : list
The list for which the generalized EMD cost is to be computed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list of integer indexes...

ot/lp/dmmot.py Outdated
from ..backend import get_backend


def dist_monge(i):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist_monge is a very generic for a very specific ground cost that indeed has monge property. This one MMOT ground cost with monge property not the only one. dist_monge_max_min for instanec is better

ot/lp/dmmot.py Outdated
return obj


def discrete_mmot_converge(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here I would use monge_mmot_1dgrid_optimizeto stet clearly what the function does. i also need more discussion about why one would optimize all distributions together, you use is as some kind of "barycenter" since they all converge to a given distribution but be clear that it is not a barycenetr in the traditional OT sens.

return A.T, x


def test_discrete_mmot():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you implemented the function with backends, you should run the tests on arrays from other backends, you can do that by adding a parameter nx to the test function that will be automatically run with all available backends.

pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT?


def discrete_mmot(A, verbose=False, log=False):
r"""
Compute the discrete multi-marginal optimal transport of distributions A.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should explain clearly ere that you suppose that the support of the distributions are supposed integers on the real line, this will be suggested by the new function name but it needs to be stated clearly.

return A.T, x


def test_discrete_mmot():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a test comparing the loss returned by your solver with two marginals et the exact OT solver with absolute ground metric since they should be equivalent no?.

@rflamary rflamary changed the title [MRG] Efficient Discrete Multi Marginal Optimal Transport Regularization [MRG] Efficient Discrete Multi Marginal Optimal Transport Jul 28, 2023
@xzyu02
Copy link
Contributor Author

xzyu02 commented Aug 1, 2023

Dear POT team,

I hope you are doing well! This pull request update contains changes and improvements based on the previous review. We appreciate your previous feedback, which guided these revisions. We look forward to your feedback on these updates.

Best Regards,
Xizheng Yu

In summary, here are the key points of the update:

  • Changed function names and answer comments based on suggestions.
  • Regarding the similarity of the shape of our "barycenter" to the L2 (average) barycenters: this is likely due to the Monge cost, for simple 1-D examples. I do not believe that the paper describes theory on the relationship of this Monge cost based barycenter to those obtained using the L2 cost. Our result has verified with cvx's gradient and proved to have the same result (cvx compare is available in paper's repo)

Comments resolved below:
Q) you should compare it to the l2 (np.mean) barycenter because your barycenter looks very similar
A) changed compare barycenter from lp to l2 (not sure if we should keep both, but l2 does similar to ours result)

Q) both the objective value ad the norm of the gradient increase at the end which is very surprising since it is supposed to be a gradient descent no?

  • same here
    A) We added lr decay in dmmot_monge_1dgrid_optimize method to control the step size when we approach the minimized obj.
  • list of integer indexes...
    • Fixed.

Q) dist_monge is a very generic for a very specific ground cost that indeed has monge property. This one MMOT ground cost with monge property not the only one. dist_monge_max_min for instanec is better
A) Fixed name.

Q) again this function name is too general, frm the name it looks like a general MMOT solver when in practice it applies only on regular 1D grids (for the marginals) and with a very specific ground metric. I suggest monge_mmot_1dgrid_loss that describes the loss for mmot with monge ground cost on a regular 1D grid (and why we only give the distributions weights to the function). It is a mouthfull but we need precise descriptions when creating new functison in a general purpose OT toolbox
A) Thanks for the detailed explanation on naming suggestion. Since the method is a discrete MMOT problem with Monge costs, we would like to suggest to use a similiar name dmmot_monge_1dgrid_loss tentatively. Since we have “d” distributions, so the grid is “d” dimensional, and our “n” is the discretization/number of bins.

Q) state the ground cost instead of using "generalized Monge" that requires to read your paper. I understand the OT plan is indeednet of which Monge cost but you return the loss for a fixed one.
A) Fixed.

Q) use alpha instead of p to maje the link with the input A of the function
A) Fixed in all mathematic discriptions.

Q) Do not use x for the OT plan, we use either or a bold matrix T for plan in POT. x is already used for support position (that are integers i in your ase) and we need unified API/notations
A) Fixed with \gamma

Q) here it would be nice to retrun a loss withe the gtradeints defined properly so that it can be used in pytorch with standard gradient decsnet algorithms. To do that you can use the backend function set_gradients that define the forward/backward relations . An example of its use can be fnd here:
A) Added inside dmmot_monge_1dgrid_loss, followed by https://github.com/PythonOT/POT/blob/release0.9/ot/backend.py#L1689

Q) same here I would use monge_mmot_1dgrid_optimizeto stet clearly what the function does. i also need more discussion about why one would optimize all distributions together, you use is as some kind of "barycenter" since they all converge to a given distribution but be clear that it is not a barycenetr in the traditional OT sens.
A) We would like to suggest a similar naming dmmot_monge_1dgrid_optimize for this method. Discussion: The main advantage here is the computation cost grows exactly linearly, and can even be faster if some distributions are “interior” to the others. This means that in the worst case we are computing roughly the same amount of “things” as the barycenter approaches, but in average and best cases we “skip” distributions that aren’t important to the computation of the d-dimensional cost. At each iteration step, only those on the “boundary” are being moved. Here's the figure regarding to move the boudary for your reference: ref-mmot

Q) since you implemented the function with backends, you should run the tests on arrays from other backends, you can do that by adding a parameter nx to the test function that will be automatically run with all available backends.
A) Added nx test. For our algorithm, since it involves multiple tensor modifications, tensorflow's immutable feature conflicts with our usage. Meanwhile, PyTorch's mutiple methods requires conversion from list to tensor. We decide to use simple conversion at the start and end of each method between nx and np, and use np for the algorithm calculation.

Q) it would be nice to see visually if you converged by plotting all the individual distributions (seems like you did because your "barycenter" ). maybe you could call it "Monge MMOT minimization" instead of discrete MMOT?
A) A plot for comparing all individual distributions has added, but they are really close due to the method (like we stated, every distribution can be view as a "barycenter"). The tentative naming is dmmot_monge_1dgrid_optimize.

Q) you should explain clearly ere that you suppose that the support of the distributions are supposed integers on the real line, this will be suggested by the new function name but it needs to be stated clearly.
A) Added to the documentation.

Q) It would be nice to have a test comparing the loss returned by your solver with two marginals et the exact OT solver with absolute ground metric since they should be equivalent no?.
A) Added to test

@xzyu02 xzyu02 requested a review from rflamary August 1, 2023 20:39
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @x12hengyu for all those changes. I know it was a lot of work but it looks good and more maintainable in the future.

I think the contribution is nearly there but there remains a small problem with the backend line where the gradient is set (see below).

Once this is done we can merge , Good work

I will wait for this and do a new release so this should be shortly available in the stable version.

ot/lp/dmmot.py Outdated
'dual objective': dualobj}

# define forward/backward relations for pytorch
obj = nx.set_gradients(obj, (nx.from_numpy(A)), (dual))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you need to use the A in input of the function (not a conversion from numpy) so that torch makes the link between this A and the objective. For instance store A0=A at the begining of the function and the use A0 here in set_gradient.

This will be very nice because your loss will be usable and differentiable with .backward() in torch, opening the door to stochasti optimization and deep learning applications.

@xzyu02 xzyu02 requested a review from rflamary August 2, 2023 20:40
@xzyu02
Copy link
Contributor Author

xzyu02 commented Aug 2, 2023

Dear POT team,

I hope you are doing well! I have fixed the gradient problem. Thanks for your time and patient in reviewing our work in past months. We really appreciate!

Best Regards,
Xizheng Yu

ot/lp/dmmot.py Outdated Show resolved Hide resolved
Store input variable instead of copying it
@rflamary rflamary merged commit 5ead79b into PythonOT:master Aug 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants