-
Somewhy I can not vectorize indices = jnp.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
my_ravel = lambda x: jnp.ravel_multi_index(x, (3, 3, 3))
print(my_ravel(indices[1]))
print(vmap(my_ravel)(indices)) This code throws def ravel_by_hands(multi_index, shapes):
raveled_index = 0
for dim, index in zip(shapes, multi_index):
raveled_index = raveled_index * dim + index
return raveled_index
indices = jnp.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
my_ravel = lambda x: ravel_by_hands(x, (3, 3, 3))
print(my_ravel(indices[1]))
print(vmap(my_ravel)(indices)) This works perfectly fine. Is there a bug in jax.numpy.ravel_multi_index function or do I misunderstand something? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The error looks like this:
The issue is that the default my_ravel = lambda x: jnp.ravel_multi_index(x, (3, 3, 3), mode='clip')
print(my_ravel(indices[1]))
# 13
print(vmap(my_ravel)(indices))
# [ 0 13 26] You might ask why did we make the choice to use a default parameter value that is incompatible with JAX transforms: it's because this is a jax wrapper of a numpy API whose default semantics are incompatible with JAX transforms, so our options are either (1) change the default semantics (which may be confusing) or (2) raise an informative error message if the default case is transformed. We opted for 2. Does that make sense? |
Beta Was this translation helpful? Give feedback.
The error looks like this:
The issue is that the default
mode='raise'
is not compatible with JAX transforms likejit
orvmap
, and as mentioned by the error you sho…