jax.vmap output: an explanation is required to get a 2D output array. #14054
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments 5 replies
-
Roughly, shape signature of the functions are: # postulate
# X: tuple[int, ...]
# n: int
# x, y: ndarray
# x.shape == y.shape
# X, X -> ()
# (n, *X), (n, *X) -> (n,) under vmap
def fext1(x, y):
return jnp.asarray(-1.)
# X, X -> X
# (n, *X), (n, *X) -> (n, *X) under vmap
def fext2(x, y):
return -jnp.ones_like(x)
# X, X -> X
# (n, *X), (n, *X) -> (n, *X) under vmap
def fext3(kpos):
return x + y |
Beta Was this translation helpful? Give feedback.
0 replies
-
So, if I want to make f(x,y) returning a number which is the result of computations with kpos.x and kpos.y, I should multiply it by jnp.ones_like(kpos.x) to get it right? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
Here is a snippet that exhibit different vmap outputs, with one that I do not expect and I do not manage to get it right neither. So I need help. Thanks
Then with
I get
which I do not expect, while with
I get what I was expecting as a 10x10 array
and the fext3 is ok too.
So, it seems after some other trials that the output of the function fext(kpos) should be of a certain type, but I do not manage to get it the correct one. Any idea is welcome.
Beta Was this translation helpful? Give feedback.
All reactions