Passing array of random 'keys' in vmap #13924
-
I have a random array of numbers, where each element can be either -1., 0. or +1. I keep on choosing a random index such that ultimately the corresponding element becomes equal to an iniatially given index ("indexup" here) . Then it prints that index. import jax
from jax import jit, lax, vmap
import jax.numpy as jnp
def choose_index(index, indexup, σ, key2):
def cond1(state):
index, key2 = state
jax.debug.print("IndexUP: {x}", x=indexup)
return σ[index] != σ[indexup]
def body1(state):
index, key2 = state
k1, junk = jax.random.split(key2)
del junk
res = jnp.arange(len(σ))
resA = res[:-1]
resB = res[1:]
del_array = jnp.where(resA < indexup, resA, resB)
index2 = jax.random.choice(k1, del_array, shape=(1,))
jax.debug.print("Index 2: {x}", x=index2)
del key2
key2 = k1
return index2[0], key2
state = tuple([index,key2])
indexdn, key2= lax.while_loop(cond_fun=cond1, body_fun=body1, init_val=state)
return indexdn
length = 16
key = jax.random.PRNGKey(411843)
key1, key2, key3, = jax.random.split(key,3)
index = 0
indexup = 11
temp = jax.random.uniform(key3, shape=(length,))
σ = jnp.floor(3.0*temp) - 1.0
print("Sigma : ", σ)
indexdn = choose_index(index, indexup, σ, key2)
print("Indexdn : ", indexdn) This works fine. But, now I want the same thing for a stack of arrays, stack of indices and stack of keys: index = jnp.array([[0], [0], [0]])
indexup = jnp.array([[11],[11],[11]])
σp = jnp.stack([σ,σ,σ])
key2p = jax.random.split(key2,3) But when I try to do the vmap to the above function:
I get the following error:
How should I solve this problem ? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It seems like you're passing 2D arrays where your vmapped function would expect 1D arrays: index = jnp.array([0, 0, 0])
indexup = jnp.array([11, 11, 11])
σp = jnp.stack([σ,σ,σ])
key2p = jax.random.split(key2,3)
indexdn = jax.vmap(choose_index)(index, indexup, σp, key2p)
print(indexdn)
# [15 8 8] Alternatively, since it looks like you're wanting to pass the same index & sigma values for each key, you could equivalently do something like this to specify that you don't want these values mapped over: jax.vmap(choose_index, in_axes=(None, None, None, 0))(0, 11, σ, key2p) |
Beta Was this translation helpful? Give feedback.
It seems like you're passing 2D arrays where your vmapped function would expect 1D arrays:
Alternatively, since it looks like you're wanting to pass the same index & sigma values for each key, you could equivalently do something like this to specify that you don't want these values mapped over: