-
Dear JAX community, I found that the conversion from a JAX array to a Numpy array is slower after calling a JAX function. This slowing down persists even if that JAX function has nothing to do with the array conversion. The slowing down scales with the computational complexity of the JAX function. Results:
Without calling the JAX function: everything is normal.
This issue appears on both CPU and GPU, and here is an example for reproducing the issue. I would like to ask if there are any fixes.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I'm not able to reproduce this with JAX v0.4.33 on either a Colab CPU or GPU runtime. That said, I have a guess as to why this may be happening for you: I suspect your timings are being fooled by JAX's Asyncronous dispatch: when you call a JAX operation, the Python function returns before the result is actually computed. So your first call to If I'm right, then adding this to your first line should make all runs take the same amount of time: jax.block_until_ready(compute(in_data)) (This is assuming that More broadly, if you're interested in microbenchmarks of JAX code, I'd suggest starting here to make sure you're measuring what you think you're measuring: FAQ – benchmarking JAX code. |
Beta Was this translation helpful? Give feedback.
I'm not able to reproduce this with JAX v0.4.33 on either a Colab CPU or GPU runtime. That said, I have a guess as to why this may be happening for you: I suspect your timings are being fooled by JAX's Asyncronous dispatch: when you call a JAX operation, the Python function returns before the result is actually computed. So your first call to
compute
essentially just queues up the computations, which begin running in the background. By the time you get to your second call, the queue is full, and so the Python function must wait for the previous iterations to finish before it can enqueue its computations.If I'm right, then adding this to your first line should make all runs take the same …