Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ot.gromov.gromov_wasserstein2 loss does not perform backprop with torch CUDA tensor #351

Closed
tbng opened this issue Mar 1, 2022 · 1 comment · Fixed by #352
Closed

ot.gromov.gromov_wasserstein2 loss does not perform backprop with torch CUDA tensor #351

tbng opened this issue Mar 1, 2022 · 1 comment · Fixed by #352

Comments

@tbng
Copy link
Contributor

tbng commented Mar 1, 2022

As title. Following is a short snippet to reproduce the error.

import numpy as np
import ot
import torch
from ot.gromov import gromov_wasserstein2


def gw_pytorch_exam(C1, C2, a1, a2, device, n_iter=1000, lr=1e-2):

    C1_torch = torch.tensor(C1, device=device, requires_grad=True)
    C2_torch = torch.tensor(C2, device=device)
    a1_torch = torch.tensor(a1, device=device)    
    a2_torch = torch.tensor(a2, device=device)

    for i in range(n_iter):
        loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
        loss.backward()
        with torch.no_grad():
            grad = C1_torch.grad
            C1_torch -= grad * lr
            C1_torch.grad.zero_()
            C1_torch.data = torch.clamp(C1_torch, 0, 1)

    return C1_torch


if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")  # maybe should disable this to force GPU usage

n = 10
C1 = np.eye(n)    
C2 = np.random.randn(n, n)
a = ot.unif(n)
C1 = gw_pytorch_exam(C1, C2, a, a, device)

Running this code returns RuntimeError

     36 a = ot.unif(n)
---> 37 C1 = gw_pytorch_exam(C1, C2, a, a, device)

<ipython-input-3-afc0d26d5054> in gw_pytorch_exam(C1, C2, a1, a2, device, n_iter, lr)
     16     for i in range(n_iter):
     17         loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)
---> 18         loss.backward()
     19         with torch.no_grad():
     20             grad = C1_torch.grad

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    128         retain_graph = create_graph
    129 
--> 130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
    132         allow_unreachable=True)  # allow_unreachable flag

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/function.py in apply(self, *args)
     87     def apply(self, *args):
     88         # _forward_cls is defined by derived class
---> 89         return self._forward_cls.backward(self, *args)  # type: ignore
     90 
     91 

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in backward(ctx, grad_output)
   1381             def backward(ctx, grad_output):
   1382                 # the gradients are grad
-> 1383                 return (None, None) + tuple(g * grad_output for g in ctx.grads)
   1384 
   1385         self.ValFunction = ValFunction

~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in <genexpr>(.0)
   1381             def backward(ctx, grad_output):
   1382                 # the gradients are grad
-> 1383                 return (None, None) + tuple(g * grad_output for g in ctx.grads)
   1384 
   1385         self.ValFunction = ValFunction

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

PyTorch: 1.7.0
POT: 0.8.1
CUDA: 10.1 on NVIDIA Tesla P100

@ncassereau
Copy link
Contributor

Hello, I have reproduced your issue and I have found that it comes from the fact that part of the computation is done in numpy. When converting back to the correct backend, the device has been forgotten on the way. I will make a PR to correct that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants