ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm#34
Conversation
|
Okay, this is amazing! I'll go through and add comments in-line in a bit, but glancing over now this looks very well done. CC @lucidrains -- I recall you saying on Reddit that you were thinking of implementing some of these, so please feel free to express an opinion / leave review comments if you want. On initialising optimisers: yep, this is something I'm aware of, c.f. #15. I've been meaning to document it. I think what you're doing is the correct way to handle it. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Right, done making comments! There's quite a lot but that's only because this is such a big PR / and they're all pretty small/nitty stuff anyway. This is a great PR.
I also haven't thought too hard about padding for transposed convolutional layers. I think having just same/valid padding is probably fine for now, as long as you think it'll be easy to adjust without introducing any backward compatibility concerns, if need be?
By the way, do you want to (a) bump the version number, and (b) add these new layers to the documentation here?
| class MultiheadAttention(Module): | ||
| """ | ||
| Multihead Attention layer from 'Attention Is All You Need' (https://arxiv.org/abs/1706.03762) | ||
| """ |
There was a problem hiding this comment.
Added documentation in the latest commit - though I had to skip using flake8 because it's complaining about the '\i' and '\s' in '\intercal' and '\sqrt'. How can this be ignored by the checks in this repo?
There was a problem hiding this comment.
This is because Python/flake8 is trying to interpret those as escape codes, like \n for a newline. Parsing of escape codes can be disabled by prefixing an r before the string, i.e.: r""" ... stuff ... """
|
|
||
|
|
||
| class LayerNorm(Module): | ||
| """Layer Normalization as described in https://arxiv.org/abs/1607.06450""" |
There was a problem hiding this comment.
I think it'd be great to have the precise mathematical description here, c.f. here, including the precise meaning of elementwise_affine. (By all means just copy what they've written verbatim if you do/don't wish.)
There was a problem hiding this comment.
(And the meaning of normalized_shape, actually.)
|
@andyehrenberg Just following up on your plans for this? (I'd be happy to help out on this PR if you want.) |
|
@patrick-kidger Starting to work through your suggestions/corrections - got busy earlier this week. Thanks for the feedback! |
|
Excellent! Let me know when this is ready for me to look at again. |
Let me know if these should be put into separate PRs!
For ConvTranspose layers, I'm using the PyTorch style argument for output_padding instead of haiku's output_size - haiku has the limitation that padding can only be either 'same' or 'valid', though other options seem to be much less frequently useful. My current implementation also has this limitation - I haven't thought deeply enough about the best way to deal with any combination of padding and output_padding. Probably needs some fixing.
The other 3 modules feel like pretty typical, barebones implementations
When playing around with trying to reimplement grokking using equinox/these new layers, I found that trying to initialize an optax optimizer with an
eqx.nn.MLPviaoptim.init(model)throws a TypeError aboutoptax.transform.init_fnrequiring all arguments tojnp.zeros_liketo be arrays or scalars. My workaround is just usingparams, static = eqx.partition(model, eqx.is_inexact_array)and thenopt_state = optim.init(params)- is this sort of thing unavoidable? Most other modules aren't suffering from this, so following along with your demo usually works.