Is it possible to index a python list with a jax array? #14804
Unanswered
pbsmoothie4life
asked this question in
Q&A
Replies: 2 comments 7 replies
-
Thanks for the question! How about just calling from jax import jit
from functools import partial
@jit
def ind(n):
return n
@partial(jit, static_argnums=(1,))
def fun(x,n):
return x[n]
n = ind(3)
x = list(range(10))
fun(x, int(n)) # CHANGE |
Beta Was this translation helpful? Give feedback.
3 replies
-
One way you could index into a list with a traced index is to use from jax import jit, lax
@jit
def fun(x,n):
# return x[n] # fails with TracerIntegerConversionError
return lax.switch(n, [lambda xi=xi: xi for xi in x])
x = list(range(10))
fun(x, 3) The mental model of what's going on here is that JAX creates ten different functions with constant outputs, pushes them all to the compiler, and then at runtime the compiler chooses which function to run based on the index value. This works, but will be pretty inefficient as the size of the list grows. In general, it's much more efficient to convert @jit
def fun2(x, n):
return x[n]
fun2(jnp.array(x), 3) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm wondering if there's a way to index a python list with a jax array? I know that one can index a list with a python int, by passing that argument as a static_argnum to jit, but what about the case when the index comes from a jit'ed function? A minimally working example showing the issue is the following....
I understand that in this specific example I can turn x into a jnp.array and the example will work, but I'm interested in the case where x is a general Pytree, and so I need x to remain a list or some other nestable object. I also understand that if I don't jit the ind function, this example will work, but in my case the ind function is computationally expensive so I'd like to jit it. I've noticed that @jakevdp answers a lot of questions of these types. I feel like I might be able to use jax.lax.switch to do this, but I also feel like there may be a more direct solution that I'm missing. Thanks in advance for any help!
Beta Was this translation helpful? Give feedback.
All reactions