-
-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of base ViT model #105
Conversation
Great go! Really excited to see this come through. I was thinking Transformers.jl might be a nice home for this to keep things organized. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great start. Why don't we tackle iterating the transformer layers first before moving onto ViT itself?
It would be good to have the attention layer and the whole transformer model in flux since they are of general use. Pytorch added them a while ago |
So this code has been cleaned up a lot more - it looks much tighter IMO (I shifted a lot of the reusable layers to the |
I was trying to write this but then a look at some PyTorch code, especially transformer-based ones made me realise that since vanilla attention is not really used as is very often, people end up having to write their own code anyways. There's a lot of fancy implementations of attention with changes in the MLP blocks, in the heads etc but writing a very general version and exposing a lot of options for the user will make it a little cluttered |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome this is indeed looking quite a bit cleaner. I think the only major design consideration left is MHA as you mentioned. I haven't had the time to think about this, and I likely won't until this weekend has passed since I have a paper deadline on Sunday.
No issues, I'll look up possible approaches that I can find in the meanwhile, good luck with your paper! |
I had a chance to implement a
|
A few more numbers with julia>
@benchmark $(mha)($x)
BenchmarkTools.Trial: 24 samples with 1 evaluation.
Range (min … max): 210.888 ms … 217.349 ms ┊ GC (min … max): 0.35% … 0.42%
Time (median): 213.661 ms ┊ GC (median): 0.35%
Time (mean ± σ): 213.556 ms ± 1.737 ms ┊ GC (mean ± σ): 0.41% ± 0.13%
█ █ ▃
▇▁▁▇▇▇▁▁▁▁▇▁▁▁█▁█▁▁▁▁▁▁▁▁▇▇▁▇█▁▇▁▁▁▇▁▇▁▁▇▁▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▇ ▁
211 ms Histogram: frequency by time 217 ms <
Memory estimate: 92.31 MiB, allocs estimate: 1164.
julia>
@benchmark $(mhapr)($x)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
Range (min … max): 1.881 s … 1.898 s ┊ GC (min … max): 0.51% … 0.60%
Time (median): 1.886 s ┊ GC (median): 0.51%
Time (mean ± σ): 1.888 s ± 8.711 ms ┊ GC (mean ± σ): 0.54% ± 0.05%
█ █ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
1.88 s Histogram: frequency by time 1.9 s <
Memory estimate: 307.26 MiB, allocs estimate: 89820. |
Yeah there were quite some issues because I was trying to implement MHA in a rather more Pythonic way.
|
SkipConnection(prenorm(planes, mlpblock(planes, mlpplanes, dropout)), +)) | ||
for _ in 1:depth] | ||
|
||
Chain(layers...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Flux master, this might be a good candidate for Chain(layers)
to reduce load time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice that layers
has a concrete eltype here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think we're at the point of cleaning up the final API. In addition to Michael's comments, I've left my own notes.
I think I've covered all the main suggestions. There's a couple involving adding types to arguments that I'm not sure about - while I'm all for it, the other model APIs don't reflect the same. Likewise with the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, mostly docstrings that need updating. (Btw all of these suggestions in Github can be committed through the web interface. Makes it easier to make sure nothing is missed from a review).
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
I ran |
Whoops. Both Flux and MLUtils are exporting |
Thank you for all the hard work and patience @theabhirath! |
This is an implementation of the base ViT model as detailed in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. This is quite the finicky model to implement and while I'm fairly sure I've done it correctly I would appreciate someone proofread the code just in case 😅 There's additional deps in the form of
cc: @DhairyaLGandhi for encouraging me to start working on a ViT model 😄