Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix Bug binary_search_circle on GPU and Gradients #457

Merged
merged 22 commits into from
Apr 18, 2023

Conversation

clbonet
Copy link
Contributor

@clbonet clbonet commented Apr 11, 2023

Types of changes

  • Fix bug in ot.lp.solver_1d.roll_cols to run on GPU
  • Fix bug in ot.lp.solver_1d.binary_search_circle to have gradients different from NaN
  • Added in backend.py the method detach() which detach a tensor from the graph (when available)
  • Added a test in test_sliced to assert that the gradient is not NaN when using pytorch

Motivation and context / Related issue

The functions ot.binary_search_circle could no be run on GPU. And gradients on pytorch using ot.sliced_wasserstein_sphere return NaN.

How has this been tested (if it applies)

  • A test to check whether the gradient is a NaN for ot.sliced_wasserstein_sphere has been added.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@clbonet clbonet changed the title [MRG] Fix Bug binary_search_circle on GPU and Gradients [WIP] Fix Bug binary_search_circle on GPU and Gradients Apr 11, 2023
@codecov
Copy link

codecov bot commented Apr 11, 2023

Codecov Report

Merging #457 (0c0fde5) into master (2bbfbbb) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head 0c0fde5 differs from pull request most recent head bccc436. Consider uploading reports for the commit bccc436 to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #457   +/-   ##
=======================================
  Coverage   94.92%   94.93%           
=======================================
  Files          31       31           
  Lines        6879     6890   +11     
=======================================
+ Hits         6530     6541   +11     
  Misses        349      349           

@clbonet clbonet changed the title [WIP] Fix Bug binary_search_circle on GPU and Gradients [MRG] Fix Bug binary_search_circle on GPU and Gradients Apr 18, 2023
@rflamary rflamary merged commit 9aa96c8 into PythonOT:master Apr 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants