Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add set_gradients method for JAX backend. #278

Merged
merged 3 commits into from
Oct 22, 2021

Conversation

AdrienCorenflos
Copy link
Contributor

@AdrienCorenflos AdrienCorenflos commented Sep 8, 2021

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and context / Related issue

The set_gradient is possible in JAX.

#277

How has this been tested (if it applies)

Added a modified unittest for JAX.

Checklist

  • [ X ] The documentation is up-to-date with the changes I made.
  • [ X ] I have read the CONTRIBUTING document.
  • [ X ] All tests passed, and additional code has been covered with new tests.

@codecov
Copy link

codecov bot commented Sep 8, 2021

Codecov Report

Merging #278 (eb8d0cf) into master (14c30d4) will increase coverage by 0.01%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #278      +/-   ##
==========================================
+ Coverage   92.64%   92.66%   +0.01%     
==========================================
  Files          19       19              
  Lines        3754     3761       +7     
==========================================
+ Hits         3478     3485       +7     
  Misses        276      276              

@rflamary rflamary changed the title Add set_gradients method for JAX backend. [MRG] Add set_gradients method for JAX backend. Sep 9, 2021
@rflamary
Copy link
Collaborator

rflamary commented Sep 9, 2021

This is awesome! that's why you need people who know the framework for multiple backend.

Could you please also check the you can compute gradients for emd2? by updating this test?
https://github.com/PythonOT/POT/blob/master/test/test_ot.py#L84

@AdrienCorenflos
Copy link
Contributor Author

Well, no you won't be able to. You are casting your tensors to numpy, and JAX tracing is not compatible with this... In order to use JAX you would need to dispatch the call to the host by using a callback instead (see https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html), and only after this set the gradients. In practice this means that the backends should essentially have a kind of pycall method.

@AdrienCorenflos
Copy link
Contributor Author

practice this means that the backends should essentially have a kind of pycall method.

Note that this is feasible, but it's a bit more work than simply setting gradients, and requires some preliminary design thinking.

@rflamary
Copy link
Collaborator

OK i'm sorry i vnever understoiod the subtleties of jax. But if what you implemented did not define the gradient for a given variable, what is its use?

@AdrienCorenflos
Copy link
Contributor Author

It does replace the gradient within JAX code, but it does not allow to bypass the tracing. The problem of emd2 is not with the gradient, it's with the fact that you have an operation that JAX can't trace in the middle.

@rflamary
Copy link
Collaborator

Note that We use the set_gradient in emd2 exactly to bypass the need for numpy arrays and it works for torch. This is why i tried to define a new function with @custom_jvp (and call it) in set gradients but i wasn't a good strategy.

@rflamary rflamary merged commit d50d814 into PythonOT:master Oct 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants