Skip to content

Passing array of random 'keys' in vmap #13924

Answered by jakevdp
abhiroop513 asked this question in Q&A
Discussion options

You must be logged in to vote

It seems like you're passing 2D arrays where your vmapped function would expect 1D arrays:

index = jnp.array([0, 0, 0])
indexup = jnp.array([11, 11, 11])
σp = jnp.stack([σ,σ,σ])
key2p  = jax.random.split(key2,3)

indexdn = jax.vmap(choose_index)(index, indexup, σp, key2p)
print(indexdn)
# [15  8  8]

Alternatively, since it looks like you're wanting to pass the same index & sigma values for each key, you could equivalently do something like this to specify that you don't want these values mapped over:

jax.vmap(choose_index, in_axes=(None, None, None, 0))(0, 11, σ, key2p)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by abhiroop513
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants