Attention, Transposed Convolutions, Embeddings, LayerNorm#38
Attention, Transposed Convolutions, Embeddings, LayerNorm#38patrick-kidger merged 47 commits intomainfrom
Conversation
ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm
|
I agree that your overhauled |
|
Excellent. I'm assuming you implemented As for transposed convolutions: staring at the implementation I'm not feeling convinced by it. These things jump out at me:
WDYT? (EDIT: if you haven't seen it before, this is quite a nice visual for transposed convolutions, where you can see how |
Tidied; simplified; generalised ConvTranspose implementation.
|
With the other branch merged in: let me know once you're satisfied that both |
|
Thanks both @lucidrains @andyehrenberg. I'll merge + release this PR tomorrow. |
Merging #34 into master via this branch, so as to make some tweaks.
Updates relative to #34:
MultiheadAttentionandLayerNormMultiheadAttentionimplementation. The implementation in ConvTranspose layers, MultiheadAttention, lookup embeddings, LayerNorm #34 was basically mimicking PyTorch's implementation, which has a very bad API. The new API is much more consistent, and much more general. (For reference, I found the Haiku implementation ofMultiheadAttentionto be the cleanest of all previous implementations.)Plus some other misc updates in support of these changes:
p=0.custom_types.Arrayandcustom_types.PyTreeto follow the convention in Diffrax; namely that they're now subscriptable and the documentation handles this.CC @andyehrenberg (and @lucidrains ?) for any commentary on these changes prior to merging them in.