Replies: 1 comment 1 reply
-
Thanks for the question. However, I think we fundamentally need to compute For example, say we compose your def gf(x):
return g(f(x)) Then the value-and-gradient is def gf_vjp(x, v):
# forward pass
y, res = f_fwd(x)
z, g_vjp = jax.vjp(g, y) # !!!
# backward pass
v, = g_vjp(1.0)
bwd_val, = f_bwd(res, v)
return z, bwd_val On the line marked with What do you think? Or did I misunderstand? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have an application where my function
f(x)
is very expensive and must be computed at both the forward and backward pass of the VJP.I could benefit from computing the return value of
f_fwd
in the bwd function as I would be able to take advantage of parallelism when computing fwd and bwd values. However, I don't see any way to do this within the constraints ofjax.custom_vjp
.Does anyone have any clever ideas to avoid computing
f(x)
inf_fwd
but instead compute it inf_bwd
in the contrived example below? Thanks!A few notes:
bwd
return value depends on the vjp inputv
in a non-trivial way and therefore can't be computed inf_fwd
.Appreciate if anyone has any thoughts!
Beta Was this translation helpful? Give feedback.
All reactions