Skip to content

Comments

Attention, Transposed Convolutions, Embeddings, LayerNorm#38

Merged
patrick-kidger merged 47 commits intomainfrom
attn-convt-layernorm
Mar 15, 2022
Merged

Attention, Transposed Convolutions, Embeddings, LayerNorm#38
patrick-kidger merged 47 commits intomainfrom
attn-convt-layernorm

Conversation

@patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Mar 10, 2022

Merging #34 into master via this branch, so as to make some tweaks.

Updates relative to #34:

  • Added new layers to the documentation.
  • Added the mathematical details to the docstrings for MultiheadAttention and LayerNorm
  • Substantially overhauled the MultiheadAttention implementation. The implementation in ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm #34 was basically mimicking PyTorch's implementation, which has a very bad API. The new API is much more consistent, and much more general. (For reference, I found the Haiku implementation of MultiheadAttention to be the cleanest of all previous implementations.)

Plus some other misc updates in support of these changes:

  • Dropout now takes a fast path when p=0.
  • Enabled MathJax in the documentation.
  • Bumped version number.
  • Updated custom_types.Array and custom_types.PyTree to follow the convention in Diffrax; namely that they're now subscriptable and the documentation handles this.
  • Standardised on the spelling "normalisation" (rather than "normalization") just because that's the convention I happen to follow.
  • Other misc doc tweaks.

CC @andyehrenberg (and @lucidrains ?) for any commentary on these changes prior to merging them in.

@andyehrenberg
Copy link
Contributor

I agree that your overhauled MultiheadAttention has a better API. When I was implementing it, I was trying to make the Haiku implementation fit in with the PyTorch API, but I agree that it's pretty bad (with sort of misleading arguments). I've been putting more thought into ConvTranspose lately and should have some extensions ready soon that take inspiration from how jax-ml/jax#5772 computes output sizes (mainly to deal with more output_padding and padding cases). It looks like I'm getting consistent outputs between this new implementation and Haiku's ConvTranspose but want to test things a bit more.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Mar 10, 2022

Excellent. I'm assuming you implemented MultiheadAttention on the basis that you need it for your own work? If you could give it a try and check that it looks like it's working for your use cases -- just to be sure -- then that'd be great.

As for transposed convolutions: staring at the implementation I'm not feeling convinced by it. These things jump out at me:

  1. this implementation only supports padding = 0 or padding = 1 (perhaps this is the issue you're describing above).
  2. self.padding is set to be a sequence-of-ints. But in the (untransposed) Conv it's set to be a sequence-of-(int, int).
  3. Would it be simpler to ignore lax.conv_transpose and work directly with lax.conv_general_dilated? Looking at the implementation of lax.conv_transpose, it's just a thin wrapper around lax.conv_general_dilated and I don't think we really hit any of the meaningful functionality that that wrapper provides.
  4. We can probably compute dimension_numbers for all dimensions, just by working directly with lax.ConvDimensionNumbers instead of the simplified string replications. (It doesn't look that tricky.)

WDYT?

(EDIT: if you haven't seen it before, this is quite a nice visual for transposed convolutions, where you can see how stride > 1 corresponds to "fractional strides", i.e. the lhs_dilation argument to lax.conv_general_dilated.)

@patrick-kidger
Copy link
Owner Author

With the other branch merged in: let me know once you're satisfied that both MultiheadAttention and ConvTranspose work for your use cases and I'll merge this branch + do a new release.

@patrick-kidger
Copy link
Owner Author

Thanks both @lucidrains @andyehrenberg. I'll merge + release this PR tomorrow.

@patrick-kidger patrick-kidger merged commit 3343e84 into main Mar 15, 2022
@patrick-kidger patrick-kidger deleted the attn-convt-layernorm branch March 17, 2022 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants