Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the sign of gradient for kl gromov #610

Merged

Conversation

KrzakalaPaul
Copy link
Contributor

@KrzakalaPaul KrzakalaPaul commented Mar 1, 2024

Types of changes

Correct a sign error of the nx.set_gradient for gromov (and fused gromov) when loss_fun = 'kl_loss'.
The correct formula is:

gC2 = - nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

instead of

gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)

How has this been tested (if it applies)

You can run this code:

from ot import gromov_wasserstein2,unif
import torch 

C1 = torch.rand((2,2), requires_grad = False)
C2 = 1 - C1
C2.requires_grad = True

eta = 1e-1

for step in range(100):
    
    loss = gromov_wasserstein2(C1=C1,C2=C2,p=unif(2,type_as=C2),q=unif(2,type_as=C2),loss_fun='square_loss')
    grad = torch.autograd.grad(loss, C2)[0]
    C2 = C2 - eta*grad
    C2 = torch.clip(C2,0,1)
    
    print(loss)

You will see that the gradient descent diverges. It converges when we fix the sign error.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link

codecov bot commented Mar 1, 2024

Codecov Report

Merging #610 (7cecd40) into master (0573eba) will not change coverage.
The diff coverage is 100.00%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #610   +/-   ##
=======================================
  Coverage   96.78%   96.78%           
=======================================
  Files          77       77           
  Lines       16027    16027           
=======================================
  Hits        15511    15511           
  Misses        516      516           

@cedricvincentcuaz
Copy link
Collaborator

Hello @KrzakalaPaul, indeed good catch.

In your example you do not have loss_fun = kl_loss. So doing a quick check to be sure, in this case the GW loss reads as
$E(\mathbf{A}, \mathbf{B}, \mathbf{T}) = \sum_{ijkl} KL(A_{ij}, B_{kl})T_{ik}T_{jl}$
with $KL(A_{ij}, B_{kl}) = A_{ij} \log(A_{ij}) - A_{ij}\log(B_{kl}) - A_{ij} + B_{kl}$.

So $\frac{\partial E}{\partial B_{pq}} = \sum_{ij} (- \frac{A_{ij}}{B_{kl}} + 1) T_{ip}T_{jq}$ and we indeed forgot a - in the current POT implementation.

@cedricvincentcuaz cedricvincentcuaz merged commit 3e05385 into PythonOT:master Mar 4, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants