double jax.vmap working as two nested loops, why single jax.vmap also working sometimes. #14057
-
Well, here is a snippet to discuss a little on import jax.numpy as jnp
class Position():
def __init__(self, x,y):
self.x = jnp.asarray(x)
self.y = jnp.asarray(y)
class Obj():
def __init__(self, sigma, flux):
self.flux = flux
self.sigma = sigma
self._sigsq = self.sigma**2
def _kValue(self, kpos):
ksq = (kpos.x**2 + kpos.y**2) * self._sigsq
return self.flux * jnp.exp(-0.5 * ksq) make a list with two Objects a_list = [Obj(sigma=1., flux=1.), Obj(sigma=2.,flux=0.1)] Then, define a class that manipulate list of Objects class A():
def __init__(self,a):
self.a_list = a
def f(self, kpos):
k_list = [obj._kValue(kpos) for obj in self.a_list]
return jnp.prod(jnp.array(k_list)) #*jnp.ones_like(kpos.x) (see another post)
def g(self, kpos):
return self.a_list[0]._kValue(kpos) Define also an array of positions coords = jax.random.normal(jax.random.PRNGKey(10), shape=(5,5,2))*0.2 Let us predict what would do a double for loop mimicing the "A:f" function for i in range(coords.shape[0]):
for j in range(coords.shape[1]):
print(i,j,jnp.prod(jnp.array([obj._kValue(Position(coords[i,j,0],coords[i,j,1])) for obj in a_list]))) One gets
Of course this double for loop looks horrible for JAX, so let us try with vmap using the A:f fucntion # instance of the A class with our list of Objects
tmpA = A(a_list)
#define a nested double vmap
vf = jax.vmap(jax.vmap(lambda i,j: tmpA.f(Position(coords[i,j,0],coords[i,j,1])), in_axes=(None, 0)), in_axes=(0, None))
# then apply it as
vf(jnp.arange(5), jnp.arange(5)) then one gets
which looks pretty good. The point is that if one does jax.vmap(lambda *args: tmpA.f(Position(*args)))(
coords[..., 0], coords[..., 1]
) the result is completely wrong
But comes (may be) the interesting part: If you look at the A:g function you will see that it uses only the first Object of the a_list. Now vg = jax.vmap(jax.vmap(lambda i,j: tmpA.g(Position(coords[i,j,0],coords[i,j,1])), in_axes=(None, 0)), in_axes=(0, None))
vg(jnp.arange(5), jnp.arange(5)) gives
AND INTERESTINGLY jax.vmap(lambda *args: tmpA.g(Position(*args)))(
coords[..., 0], coords[..., 1]
) gives the same correct result. So, what is puzzling me is the difference of behavior between
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Ah, I see the problem.
# postulate
a.shape == (m,)
b.shape == (n,)
vmap(vmap(f, in_axes=(None, 0)), in_axes=(0, None))(a, b) which gives [ f(a0, b0), f(a0, b1), ..., f(a0, bn);
f(a1, b1), f(a1, b1), ..., f(a1, bn);
...
f(am, bn), f(am, b1), ..., f(am, bn);
]
# postulate
a.shape == b.shape == (m, n)
vmap(vmap(f, in_axes=(0, 0)), in_axes=(0, 0))(a, b) which gives [ f(a00, b00), f(a01, b01), ..., f(a0n, b0n);
f(a10, b10), f(a11, b11), ..., f(a1n, b1n);
...
f(am0, bm0), f(am1, bm1), ..., f(amn, bmn);
] Now go back to your code X, Y = coords[..., 0], coords[..., 1]
# scenario 1, correct
f1 = vmap(vmap(lambda i, j: f(X[i, j], Y[i, j]),
in_axes=(None, 0)), in_axes=(0, None))
f1(jnp.arange(5), jnp.arange(5))
# scenario 2, incorrect
f2 = vmap(vmap(lambda x, y: f(x, y),
in_axes=(None, 0)), in_axes=(0, None))
f2(X, Y)
# scenario 3, correct
f3 = vmap(vmap(lambda x, y: f(x, y),
in_axes=(0, 0)), in_axes=(0, 0))
f3(X, Y) Your scenario 2 is essentially [ f(X[0,:], Y[0,:]), f(X[0,:], Y[1,:]), ..., f(X[0,:], Y[5,:]);
f(X[1,:], Y[1,:]), f(X[1,:], Y[1,:]), ..., f(X[1,:], Y[5,:]);
...
f(X[5,:], Y[0,:]), f(X[5,:], Y[1,:]), ..., f(X[5,:], Y[5,:]);
] It happens to run without raising shape error only because |
Beta Was this translation helpful? Give feedback.
Ah, I see the problem.
outer
-likewhich gives
vectorize
-likewhich gives
Now go back to your code