-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalent of torch's retain_graph #22725
Comments
Thanks for the question! I think perhaps you just want import jax
import jax.numpy as jnp
# Define the forward function
def forward(x):
return jnp.sin(x) * jnp.cos(x)
# Compute the forward pass and store the result
x = jnp.ones(5)
y, f_vjp = jax.vjp(forward, x)
# Take the derivative with respect to the input
grad_outputs = jnp.ones_like(y)
grad1, = f_vjp(grad_outputs)
print("Forward output:", y)
print("Gradient 1:", grad1)
# Apply another function to the output
modified_y, g_vjp = jax.vjp(lambda y: y * 2, y)
# Evaluate the derivative of the function composition g . forward
grad2, = f_vjp(g_vjp(jnp.ones_like(modified_y))[0])
print("Modified forward output:", modified_y)
print("Gradient 2:", grad2) Just think of the What do you think? |
Sorry for the late reply. This lead me to rethink completely (or think for the first time maybe) how auto-diff worked. As a follow up of this issue I would like to point to #23180. Thank you for your help ! |
Hi all !
I've been having a blast learning JAX recently. However, coming from torch there is one type of operation that is very easy to do in torch but I can't wrap my head around how to do it in JAX's functional paradigm.
I basically want to differentiate twice through the same forward pass. More specifically, I want to be able to reuse computation of a forward pass for two distinct backward passes (and not take second order derivatives). This requirement is a deal-breaker for my research where the forward pass takes a very long time (I would have to go back to torch which I really don't want). With torch's computation graph and autograd's retain_graph this is very easy to do.
One motivation is for Truncated Backpropagation Through Time style scenarios where we might enlarge the forward pass' computation graph slightly at every step but we do not want to recompute the whole window. In torch the computation graph allows to organize the code in such a way that the two backward passes can be in completely different locations (just passing around the outputs). Here is a very simplified version of what I mean in torch.
EDIT: In the simple example below and in the case of TBPTT this could be achieved by applying the chain rule by hand. But I'm wondering if there is an automatic way to do it for arbitrarily complex functions (multi-output, multi-input). This would be much less error prone and maybe more efficient (?)
I understand that this goes against the whole philosophy of the framework as I understand it because it would assume we maintain side effects.
I have two questions:
Thank you very much in advance !
The text was updated successfully, but these errors were encountered: