Skip to content

JAX matmul (100x100) seems ~16x slower than numpy because no asynchronous behavior?! #13960

Answered by jakevdp
RezaRob asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - thanks for the question! There's some general information about this kind of question in the FAQ: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy

In particular, when you run linear algebra operations on CPU, they are lowered to whatever BLAS/LAPACK libraries the package was compiled against, and this statement is true for both JAX and NumPy. If NumPy is significantly faster, then I suspect you're using a numpy installation built against a faster BLAS (for example, numpy installed via conda is often built against MKL).

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
5 replies
@RezaRob
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@RezaRob
Comment options

Answer selected by RezaRob
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants