lax.conv and vmap #4092
-
import jax
from jax import numpy as jnp
import numpy as onp
image = onp.random.randn(3,13,17)
batched_image = onp.random.randn(7,3,13,17)
weight = onp.random.randn(5,3,2,2)
def myconv(im, we):
return jax.lax.conv(im, we, (1,1), 'SAME')
myconv_ = jax.jit(myconv)
# myconv_(image, weight).shape <= error TypeError: convolution requires lhs and rhs ndim to be equal, got 3 and 4.
x = myconv_(batched_image, weight)
myconv_vmap = jax.vmap(myconv, (0, None))
# myconv_vmap(image, weight).shape <= error TypeError: convolution requires lhs and rhs ndim to be equal, got 2 and 4.
y = myconv_vmap(onp.expand_dims(batched_image,1), weight)
print('x.shape=', x.shape, 'y.shape=', y.shape, 'difference',((x - y.squeeze()) ** 2).sum()) perhaps i'm missing some thoughts behind this design decision behind |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
My understanding is that the API of Have you seen |
Beta Was this translation helpful? Give feedback.
-
I think the numpy/scipy functions might not support channels (e.g. in how the singleton axes are added here). Actually this is an old wart in JAX, and one we've never gotten around to fixing: see #381 (and probably several duplicates)! The right thing to do is adjust our conv primitive not to require a batch dimension. We've never done it, to our great shame, and that is an impediment to the We should finally fix this conv batch issue once and for all. I'll mention it at our team meeting today and see if we can get any takers. Let's track in #381. |
Beta Was this translation helpful? Give feedback.
-
thanks, @jakevdp and @mattjj! it'll be great to have it fixed to realized the "dream of liberating ourselves from batch dimensions", which is a great dream. |
Beta Was this translation helpful? Give feedback.
I think the numpy/scipy functions might not support channels (e.g. in how the singleton axes are added here).
Actually this is an old wart in JAX, and one we've never gotten around to fixing: see #381 (and probably several duplicates)! The right thing to do is adjust our conv primitive not to require a batch dimension. We've never done it, to our great shame, and that is an impediment to the
vmap
dream of liberating ourselves from batch dimensions. (Tangentially, another impediment was batch norm, butvmap
recently gained support for naming the mapped axis and applying collectives, so that's no longer an issue!)We should finally fix this conv batch issue once and for all. I'll mention it…