You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So vmap looks awesome and I'd like to use it everywhere I possibly can. However it looks like it's not really compatible with stax, and in particular convolutions.
Closing since we aren't going to evolve stax any further. Others have built better and more fully-featured neural network libraries on top of JAX, such as Flax and Haiku.
So
vmap
looks awesome and I'd like to use it everywhere I possibly can. However it looks like it's not really compatible with stax, and in particular convolutions.gives an error
which makes some sense based on the
dimension_numbers
in the definition here: https://github.com/google/jax/blob/master/jax/experimental/stax.py#L127. But usingdoesn't work either:
The text was updated successfully, but these errors were encountered: