Replies: 2 comments
-
I'm not sure what exactly you're looking for here. Can you put together a short example of what you want to compute? |
Beta Was this translation helpful? Give feedback.
0 replies
-
I suspect you're looking for |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I have a question related to my research.
I need to compute gradients of individual layers, not the whole model.
I saw that jax.grad can only be applied on functions with scalar output (i.e. a whole model with loss function at the end).
However, the computation of this gradient must create internally all the objects needed to build a gradient for functions with tensor output (the layers).
How difficult would it be to extract them and build, say, the gradient of a convolution?
Maybe there are internal functions that build what I need, which are not published on the interface, and which I could use?
Any help would be highly appreciated.
Edit: would jax.jacfwd and jax.jacrev what I look for?
Edit2: just tried jax.jacrev on a (small) convolution and it fails. It seems to try to allocate the whole Jacobian matrix, which can't work even for small convolutions. What I need is really the function that computes gradients on the inputs from gradients on the outputs.
Best regards,
Dumitru
Beta Was this translation helpful? Give feedback.
All reactions