Efficiently re-use vjp_funcs using jax loops #23180
Unanswered
misterguick
asked this question in
Q&A
Replies: 4 comments 2 replies
-
Sorry for the lack of response here. The loop carry can only contain arrays, not functions. You should avoid returning a function from your loop carry. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all !
I'm writting an experiment that involves some variant of TBPTT. Here #22725 I learnt how it was possible to re-use gradient computation. My issue is that I need to do such a thing through long sequences in a loop which leads to exploding compilation times.
Here is a basic example where we cache some vjp_funcs to re-use them at later iteration.
Compilation time grows dramatically for non-trivial step functions (here I keep it simple).
Here is an example (where compilation times doesn't explode but it gives an idea).
I see two solutions
But the first one is not very satisfactory since it implies we can never run that type of computation in jitted context
The second solution (minimal example below) leads to an error due to the carry changing
I understand that we shouldn't be able to change the function completely from one iteration to another because computation must be static. However, in this case, it seems to me each new vjp_func involves the same computation. It looks pretty pure to me. It is the the same but just different numbers (for lack of a better terminology).
My question: is there absolutely no way to do this ? I'm open to any idea as long as it keeps the efficiency of caching the vjp and doesn't lead to exploding compilation times.
Thank you very much in advance !
Beta Was this translation helpful? Give feedback.
All reactions