Replies: 1 comment
-
You'll find it faster if you jit compile the sparsified code: func_sp = sparse.sparsify(mul_jnp) Also, please see https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code for tips on writing effective microbenchmarks for JAX code, and notes about when and where you can expect JAX to be faster or slower than similar code written in numpy/scipy. Also, I'd suggest workng with Out of curiosity, I tried a more careful benchmark of JAX vs scipy versions of this operation: import jax
import jax.numpy as jnp
from jax.experimental import sparse as jsparse
import numpy as np
from scipy import sparse
mat = np.random.rand(1000, 1000)
mat[mat < 0.5] = 0
mat_coo = sparse.coo_matrix(mat)
vec = np.random.rand(1000)
mat_bcoo = jsparse.BCOO.fromdense(mat)
vec_jax = jnp.array(vec)
def func(M, v):
return 2 * (M @ v)
func_jax = jax.jit(func)
# trigger JIT compilation
_ = func_jax(mat_bcoo, vec_jax).block_until_ready()
print('jax operation')
%timeit func_jax(mat_bcoo, vec_jax).block_until_ready()
print()
print('scipy operation')
%timeit func(mat_coo, vec) On a Colab CPU runtime I get this:
While on a Colab GPU runtime I get this:
As discussed in the FAQ entry I linked to, it's not surprising that for a small, self-contained operation like this run on CPU, the numpy ecosystem is more performant than JAX: this is exactly the domain that numpy/scipy operations have been optimized for! But you'll often find that JAX performance surpasses that of numpy-related tools when you're working on accelerators, and/or JIT-compiling longer programs. |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
All reactions