Skip to content

Jax vmap, in_axes doesn't work if keyword argument is passed #13836

Answered by jakevdp
Amith225 asked this question in Q&A
Discussion options

You must be logged in to vote

Yes, it's true that vmap in_axes only works for positional arguments. If you want to make a more general vmapped function, the best option currently is probably to use a wrapper function. For example:

def _foo(a, b, c):
    return a * b + c

def foo(a, b, c):
  return vmap(_foo, in_axes=(0, 0, None))(a, b, c)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Amith225
Comment options

Answer selected by Amith225
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants