-
Hi, I have a complicated function I am taking a vjp over. The resulting product ends up being quite large. To debug this, I would like to inspect the compute graph (with the individual partials) of the reverse-mode autodiff procedure and understand why this ended up being the case. What are the best practices for doing this in jax? Is there a toolkit for logging these kinds of quantities beyond defining Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I'm not sure that I have a great answer to the broad question here, but there are some options. One place to start is to try printing the jaxpr of your reverse pass: print(jax.make_jaxpr(jax.vjp(fun, *args)[1])(*ct)) Which you can also combine with Exampleimport jax
import jax.numpy as jnp
@jax.jit
def fun(x):
return f1(x) * f2(x)
@jax.jit
@jax.named_scope("f1")
def f1(x):
return jnp.sin(x)
@jax.jit
@jax.named_scope("f2")
def f2(x):
return jnp.exp(0.5 * x)
jax.make_jaxpr(jax.vjp(fun, jnp.ones(5))[1])(jnp.ones(5)) Prints:
This doesn't really help with logging the specific values, but perhaps it can provide some help identifying the problematic parts of the computation. In terms of actually logging or otherwise intercepting the reverse mode computations, I think that Perhaps you can see how far something like this gets you, but please feel free to post a more specific example function and more details about what you're looking for if it would be useful to discuss further. |
Beta Was this translation helpful? Give feedback.
I'm not sure that I have a great answer to the broad question here, but there are some options. One place to start is to try printing the jaxpr of your reverse pass:
Which you can also combine with
jax.named_scope
to get a little bit more metadata. For example:Example
Prints: