Is there any way to profile whether there is any unintended jit re-compilation happening? #24845
-
Hello all! I have some rather complex code I am trying to ensure is fully compatible with jit end-to-end (i.e. all the code is contained within a single function that I can then just call jit on). I was wondering if when I run this code it's possible to profile it in some way such that I can see if everything has been compiled as I intended, or if I've written code badly in places that might be prompting jit to re-compile unnecessarily and slow things down? Thank you in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can use In [1]: import jax
In [2]: @jax.jit
...: def f(x):
...: return x ** 2
...:
In [3]: with jax.log_compiles():
...: f(1.0)
...: f(1)
...:
WARNING:2024-11-11 15:44:37,372:jax._src.dispatch:181: Finished tracing + transforming f for pjit in 0.000458002 sec
WARNING:2024-11-11 15:44:37,411:jax._src.interpreters.pxla:1903: Compiling f with global shapes and types [ShapedArray(float32[], weak_type=True)]. Argument mapping: (UnspecifiedValue,).
WARNING:2024-11-11 15:44:37,421:jax._src.dispatch:181: Finished jaxpr to MLIR module conversion jit(f) in 0.010539055 sec
WARNING:2024-11-11 15:44:37,429:jax._src.dispatch:181: Finished XLA compilation of jit(f) in 0.006999016 sec
WARNING:2024-11-11 15:44:37,430:jax._src.dispatch:181: Finished tracing + transforming f for pjit in 0.000257015 sec
WARNING:2024-11-11 15:44:37,430:jax._src.interpreters.pxla:1903: Compiling f with global shapes and types [ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue,).
WARNING:2024-11-11 15:44:37,432:jax._src.dispatch:181: Finished jaxpr to MLIR module conversion jit(f) in 0.001645803 sec
WARNING:2024-11-11 15:44:37,439:jax._src.dispatch:181: Finished XLA compilation of jit(f) in 0.007171869 sec This tells you that |
Beta Was this translation helpful? Give feedback.
You can use
jax.log_compiles
to log when and why recompilations are happening. For example: