Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent batching behavior between MLPs and convnets #381

Closed
ericjang opened this issue Feb 15, 2019 · 14 comments
Closed

Inconsistent batching behavior between MLPs and convnets #381

ericjang opened this issue Feb 15, 2019 · 14 comments
Assignees
Labels
enhancement New feature or request question Questions for the JAX team

Comments

@ericjang
Copy link
Contributor

ericjang commented Feb 15, 2019

The following code snippet shows that Stax MLPs can be defined w.r.t. unbatched examples (input_size = (1,)) while Convnets seem to require a batch size (though it can be -1). Is this intended behavior?

# Works
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2, 2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax,
)

# Initialize parameters, not committing to a batch shape
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(in_shape)

# Works
net_init, net_apply = stax.serial(
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(1)
)
in_shape = (1,)
out_shape, net_params = net_init(in_shape)

# Doesn't Work
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2, 2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax,
)
in_shape = (28, 28, 1)
out_shape, net_params = net_init(in_shape)

The last one returns the following error:
IndexError Traceback (most recent call last)
in ()
9 # Initialize parameters, not committing to a batch shape
10 in_shape = (28, 28, 1)
---> 11 out_shape, net_params = net_init(in_shape)

google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
269 params = []
270 for init_fun in init_funs:
--> 271 input_shape, param = init_fun(input_shape)
272 params.append(param)
273 return input_shape, params

google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
109 kernel_shape = [out_chan if c == 'O' else
110 input_shape[lhs_spec.index('C')] if c == 'I' else
--> 111 next(filter_shape_iter) for c in rhs_spec]
112 output_shape = lax.conv_general_shape_tuple(
113 input_shape, kernel_shape, strides, padding, dimension_numbers)

IndexError: tuple index out of range

@mattjj mattjj added the bug Something isn't working label Feb 15, 2019
@mattjj mattjj self-assigned this Feb 15, 2019
@mattjj mattjj added question Questions for the JAX team enhancement New feature or request and removed bug Something isn't working labels Feb 15, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 15, 2019

Actually, rereading the issue, I think this is originally-intended behavior, but could be revised.

stax.Conv (which perhaps should be called Conv2D) requires a batch dimension essentially because the underlying XLA HLO (and corresponding lax function) requires a batch dimension. Notice how lhs and rhs must have ranks n+2 for n spatial dimensions, +1 for a channel dimension and +1 for a batch dimension.

We could revise the stax layer and/or the underlying lax primitive to work without a batch dimension (probably the latter so that it's easier to inherit the behavior in the former). Would that be useful to you? I'm guessing that, in a vmap world, we should take seriously the fact that we can remove batch dimensions from all our library code, including stax.

@ericjang
Copy link
Contributor Author

ericjang commented Feb 15, 2019 via email

@hawkinsp
Copy link
Collaborator

If we removed the batch dimension from stax, it's not obvious to me how to define batch norm.

@ericjang
Copy link
Contributor Author

ericjang commented Feb 15, 2019 via email

@hawkinsp
Copy link
Collaborator

The catch with allowing a batch dimension to be omitted in that way would be that then we would have an ambiguity when we support conv layers with different numbers of spatial dimensions. We could fix that by requiring, say, the conv layer to have an explicitly specified spatial dimension (e.g., Conv2D instead of simply Conv.)

@mattjj
Copy link
Collaborator

mattjj commented Feb 15, 2019

Btw stax.Conv already is really a Conv2D, with stax.GeneralConv allowing arbitrary spatial dimensions:

Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))

@ericjang
Copy link
Contributor Author

ericjang commented Feb 15, 2019 via email

@mattjj
Copy link
Collaborator

mattjj commented Feb 15, 2019

Just to clarify, you can vmap over functions that call lax.conv just fine; AIUI this is mainly a question of the (experimental) Stax API layer.

@mattjj
Copy link
Collaborator

mattjj commented Feb 15, 2019

To put a finer point on it: we like to think in terms of expressing and transforming functions. Expressing means just writing Python+NumPy(+lax) code. Transformations are things like automatic differentiation, vmap, jit, etc.

This is actually an issue about expressing an unbatched conv: our stax.Conv layer currently requires a batch dimension (as does the underlying lax.conv function).

To contrast, it's not an issue about transforming (autobatching) convs: you can add batch dimensions to your heart's content. But you have to start with a minimum of one batch dimension. As you point out, that's something peculiar to conv and not shared by operations like dot.

To summarize, I'd revise your statement to say that JAX supports autobatching just fine (vmap transformations can add arbitrary batch dimensions), but our stax.Conv and lax.conv have this peculiarity inherited from XLA that you can't directly express a convolution with no batch dimensions.

In any case, I think we agree that we should figure out a way to tweak stax.Conv and/or lax.conv to enable expressing unbatched convolution operations.

@mattjj
Copy link
Collaborator

mattjj commented Feb 15, 2019

We've got a plan! Will update this issue as we make progress.

@rsepassi
Copy link
Contributor

+1 on making Stax operate on single examples.

@matpalm
Copy link
Contributor

matpalm commented Apr 16, 2019

+1 just got confused by net_apply(net_params, batch_X) working, but vmap(partial(net_apply, net_params))(batch_X) failing with cryptic shape error for simple conv net in stax...

`

@hamzamerzic
Copy link
Contributor

+1 looking forward to having JAX fully batch agnostic! :)

@hawkinsp
Copy link
Collaborator

Closing since we aren't going to evolve stax any further. Others have built better and more fully-featured neural network libraries on top of JAX, such as Flax and Haiku.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants