Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of base ViT model #105

Merged
merged 17 commits into from
Feb 11, 2022
Merged

Implementation of base ViT model #105

merged 17 commits into from
Feb 11, 2022

Conversation

theabhirath
Copy link
Member

This is an implementation of the base ViT model as detailed in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. This is quite the finicky model to implement and while I'm fairly sure I've done it correctly I would appreciate someone proofread the code just in case 😅 There's additional deps in the form of

  1. TensorCast (once Implementation of MLPMixer #103 gets merged then it should not be a problem) and
  2. LinearAlgebra. This is because of a utility function I had to write for batched matrix multiplication (I poached it off a StackOverflow answer, I hope that's not plagiarism). The function might not be the fastest tool for the job, though, so there's that additional thing to discuss too.

cc: @DhairyaLGandhi for encouraging me to start working on a ViT model 😄

@DhairyaLGandhi
Copy link
Member

Great go! Really excited to see this come through. I was thinking Transformers.jl might be a nice home for this to keep things organized.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great start. Why don't we tackle iterating the transformer layers first before moving onto ViT itself?

src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

It would be good to have the attention layer and the whole transformer model in flux since they are of general use. Pytorch added them a while ago

@theabhirath
Copy link
Member Author

So this code has been cleaned up a lot more - it looks much tighter IMO (I shifted a lot of the reusable layers to the layers.jl file as discussed). I think the main ViT model itself looks perfect - of course, if there's any design changes you think can be made, do let me know 😅. But the main issue now is that of the MH-Attention layer - it's still written in a fashion that is not AD-friendly and might definitely be slower than we want...

@theabhirath
Copy link
Member Author

theabhirath commented Feb 4, 2022

It would be good to have the attention layer and the whole transformer model in flux since they are of general use. PyTorch added them a while ago

I was trying to write this but then a look at some PyTorch code, especially transformer-based ones made me realise that since vanilla attention is not really used as is very often, people end up having to write their own code anyways. There's a lot of fancy implementations of attention with changes in the MLP blocks, in the heads etc but writing a very general version and exposing a lot of options for the user will make it a little cluttered

@theabhirath theabhirath changed the base branch from master to compathelper/new_version/2022-02-04-03-06-36-596-01515794626 February 4, 2022 10:20
@theabhirath theabhirath changed the base branch from compathelper/new_version/2022-02-04-03-06-36-596-01515794626 to master February 4, 2022 10:21
@theabhirath theabhirath requested a review from darsnack February 4, 2022 10:22
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome this is indeed looking quite a bit cleaner. I think the only major design consideration left is MHA as you mentioned. I haven't had the time to think about this, and I likely won't until this weekend has passed since I have a paper deadline on Sunday.

src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

Awesome this is indeed looking quite a bit cleaner. I think the only major design consideration left is MHA as you mentioned. I haven't had the time to think about this, and I likely won't until this weekend has passed since I have a paper deadline on Sunday.

No issues, I'll look up possible approaches that I can find in the meanwhile, good luck with your paper!

@darsnack
Copy link
Member

darsnack commented Feb 7, 2022

I had a chance to implement a Parallel-based approach, which looks much cleaner in my opinion. It also appears to be faster than the PR MHAttention on my machine. We'd still want to test GPU stuff.

Parallel implementation

using NNlib: batched_mul

struct Attention{T}
    qkv::T
end

Attention(in, out) = Attention(Dense(in, out * 3; bias = false))

@functor Attention

function (attn::Attention)(x::AbstractArray{T}) where T
    q, k, v = chunk(attn.qkv(x), 3; dim = 1)
    scale = convert(T, sqrt(size(q, 1)))
    score = softmax(batched_mul(batched_transpose(q), k) / scale)
    attention = batched_mul(v, score)

    return attention
end

struct MultiHead{T, S}
    heads::T
    projection::S
end

function MultiHead(in, out, nheads; dropout = 0.)
    inheads, outheads = chunk(1:in, nheads), chunk(1:out, nheads)
    heads = Parallel(vcat, [Attention(length(i), length(o)) for (i, o) in zip(inheads, outheads)]...)
    projection = Chain(Dense(out, out), Dropout(dropout))

    MultiHead(heads, projection)
end

@functor MultiHead

function (mha::MultiHead)(x)
    xhead = chunk(x, 3; dim = 1)

    return mha.projection(mha.heads(xhead...))
end

Benchmarks

This is the code I used to setup the benchmark

dh = 64
nheads = 3
d, n, b = dh * nheads, 20, 32
x = rand(Float32, d, n, b)

mha = MultiHead(d, d, nheads)
mhapr = MHAttention(d; heads = nheads, headplanes = dh)

Here are the results

julia> @benchmark $(mha)($x)
BenchmarkTools.Trial: 303 samples with 1 evaluation.
 Range (min  max):  14.584 ms  47.184 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     15.825 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   16.490 ms ±  2.918 ms  ┊ GC (mean ± σ):  0.43% ± 1.38%

   ▃▄▆█▄█▁▁   ▁                                                
  ▄████████▇▅██▄▄▄▅▅▅▃▃▃▁▂▃▂▁▃▁▁▂▁▁▁▁▁▁▁▂▂▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▂ ▃
  14.6 ms         Histogram: frequency by time        25.8 ms <

 Memory estimate: 5.26 MiB, allocs estimate: 438.

julia> @benchmark $(mhapr)($x)
BenchmarkTools.Trial: 85 samples with 1 evaluation.
 Range (min  max):  55.089 ms  62.719 ms  ┊ GC (min  max): 0.00%  3.66%
 Time  (median):     59.550 ms              ┊ GC (median):    3.79%
 Time  (mean ± σ):   59.256 ms ±  1.670 ms  ┊ GC (mean ± σ):  3.17% ± 1.71%

         ▁              ▁ ▁▃  ▁▁   ▃ ▁▆ ▃ ▃█▁                  
  ▄▁▁▁▄▄▄█▁▁▄▁▇▁▁▄▁▁▇▁▄▁█▄██▇▁██▄▇▇█▇██▇█▇███▁▇▁▁▇▄▇▁▁▁▄▄▁▁▄▄ ▁
  55.1 ms         Histogram: frequency by time        62.7 ms <

 Memory estimate: 50.50 MiB, allocs estimate: 5307.

Some other notes

I think the implementation in the PR is wrong. It seems to be returning all zeros which I think is due to the last @cast operation.

The chunk utility here should be a PR to Flux or maybe renamed.

I am using Flux#master.

Perhaps we can look at speeding up Parallel separately from this PR.

@darsnack
Copy link
Member

darsnack commented Feb 7, 2022

A few more numbers with (dh, nheads, n, b) == (64, 8, 100, 32).

julia> 
       @benchmark $(mha)($x)
BenchmarkTools.Trial: 24 samples with 1 evaluation.
 Range (min  max):  210.888 ms  217.349 ms  ┊ GC (min  max): 0.35%  0.42%
 Time  (median):     213.661 ms               ┊ GC (median):    0.35%
 Time  (mean ± σ):   213.556 ms ±   1.737 ms  ┊ GC (mean ± σ):  0.41% ± 0.13%

                █ █            ▃                                 
  ▇▁▁▇▇▇▁▁▁▁▇▁▁▁█▁█▁▁▁▁▁▁▁▁▇▇▁▇█▁▇▁▁▁▇▁▇▁▁▇▁▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▇ ▁
  211 ms           Histogram: frequency by time          217 ms <

 Memory estimate: 92.31 MiB, allocs estimate: 1164.

julia> 
       @benchmark $(mhapr)($x)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min  max):  1.881 s   1.898 s  ┊ GC (min  max): 0.51%  0.60%
 Time  (median):     1.886 s             ┊ GC (median):    0.51%
 Time  (mean ± σ):   1.888 s ± 8.711 ms  ┊ GC (mean ± σ):  0.54% ± 0.05%

  █               █                                      █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.88 s        Histogram: frequency by time         1.9 s <

 Memory estimate: 307.26 MiB, allocs estimate: 89820.

@darsnack darsnack mentioned this pull request Feb 7, 2022
46 tasks
@theabhirath
Copy link
Member Author

I think the implementation in the PR is wrong. It seems to be returning all zeros which I think is due to the last @cast operation.

Yeah there were quite some issues because I was trying to implement MHA in a rather more Pythonic way. Parallel seems to be a very nice Julian way to resolve the same problem - I made slight tweaks but overall your suggestion fits very well. It also solves the problem of taking on additional deps

The chunk utility here should be a PR to Flux or maybe renamed.

FluxML/Flux.jl#1841 😅

src/utilities.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
SkipConnection(prenorm(planes, mlpblock(planes, mlpplanes, dropout)), +))
for _ in 1:depth]

Chain(layers...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Flux master, this might be a good candidate for Chain(layers) to reduce load time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that layers has a concrete eltype here.

src/layers.jl Outdated Show resolved Hide resolved
src/utilities.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think we're at the point of cleaning up the final API. In addition to Michael's comments, I've left my own notes.

src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

I think I've covered all the main suggestions. There's a couple involving adding types to arguments that I'm not sure about - while I'm all for it, the other model APIs don't reflect the same. Likewise with the @assert vs throw ArgumentError cases - I think separate PRs to deal with those issues make more sense?

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, mostly docstrings that need updating. (Btw all of these suggestions in Github can be committed through the web interface. Makes it easier to make sure nothing is missed from a review).

src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/utilities.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/vit-based/vit.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
src/layers.jl Outdated Show resolved Hide resolved
@theabhirath theabhirath requested a review from darsnack February 10, 2022 01:45
src/utilities.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

I ran gradtest locally and it passes - quite faster than for the other models, in fact 😂. This should be good to go now

@theabhirath
Copy link
Member Author

theabhirath commented Feb 11, 2022

Whoops. Both Flux and MLUtils are exporting flatten. That's causing the problem

@darsnack darsnack merged commit dfc9a64 into FluxML:master Feb 11, 2022
@darsnack
Copy link
Member

Thank you for all the hard work and patience @theabhirath!

@theabhirath theabhirath deleted the vit branch February 12, 2022 01:02
@darsnack darsnack mentioned this pull request Mar 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants