-
Here's the code: It seems like JAX isn't using multithreading (not truly asynchronous behavior apparently). I just looked at the 'top' command in Linux. Jax basically seems about 16x slower than numpy on my ~16-threaded machine when "small" 100x100 matrices are multiplied. But that's not really a tiny computation! Since JAX is asynchronous, shouldn't it basically be using multi-threads for this, just like numpy? The results are the same whether you use JAX_CPU = True or False. Can somebody please glance at my code and explain what's going on? And this is the output:
Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
I'm almost leaning towards retracting the question right now. I do know that vmap and jit can be used in this particular example (vmap is so super cool by the way!!), but of course that wasn't the point of the question. I will leave it in the spirit of perhaps getting some interesting discussion. I'm still interested why JAX doesn't go multi-thread at that scale. It picks up multi-threading at 1000x1000 matrix size and becomes more competitive. I suppose that some of the most interesting cases for multi-threading are hampered by the fact that python itself isn't multi-threaded, but if you need variable-sized arrays for some reason, then that's an obvious case where vmap won't do. Perhaps in practice that's not a concern? The other thing is dynamic graphs. Anyway, I'll leave it here, in case anyone wants to comment. Thanks. |
Beta Was this translation helpful? Give feedback.
-
Hi - thanks for the question! There's some general information about this kind of question in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy In particular, when you run linear algebra operations on CPU, they are lowered to whatever BLAS/LAPACK libraries the package was compiled against, and this statement is true for both JAX and NumPy. If NumPy is significantly faster, then I suspect you're using a numpy installation built against a faster BLAS (for example, numpy installed via conda is often built against MKL). |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! There's some general information about this kind of question in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
In particular, when you run linear algebra operations on CPU, they are lowered to whatever BLAS/LAPACK libraries the package was compiled against, and this statement is true for both JAX and NumPy. If NumPy is significantly faster, then I suspect you're using a numpy installation built against a faster BLAS (for example, numpy installed via conda is often built against MKL).