Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent of torch's retain_graph #22725

Closed
misterguick opened this issue Jul 29, 2024 · 3 comments
Closed

Equivalent of torch's retain_graph #22725

misterguick opened this issue Jul 29, 2024 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@misterguick
Copy link

misterguick commented Jul 29, 2024

Hi all !

I've been having a blast learning JAX recently. However, coming from torch there is one type of operation that is very easy to do in torch but I can't wrap my head around how to do it in JAX's functional paradigm.

I basically want to differentiate twice through the same forward pass. More specifically, I want to be able to reuse computation of a forward pass for two distinct backward passes (and not take second order derivatives). This requirement is a deal-breaker for my research where the forward pass takes a very long time (I would have to go back to torch which I really don't want). With torch's computation graph and autograd's retain_graph this is very easy to do.

One motivation is for Truncated Backpropagation Through Time style scenarios where we might enlarge the forward pass' computation graph slightly at every step but we do not want to recompute the whole window. In torch the computation graph allows to organize the code in such a way that the two backward passes can be in completely different locations (just passing around the outputs). Here is a very simplified version of what I mean in torch.

EDIT: In the simple example below and in the case of TBPTT this could be achieved by applying the chain rule by hand. But I'm wondering if there is an automatic way to do it for arbitrarily complex functions (multi-output, multi-input). This would be much less error prone and maybe more efficient (?)

import torch

# Define the forward function
def forward(x):
    return torch.sin(x) * torch.cos(x)

# Compute the forward pass and store the result
x = torch.ones(5, requires_grad=True)
y = forward(x)

# Take the derivative with respect to the input
grad1 = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=torch.ones_like(y), create_graph=True)[0]

print("Forward output:", y)
print("Gradient 1:", grad1)

# Perform some modification on the output
modified_y = y * 2

# Take the derivative again with respect to the input using the modified output
grad2 = torch.autograd.grad(outputs=modified_y, inputs=x, grad_outputs=torch.ones_like(modified_y))[0]

print("Modified forward output:", modified_y)
print("Gradient 2:", grad2)

I understand that this goes against the whole philosophy of the framework as I understand it because it would assume we maintain side effects.

I have two questions:

  1. Is there in the API a function transform to wrap around the forward pass that would return something like a tuple of (the output, reusable computation representation) where the second output would be used like a virtual function call (same output, same dependencies but not ran explicitly) ?
  2. Is it simply doable in a different way that is less obvious for someone coming from torch ?

Thank you very much in advance !

@misterguick misterguick added the enhancement New feature or request label Jul 29, 2024
@PhilipVinc
Copy link
Contributor

@mattjj
Copy link
Collaborator

mattjj commented Jul 29, 2024

Is there in the API a function transform to wrap around the forward pass that would return something like a tuple of (the output, reusable computation representation) where the second output would be used like a virtual function call (same output, same dependencies but not ran explicitly) ?

Thanks for the question!

I think perhaps you just want jax.vjp:

import jax
import jax.numpy as jnp

# Define the forward function
def forward(x):
    return jnp.sin(x) * jnp.cos(x)

# Compute the forward pass and store the result
x = jnp.ones(5)
y, f_vjp = jax.vjp(forward, x)

# Take the derivative with respect to the input
grad_outputs = jnp.ones_like(y)
grad1, = f_vjp(grad_outputs)

print("Forward output:", y)
print("Gradient 1:", grad1)

# Apply another function to the output
modified_y, g_vjp = jax.vjp(lambda y: y * 2, y)

# Evaluate the derivative of the function composition g . forward
grad2, = f_vjp(g_vjp(jnp.ones_like(modified_y))[0])

print("Modified forward output:", modified_y)
print("Gradient 2:", grad2)

Just think of the vjp functions as pulling back gradients from output space to input space. In other words, they're "backward-pass functions". For more, see the autodiff cookbook if you haven't already.

What do you think?

@mattjj mattjj self-assigned this Jul 29, 2024
@mattjj mattjj added question Questions for the JAX team and removed enhancement New feature or request labels Jul 29, 2024
@misterguick
Copy link
Author

Sorry for the late reply. This lead me to rethink completely (or think for the first time maybe) how auto-diff worked. As a follow up of this issue I would like to point to #23180.

Thank you for your help !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants