-
The docs say that JAX does async dispatch on the CPU. How does JAX get around the GIL to accomplish this? Is there a separate process executing the computation while the python code runs ahead? In other frameworks (e.g. PyTorch), async dispatch only works for computation on an accelerator like a CUDA gpu, where operations can be queued up on a CUDA stream and functions just return a future. However, in PyTorch CPU operations are always synchronous. Since JAX implements async dispatch for computation on the CPU, what stops PyTorch from doing the same? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Now, I don't know how exactly is this implemented in JAX, but my guess is that it offloads work to some background thread that never acquires the GIL. About PyTorch, there's no big technical reason why it doesn't do async dispatch on CPU, but one should understand that it's a tradeoff. The upside to async dispatch is obviously that when you have big ops, then your Python thread can run ahead and you don't have to pay anything for Python interpretation. The downsides are twofold. First, async dispatch generally increases implementation complexity. Second, if you actually have lots of small ops, then async dispatch only adds overhead on top of the Python interpretation. People often run small workloads on the CPU and large ones on the GPU, so not using async dispatch in those situations is a viable option. |
Beta Was this translation helpful? Give feedback.
-
It doesn't work on CPU, at least for cheap computations (see #9895). The example in the documentation is actually running on GPU (the output of |
Beta Was this translation helpful? Give feedback.
Now, I don't know how exactly is this implemented in JAX, but my guess is that it offloads work to some background thread that never acquires the GIL.
About PyTorch, there's no big technical reason why it doesn't do async dispatch on CPU, but one should understand that it's a tradeoff. The upside to async dispatch is obviously that when you have big ops, then your Python thread can run ahead and you don't have to pay anything for Python interpretation. The downsides are twofold. First, async dispatch generally increases implementation complexity. Second, if you actually have lots of small ops, then async dispatch only adds overhead on top of the Python interpretation. People often run sma…