Skip to content

Comments

ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm#34

Merged
patrick-kidger merged 19 commits intopatrick-kidger:attn-convt-layernormfrom
andyehrenberg:main
Mar 9, 2022
Merged

ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm#34
patrick-kidger merged 19 commits intopatrick-kidger:attn-convt-layernormfrom
andyehrenberg:main

Conversation

@andyehrenberg
Copy link
Contributor

Let me know if these should be put into separate PRs!

  • For ConvTranspose layers, I'm using the PyTorch style argument for output_padding instead of haiku's output_size - haiku has the limitation that padding can only be either 'same' or 'valid', though other options seem to be much less frequently useful. My current implementation also has this limitation - I haven't thought deeply enough about the best way to deal with any combination of padding and output_padding. Probably needs some fixing.

  • The other 3 modules feel like pretty typical, barebones implementations

  • When playing around with trying to reimplement grokking using equinox/these new layers, I found that trying to initialize an optax optimizer with an eqx.nn.MLP via optim.init(model) throws a TypeError about optax.transform.init_fn requiring all arguments to jnp.zeros_like to be arrays or scalars. My workaround is just using params, static = eqx.partition(model, eqx.is_inexact_array) and then opt_state = optim.init(params) - is this sort of thing unavoidable? Most other modules aren't suffering from this, so following along with your demo usually works.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 25, 2022

Okay, this is amazing! I'll go through and add comments in-line in a bit, but glancing over now this looks very well done.
(I'm not too fussed about making these separate PRs.)

CC @lucidrains -- I recall you saying on Reddit that you were thinking of implementing some of these, so please feel free to express an opinion / leave review comments if you want.

On initialising optimisers: yep, this is something I'm aware of, c.f. #15. I've been meaning to document it. I think what you're doing is the correct way to handle it.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, done making comments! There's quite a lot but that's only because this is such a big PR / and they're all pretty small/nitty stuff anyway. This is a great PR.

I also haven't thought too hard about padding for transposed convolutional layers. I think having just same/valid padding is probably fine for now, as long as you think it'll be easy to adjust without introducing any backward compatibility concerns, if need be?

By the way, do you want to (a) bump the version number, and (b) add these new layers to the documentation here?

class MultiheadAttention(Module):
"""
Multihead Attention layer from 'Attention Is All You Need' (https://arxiv.org/abs/1706.03762)
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add the mathematical formulation here?

PyTorch kind-of do this for attention here, but I'm imagining something more precise, a la here.

The documentation generation supports LaTeX: stuff $\alpha$ morestuff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added documentation in the latest commit - though I had to skip using flake8 because it's complaining about the '\i' and '\s' in '\intercal' and '\sqrt'. How can this be ignored by the checks in this repo?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because Python/flake8 is trying to interpret those as escape codes, like \n for a newline. Parsing of escape codes can be disabled by prefixing an r before the string, i.e.: r""" ... stuff ... """



class LayerNorm(Module):
"""Layer Normalization as described in https://arxiv.org/abs/1607.06450"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be great to have the precise mathematical description here, c.f. here, including the precise meaning of elementwise_affine. (By all means just copy what they've written verbatim if you do/don't wish.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And the meaning of normalized_shape, actually.)

@patrick-kidger
Copy link
Owner

@andyehrenberg Just following up on your plans for this? (I'd be happy to help out on this PR if you want.)

@andyehrenberg
Copy link
Contributor Author

@patrick-kidger Starting to work through your suggestions/corrections - got busy earlier this week. Thanks for the feedback!

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 4, 2022

Excellent! Let me know when this is ready for me to look at again.

@patrick-kidger patrick-kidger changed the base branch from main to attn-convt-layernorm March 9, 2022 20:33
@patrick-kidger patrick-kidger merged commit 1330439 into patrick-kidger:attn-convt-layernorm Mar 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants