Skip to content

Is there any way to profile whether there is any unintended jit re-compilation happening? #24845

Answered by jakevdp
Chulabhaya asked this question in Q&A
Discussion options

You must be logged in to vote

You can use jax.log_compiles to log when and why recompilations are happening. For example:

In [1]: import jax

In [2]: @jax.jit
   ...: def f(x):
   ...:     return x ** 2
   ...: 

In [3]: with jax.log_compiles():
   ...:     f(1.0)
   ...:     f(1)
   ...: 
WARNING:2024-11-11 15:44:37,372:jax._src.dispatch:181: Finished tracing + transforming f for pjit in 0.000458002 sec
WARNING:2024-11-11 15:44:37,411:jax._src.interpreters.pxla:1903: Compiling f with global shapes and types [ShapedArray(float32[], weak_type=True)]. Argument mapping: (UnspecifiedValue,).
WARNING:2024-11-11 15:44:37,421:jax._src.dispatch:181: Finished jaxpr to MLIR module conversion jit(f) in 0.010539055 sec
WARNING:2024-

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Chulabhaya
Comment options

Answer selected by Chulabhaya
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants