Replies: 1 comment 5 replies
-
Can you put together a reproducible example of what you're seeing? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am trying to create a custom jvp rule for a newton ralphson root finder algorithm.
Since I use linear solvers and host of other external stuff inside the newton_solver_fn(), I do the following:
As per my understanding, JAX should not look inside "_solve()".
But when I try to do jax.grad(), it shows an error:
"ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values"
which I believe is coming from inside the "_solve()".
Why is this happening? Is there something that I am missing?
P.s. The error only happens when I create the jvp [which is an abstract tracer]and use that as the RHS. If I simply do it with some other RHS vector, the code works [Of course, the gradient is wrong].
Beta Was this translation helpful? Give feedback.
All reactions