-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Add set_gradients method for JAX backend. #278
Conversation
Codecov Report
@@ Coverage Diff @@
## master #278 +/- ##
==========================================
+ Coverage 92.64% 92.66% +0.01%
==========================================
Files 19 19
Lines 3754 3761 +7
==========================================
+ Hits 3478 3485 +7
Misses 276 276 |
This is awesome! that's why you need people who know the framework for multiple backend. Could you please also check the you can compute gradients for emd2? by updating this test? |
Well, no you won't be able to. You are casting your tensors to numpy, and JAX tracing is not compatible with this... In order to use JAX you would need to dispatch the call to the host by using a callback instead (see https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html), and only after this set the gradients. In practice this means that the backends should essentially have a kind of pycall method. |
Note that this is feasible, but it's a bit more work than simply setting gradients, and requires some preliminary design thinking. |
OK i'm sorry i vnever understoiod the subtleties of jax. But if what you implemented did not define the gradient for a given variable, what is its use? |
It does replace the gradient within JAX code, but it does not allow to bypass the tracing. The problem of emd2 is not with the gradient, it's with the fact that you have an operation that JAX can't trace in the middle. |
Note that We use the set_gradient in emd2 exactly to bypass the need for numpy arrays and it works for torch. This is why i tried to define a new function with @custom_jvp (and call it) in set gradients but i wasn't a good strategy. |
Types of changes
Motivation and context / Related issue
The
set_gradient
is possible in JAX.#277
How has this been tested (if it applies)
Added a modified unittest for JAX.
Checklist