-
Does JAX use cache when JIT-compiling nested functions? Consider this example: import jax
import jax.numpy as jnp
def outer(a: jax.Array):
def inner(x: jax.Array):
print("compiling")
return x * 2
# print(f"id(inner) = {id(inner)}")
# print(f"hash(inner) = {hash(inner)}")
jitted_inner = jax.jit(inner)
return jitted_inner(a)
a = jax.random.normal(jax.random.key(0), (3, 4))
outer(a) When calling Some sources say that If Python re-defines nested functions on each call, how do I make JIT cache them? Full output after 3 calls:
System info:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 10 replies
-
As you've identified here, since def inner(x: jax.Array):
print("compiling")
return x * 2
def outer(a: jax.Array):
jitted_inner = jax.jit(inner)
return jitted_inner(a) which would work as you intend. To be completely explicit I would probably also move the @jax.jit
def inner(x: jax.Array):
print("compiling")
return x * 2
def outer(a: jax.Array):
return inner(a) but the former seems to do the trick as well. In this simple example, I don't see any reason why you wouldn't want to refactor like this, but I'm not sure how easily this generalizes to your case. Either way, I hope it helps! |
Beta Was this translation helpful? Give feedback.
That's what I expected! Yeah, like you say, the usual advice here would be to move the
jit
as high up the stack as a you can. For example, in the flax examples that you link to, thejit
is applied to the training step, e.g.:in which case
inner
is only compiled once!But, there are cases where this won't necessarily work (e.g. long compile times, etc.). In that case, maybe you could try converting the closure into a compiled function (at the global level) which takes the relevant parameters as static arguments, which should also lead to a cache hit.