-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistent batching behavior between MLPs and convnets #381
Comments
Actually, rereading the issue, I think this is originally-intended behavior, but could be revised.
We could revise the stax layer and/or the underlying lax primitive to work without a batch dimension (probably the latter so that it's easier to inherit the behavior in the former). Would that be useful to you? I'm guessing that, in a |
I am fine with refactoring my code to use batched convs (that’s reasonable,
given the underlying primitive is vectorized for efficiency). But yeah,
this was surprising because I was operating under the assumption that the
“vmap world” you mention should support the ability to define nets that
don’t take into account the batch dimension.
…On Fri, Feb 15, 2019 at 6:44 AM Matthew Johnson ***@***.***> wrote:
Actually, rereading the issue, I think this is originally-intended
behavior, but could be revised.
stax.Conv (which perhaps should be called Conv2D) requires a batch
dimension essentially because the underlying XLA HLO (and corresponding
lax function) requires a batch dimension
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>.
Notice how lhs and rhs must have ranks n+2 for n spatial dimensions, +1 for
a channel dimension and +1 for a batch dimension.
We could revise the stax layer and/or the underlying lax primitive to work
without a batch dimension (probably the latter so that it's easier to
inherit the behavior in the former). Would that be useful to you? I'm
guessing that, in a vmap world, we should take seriously the fact that we
can remove batch dimensions from all our library code, including stax.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#381 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AAacMdgBezXnMPJi9kdFW2Km5qmv-MGxks5vNsflgaJpZM4a88fG>
.
|
If we removed the batch dimension from stax, it's not obvious to me how to define batch norm. |
Agreed, there definitely should be the flexibility to define batched nets
(for models whose forward pass requires minibatches). But given the
behavior of MLP, a user can easily suspect that Conv would automatically
pretend a singleton batch under the hood if ndims==3. It is confusing if
some primitives assume batching while others do not.
…On Fri, Feb 15, 2019 at 7:02 AM Peter Hawkins ***@***.***> wrote:
If we removed the batch dimension from stax, it's not obvious to me how to
define batch norm.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#381 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AAacMZHDAIZV7SL4uef---NOjzL3FVRFks5vNsv0gaJpZM4a88fG>
.
|
The catch with allowing a batch dimension to be omitted in that way would be that then we would have an ambiguity when we support conv layers with different numbers of spatial dimensions. We could fix that by requiring, say, the conv layer to have an explicitly specified spatial dimension (e.g., |
Btw Conv = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC')) |
I don’t have an opinion on general conv vs. Conv2D etc, but in the short
term, some kind of error message indicating why Conv requires a batch
dimension would be helpful for debugging purposes. Some libraries (dynet)
*do* support autobatching, which I assumed with vmap that this was also the
case in Jax.
…On Fri, Feb 15, 2019 at 7:08 AM Peter Hawkins ***@***.***> wrote:
The catch with allowing a batch dimension to be omitted in that way would
be that then we would have an ambiguity when we support conv layers with
different numbers of spatial dimensions. We could fix that by requiring,
say, the conv layer to have an explicitly specified spatial dimension
(e.g., Conv2D instead of simply Conv.)
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#381 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AAacMcpLvNv5D48RPG2pxPoK2q2Lccepks5vNs1sgaJpZM4a88fG>
.
|
Just to clarify, you can |
To put a finer point on it: we like to think in terms of expressing and transforming functions. Expressing means just writing Python+NumPy(+lax) code. Transformations are things like automatic differentiation, vmap, jit, etc. This is actually an issue about expressing an unbatched conv: our To contrast, it's not an issue about transforming (autobatching) convs: you can add batch dimensions to your heart's content. But you have to start with a minimum of one batch dimension. As you point out, that's something peculiar to conv and not shared by operations like dot. To summarize, I'd revise your statement to say that JAX supports autobatching just fine (vmap transformations can add arbitrary batch dimensions), but our In any case, I think we agree that we should figure out a way to tweak |
We've got a plan! Will update this issue as we make progress. |
+1 on making Stax operate on single examples. |
+1 just got confused by ` |
+1 looking forward to having JAX fully batch agnostic! :) |
Closing since we aren't going to evolve |
The following code snippet shows that Stax MLPs can be defined w.r.t. unbatched examples (input_size = (1,)) while Convnets seem to require a batch size (though it can be -1). Is this intended behavior?
The last one returns the following error:
IndexError Traceback (most recent call last)
in ()
9 # Initialize parameters, not committing to a batch shape
10 in_shape = (28, 28, 1)
---> 11 out_shape, net_params = net_init(in_shape)
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
269 params = []
270 for init_fun in init_funs:
--> 271 input_shape, param = init_fun(input_shape)
272 params.append(param)
273 return input_shape, params
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape)
109 kernel_shape = [out_chan if c == 'O' else
110 input_shape[lhs_spec.index('C')] if c == 'I' else
--> 111 next(filter_shape_iter) for c in rhs_spec]
112 output_shape = lax.conv_general_shape_tuple(
113 input_shape, kernel_shape, strides, padding, dimension_numbers)
IndexError: tuple index out of range
The text was updated successfully, but these errors were encountered: