Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix Gradient scaling in Partial GW solver #602

Merged
merged 7 commits into from
Jun 21, 2024

Conversation

yikun-baio
Copy link
Contributor

@yikun-baio yikun-baio commented Feb 1, 2024

Types of changes

I modify the code ot.partial.partial_gromov_wasserstein

Motivation and context / Related issue

There seems to be an inconsistency between ot.partial.partial_gromov_wasserstein and the line search section in the paper [29]. I fixed this section. In addition, I have made minor change to the initial guess in the partial-GW solver since the original initial guess is np.out(p,q), which might not be suitable for unbalanced case, i.e. |p|\neq |q|.

N/A

N/A

How has this been tested (if it applies)

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title new file: ot/partial_gw.py [Fix] Gradien saling in Partial GW solver Feb 1, 2024
@rflamary
Copy link
Collaborator

rflamary commented Feb 1, 2024

Hello @yikun-baio and thanks for the PR.

You should not propose anew file but instead implement directly your fix in th existing files. This is important because we have many tests that checks that nothing else breaks and we can compare easily your modification with the old implementation. We will do a code review with @lchapel when this is done.

@rflamary rflamary changed the title [Fix] Gradien saling in Partial GW solver [Fix] Gradien scaling in Partial GW solver Feb 12, 2024
@rflamary rflamary changed the title [Fix] Gradien scaling in Partial GW solver [Fix] Gradient scaling in Partial GW solver Feb 12, 2024
@rflamary rflamary changed the title [Fix] Gradient scaling in Partial GW solver [Fix] Gradient saling in Partial GW solver Feb 12, 2024
@yikun-baio
Copy link
Contributor Author

Hello @rflamary,

Thank you for your feedback. I apologize for not implementing the fix directly in the existing file and instead proposing a new file. This is my first time to contribute to a public project via a pull request.

I just realized that my email settings were inadvertently blocking emails from GitHub, which caused me to delay seeing your message.

Not sure if it's still needed, but I'll implement the fix as suggested and make sure it's done in an existing file for @lchapel's direct code review. Thank you very much for your guidance and I look forward to contributing more effectively in the future.

Thank you for your understanding and patience.

Sincerest regards,
Yikun

@rflamary rflamary changed the title [Fix] Gradient saling in Partial GW solver [Fix] Gradient scaling in Partial GW solver Mar 4, 2024
@rflamary
Copy link
Collaborator

Hello @yikun-baio this is a friendly reminder to implement your fix directly in the code if possible. I used a rocket emoji in your previous question to say that we are interested but maybe you did not receive a notfication.

Copy link

codecov bot commented Mar 29, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 96.68%. Comparing base (628a089) to head (b862248).

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #602   +/-   ##
=======================================
  Coverage   96.68%   96.68%           
=======================================
  Files          85       85           
  Lines       16890    16890           
=======================================
  Hits        16330    16330           
  Misses        560      560           

@cedricvincentcuaz
Copy link
Collaborator

cedricvincentcuaz commented Jun 12, 2024

Hello @yikun-baio,

Thank you for your PR. In order to proceed with the review, could you first implement @rflamary corrections ? and make sure that all tests pass. Then I have some doubts on the current implementation.

Could you precise which inconsistency did you spot in the paper [29] ?
According to Eq. 2 in the paper, authors considered the factor 1/2 for the GW cost which is omitted in our implementation of GW (cf docs). As long as the documentation of the PGW solver is clear, with the explicited loss (which is currently lacking), I don't see a major problem with having these differences within POT.

This difference implies that the computation of the gradient in gwgrad_partial had to be adapted and the implementation is correct. Same for gwloss_partial.
However I also believe the current line-search implementation to be wrong but your fix does not seem to match:

M = gwgrad_partial(C1, C2, G0) # Here we want the 4D-tensor product to match calculus -> missing *0.5 
...
a = gwloss_partial(C1, C2, deltaG) # correct
b = 2 * np.sum(M * deltaG) # correct

Then I agree on the fact that the initial transport plan should be admissible such as the one you proposed. I would suggest to store p.sum() and q.sum() early in the function, and remove the redundancies in the current implementation.

Best,
Cédric

@yikun-baio
Copy link
Contributor Author

yikun-baio commented Jun 13, 2024

Hello, Cédric

Could you precise which inconsistency did you spot in the paper [29] ?
According to Eq. 2 in the paper, authors considered the factor 1/2 for the GW cost which is omitted in our implementation of GW (cf docs). As long as the documentation of the PGW solver is clear, with the explicited loss (which is currently lacking), I don't see a major problem with having these differences within POT.

Please refer to the attached PDF. In Part 1, I explain the Linear Search problem and its solution. The solution I have derived is consistent with [29]. In Part 2, I go through the code related to the Line Search section. I have highlighted the important parts in red for your review.

explaination of the linear search.pdf

M = gwgrad_partial(C1, C2, G0) # Here we want the 4D-tensor product to match calculus -> missing *0.5
...
a = gwloss_partial(C1, C2, deltaG) # correct
b = 2 * np.sum(M * deltaG) # correct

Based on the pdf, I think it should be changed to the following:

M = gwgrad_partial(C1, C2, G0) # M = \mathcal{M}\circ G there is no 1/2 

old_a = gwloss_partial(C1, C2, deltaG) # a= 1/2 <M\circ deltaG, deltaG>.  there is  1/2 term,  
old_b = 2 * np.sum(M * deltaG) # b = 2 <M\circ G, delta G>,   a,b are not consistant.

# option 1. (no 1/2 term for both a and b) 
a= 2* gwloss_partial(C1, C2, deltaG) # a= <M\circ deltaG, deltaG> 
b= 2*np.sum(M * deltaG) # b=2<M\circ G, delta G>

# option 2. (apply 1/2 term for both a and b) 
a=  gwloss_partial(C1, C2, deltaG) # a= 1/2 <M\circ deltaG, deltaG> 
b= np.sum(M * deltaG) # b=<M\circ G, delta G>

Could you take a look if my understanding is correct?

Thanks,
Yikun Bai

@cedricvincentcuaz
Copy link
Collaborator

cedricvincentcuaz commented Jun 13, 2024

Thank you for these details.

I think we are saying the same thing, where you leverage the fact that anyways there will be a quotient between a and b, so it does not matter to rescale one or the other.
It is just preferable to code the exact formula for these coefficients instead of using tricks that we might forget ;)

CF image below:

Screenshot 2024-06-13 at 6 28 40 PM

@cedricvincentcuaz
Copy link
Collaborator

Hello @yikun-baio,
As we are rushing for a release to support numpy >= 2.0, I implemented the modifications mentioned above and some small modifications needed for tests to pass.
I will merge when tests are finished.

Thank you for your contribution to POT :)

@cedricvincentcuaz cedricvincentcuaz changed the title [Fix] Gradient scaling in Partial GW solver [MRG] Fix Gradient scaling in Partial GW solver Jun 21, 2024
@cedricvincentcuaz cedricvincentcuaz merged commit 14c08ba into PythonOT:master Jun 21, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants