Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples for pytorch backend throwing error if run with GPU #389

Closed
sahahner opened this issue Jul 29, 2022 · 0 comments · Fixed by #391
Closed

Examples for pytorch backend throwing error if run with GPU #389

sahahner opened this issue Jul 29, 2022 · 0 comments · Fixed by #391

Comments

@sahahner
Copy link

In the scripts exemplify the use of PyTorch in combination with POT, the torch.Generator() is always created on the CPU, which leads to an error if a GPU is available.

To reproduce the error run the example script 'plot_sliced_wass_grad_flow_pytorch.py' with an available GPU.

This is easily fixed by changing
gen = torch.Generator()
to
gen = torch.Generator(device=device)
in lines 77 and 139 of the script.
The same has to be done in the example jupyter notebook.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant