Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap support with stax #815

Closed
samuela opened this issue Jun 4, 2019 · 2 comments
Closed

vmap support with stax #815

samuela opened this issue Jun 4, 2019 · 2 comments

Comments

@samuela
Copy link
Contributor

samuela commented Jun 4, 2019

So vmap looks awesome and I'd like to use it everywhere I possibly can. However it looks like it's not really compatible with stax, and in particular convolutions.

from functools import partial

import jax.numpy as np
from jax import random, vmap
from jax.experimental.stax import Conv, GeneralConv

net_init, net_apply = Conv(32, (3, 3), padding='SAME')

in_shape = (28, 28, 1)
out_shape, net_params = net_init(random.PRNGKey(0), in_shape)

# Apply network to dummy inputs
inputs = np.zeros((128, 28, 28, 1))
predictions = vmap(partial(net_apply, net_params))(inputs)
print(predictions.shape)

gives an error

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
~/nu/skainswo/research/gan_with_the_wind/vae.py in <module>
     13 
     14 in_shape = (28, 28, 1)
---> 15 out_shape, net_params = net_init(random.PRNGKey(0), in_shape)
     16 
     17 # Apply network to dummy inputs

~/.local/share/virtualenvs/research-OGGq2tNy/lib/python3.7/site-packages/jax/experimental/stax.py in init_fun(rng, input_shape)
    112     kernel_shape = [out_chan if c == 'O' else
    113                     input_shape[lhs_spec.index('C')] if c == 'I' else
--> 114                     next(filter_shape_iter) for c in rhs_spec]
    115     output_shape = lax.conv_general_shape_tuple(
    116         input_shape, kernel_shape, strides, padding, dimension_numbers)

~/.local/share/virtualenvs/research-OGGq2tNy/lib/python3.7/site-packages/jax/experimental/stax.py in <listcomp>(.0)
    112     kernel_shape = [out_chan if c == 'O' else
    113                     input_shape[lhs_spec.index('C')] if c == 'I' else
--> 114                     next(filter_shape_iter) for c in rhs_spec]
    115     output_shape = lax.conv_general_shape_tuple(
    116         input_shape, kernel_shape, strides, padding, dimension_numbers)

IndexError: tuple index out of range

which makes some sense based on the dimension_numbers in the definition here: https://github.com/google/jax/blob/master/jax/experimental/stax.py#L127. But using

net_init, net_apply = GeneralConv(("HWC", "HWIO", "HWC"), 32, (3, 3))

doesn't work either:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/nu/skainswo/research/gan_with_the_wind/vae.py in <module>
     13 
     14 in_shape = (28, 28, 1)
---> 15 out_shape, net_params = net_init(random.PRNGKey(0), in_shape)
     16 
     17 # Apply network to dummy inputs

~/.local/share/virtualenvs/research-OGGq2tNy/lib/python3.7/site-packages/jax/experimental/stax.py in init_fun(rng, input_shape)
    114                     next(filter_shape_iter) for c in rhs_spec]
    115     output_shape = lax.conv_general_shape_tuple(
--> 116         input_shape, kernel_shape, strides, padding, dimension_numbers)
    117     bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    118     bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))

~/.local/share/virtualenvs/research-OGGq2tNy/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding, dimension_numbers)
   3858 def conv_general_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,
   3859                              dimension_numbers):
-> 3860   lhs_perm, rhs_perm, out_perm = conv_general_permutations(dimension_numbers)
   3861   lhs_trans = onp.take(lhs_shape, lhs_perm)
   3862   rhs_trans = onp.take(rhs_shape, rhs_perm)

~/.local/share/virtualenvs/research-OGGq2tNy/lib/python3.7/site-packages/jax/lax/lax.py in conv_general_permutations(dimension_numbers)
   4002       msg = ("convolution dimension_numbers[{}] must contain the characters "
   4003              "'{}' and '{}' exatly once, got {}.")
-> 4004       raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
   4005     if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
   4006       msg = ("convolution dimension_numbers[{}] cannot have duplicate "

TypeError: convolution dimension_numbers[0] must contain the characters 'N' and 'C' exatly once, got HWC.
@joaogui1
Copy link
Contributor

I believe this is related to #381 and #931

@hawkinsp
Copy link
Collaborator

Closing since we aren't going to evolve stax any further. Others have built better and more fully-featured neural network libraries on top of JAX, such as Flax and Haiku.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants