-
The parameter from jax import vmap
import numpy as np
def foo(a, b, c):
return a * b + c
foo = vmap(foo, in_axes=(0, 0, None))
aj, bj = np.random.rand(2, 100, 1)
foo(aj, bj, 10) # works
foo(aj, bj, c=10) # throws error console
how would one go about running foo as |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Dec 31, 2022
Replies: 1 comment 1 reply
-
Yes, it's true that def _foo(a, b, c):
return a * b + c
def foo(a, b, c):
return vmap(_foo, in_axes=(0, 0, None))(a, b, c) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
Amith225
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, it's true that
vmap
in_axes
only works for positional arguments. If you want to make a more general vmapped function, the best option currently is probably to use a wrapper function. For example: