Replies: 1 comment 4 replies
-
It's not clear to me what you mean when you say "jit outside the for loop". Do you mean something like this compute_ccv = jax.jit(vmap(vmap(vmap(vmap(u(V,W,X,Y,Z, vf_array)...)
for period in range(reversed(periods)):
ccv = compute_ccv(V,W,X,Y,Z, vf_array)
vf_array = ccv.max(axis=[0,1,2}) or do you mean something like this? @jax.jit
def myfunc():
for period in range(reversed(periods)):
compute_ccv = vmap(vmap(vmap(vmap(vmap(u(V,W,X,Y,Z, vf_array)...) #Nested Vmap with the vectors that span the grid
ccv = compute_ccv(V,W,X,Y,Z, vf_array)
vf_array = ccv.max(axis=[0,1,2}) And for "jit inside the for loop", do you mean something like this? for period in range(reversed(periods)):
compute_ccv = vmap(vmap(vmap(vmap(vmap(u(V,W,X,Y,Z, vf_array)...) #Nested Vmap with the vectors that span the grid
ccv = jax.jit(compute_ccv)(V,W,X,Y,Z, vf_array)
vf_array = ccv.max(axis=[0,1,2}) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am working on a project that solves user specified economic models on the GPU. To do that I need a for loop that evaluates a function on grid and then calculates a maximum along the axes of the grid. The found maxima are then the input for the next iteration of the loop. I can't really provide a working example of the code, but it generally looks like this:
When I dont use jit, the nested vmap creates a big array in memory of shape [VxWxXxYxZ], as I would expect.
When I use jit inside of the for loop, like this, the memory consumption is suddenly very low, some kind of optimization seems to be happening.
When I put the jit outside of the for loop, the memory consumption is again the same as when I don't jit anything.
Because I was interested in what optimization the XLA compiler does, I started looking at the output of
compute_ccv.lower().compile().as_text()
. In both cases, jit inside for loop and outside for loop, the compiled code seems to work with the huge array I would expect, here are some snippets of the output.Jit outside of For Loop:
Jit inside of For Loop:
I know it's probably not possible to help me with my specific problem, as I can't provide a short working example. But maybe someone could tell me if the output from
compute_ccv.lower().compile().as_text()
is actually the fully optimzed code or why the huge array is only allocated on the GPU in one case even though both compiled functions seem to work with it?Beta Was this translation helpful? Give feedback.
All reactions