Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash-attention-like gpu kernel #23

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

flash-attention-like gpu kernel #23

wants to merge 2 commits into from

Conversation

chengchingwen
Copy link
Owner

@chengchingwen chengchingwen commented Jan 21, 2024

This is an initial attempt to adopt the technique used in flash attention. The implementation basically follows the pseudo-code in flash attention 2 paper. The code is done in a CUDA WMMA fashion, so we should be able to opt-in/out WMMA instructions.

Currently, this is only a draft for testing. I'll see if it is possible to merge with the existing attention interface. It should at least support arbitrary masks, dropout, and the backward function.

Some simple benchmark (with RTX 3090 24GB):

julia> using NeuralAttentionlib, CUDA, BenchmarkTools; using NeuralAttentionlib: Flash

julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  1.143 ms (129 allocations: 5.64 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  1.222 ms (74 allocations: 1.95 KiB)

julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  39.035 ms (121 allocations: 5.23 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  14.304 ms (75 allocations: 1.97 KiB)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  642.535 ms (121 allocations: 5.23 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  116.808 ms (75 allocations: 1.97 KiB)

julia> x = CUDA.randn(Float32, 128, 4096, 128); y = CUDA.randn(Float32, 128, 4096, 128); z = CUDA.randn(Float32, 128, 4096, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  653.187 ms (121 allocations: 5.23 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  424.653 ms (75 allocations: 1.97 KiB)

julia>  CUDA.@time NeuralAttentionlib.naive_qkv_attention(x, y, z);
  0.676641 seconds (1.07 k CPU allocations: 45.891 KiB) (5 GPU allocations: 8.252 GiB, 0.01% memmgmt time)

julia> CUDA.@time Flash.flash_attention_forward(x, y, z);
  0.434390 seconds (76 CPU allocations: 2.000 KiB) (2 GPU allocations: 258.000 MiB, 0.00% memmgmt time)

julia> using Zygote

julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32); dO = CUDA.randn(Float32, 64, 512, 32);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
  2.269 ms (202 allocations: 8.53 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  3.759 ms (72 allocations: 2.75 KiB)

julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128); dO = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
  73.121 ms (189 allocations: 7.91 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  53.824 ms (72 allocations: 2.75 KiB)

julia> CUDA.@time Zygote.pullback(NeuralAttentionlib.naive_qkv_attention, x, y, z)[2](dO);
  0.160582 seconds (5.83 k CPU allocations: 324.297 KiB) (11 GPU allocations: 1.626 GiB, 0.09% memmgmt time)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
  0.081046 seconds (86 CPU allocations: 3.953 KiB) (5 GPU allocations: 128.500 MiB, 0.07% memmgmt time)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128); dO = CUDA.randn(Float32, 32, 4096, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
ERROR: Out of GPU memory ...

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  395.993 ms (72 allocations: 2.75 KiB)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
  0.444120 seconds (76 CPU allocations: 2.859 KiB) (5 GPU allocations: 258.000 MiB, 0.01% memmgmt time)

(with A100 80GB):

julia> using NeuralAttentionlib, CUDA, BenchmarkTools; using NeuralAttentionlib: Flash

julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  1.301 ms (212 allocations: 12.70 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  1.021 ms (58 allocations: 3.02 KiB)

julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  8.954 ms (206 allocations: 12.30 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  12.566 ms (59 allocations: 3.03 KiB)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  156.268 ms (208 allocations: 12.33 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  94.749 ms (59 allocations: 3.03 KiB)

julia> x = CUDA.randn(Float32, 128, 4096, 128); y = CUDA.randn(Float32, 128, 4096, 128); z = CUDA.randn(Float32, 128, 4096, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  177.827 ms (208 allocations: 12.33 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  370.966 ms (59 allocations: 3.03 KiB)

julia> CUDA.@time NeuralAttentionlib.naive_qkv_attention(x, y, z);
  0.184264 seconds (6.77 k CPU allocations: 455.398 KiB) (5 GPU allocations: 8.252 GiB, 0.03% memmgmt time)

julia> CUDA.@time Flash.flash_attention_forward(x, y, z);
  0.371610 seconds (60 CPU allocations: 3.062 KiB) (2 GPU allocations: 258.000 MiB, 0.01% memmgmt time)

julia> using Zygote

julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32); dO = CUDA.randn(Float32, 64, 512, 32);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
  2.659 ms (329 allocations: 20.06 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  3.362 ms (132 allocations: 7.28 KiB)

julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128); dO = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
  20.461 ms (324 allocations: 19.52 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  43.777 ms (132 allocations: 7.28 KiB)

julia> CUDA.@time Zygote.pullback(NeuralAttentionlib.naive_qkv_attention, x, y, z)[2](dO);
  0.065213 seconds (5.97 k CPU allocations: 336.094 KiB) (11 GPU allocations: 1.626 GiB, 0.07% memmgmt time)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
  0.044100 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 128.500 MiB, 0.05% memmgmt time)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128); dO = CUDA.randn(Float32, 32, 4096, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
  341.050 ms (327 allocations: 19.56 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  307.269 ms (132 allocations: 7.28 KiB)

julia> CUDA.@time Zygote.pullback(NeuralAttentionlib.naive_qkv_attention, x, y, z)[2](dO);
  0.342542 seconds (386 CPU allocations: 30.094 KiB) (11 GPU allocations: 24.254 GiB, 0.02% memmgmt time)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
  0.308509 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 258.000 MiB, 0.01% memmgmt time)

@chengchingwen chengchingwen mentioned this pull request Jan 21, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: 1134 lines in your changes are missing coverage. Please review.

Comparison is base (40922f8) 74.51% compared to head (1c3e98a) 48.07%.

Files Patch % Lines
src/flash/launch.jl 0.00% 264 Missing ⚠️
src/flash/backward.jl 0.00% 226 Missing ⚠️
src/flash/forward.jl 0.00% 195 Missing ⚠️
src/flash/forward_utils.jl 0.00% 170 Missing ⚠️
src/flash/utils.jl 0.00% 106 Missing ⚠️
src/flash/mma.jl 0.00% 102 Missing ⚠️
src/flash/backward_utils.jl 0.00% 52 Missing ⚠️
src/mask/indexer.jl 71.79% 11 Missing ⚠️
src/mask/mask.jl 66.66% 7 Missing ⚠️
src/mask/broadcast.jl 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master      #23       +/-   ##
===========================================
- Coverage   74.51%   48.07%   -26.45%     
===========================================
  Files          30       38        +8     
  Lines        2052     3162     +1110     
===========================================
- Hits         1529     1520        -9     
- Misses        523     1642     +1119     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@FinnWeng
Copy link

FinnWeng commented Feb 5, 2024

(with 4090)

julia> using NeuralAttentionlib, CUDA, BenchmarkTools; using NeuralAttentionlib: Flash

julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  604.971 μs (211 allocations: 12.69 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  578.523 μs (59 allocations: 3.03 KiB)


julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z);
  8.371 ms (206 allocations: 12.30 KiB) 

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z);
  6.397 ms (59 allocations: 3.03 KiB)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128); 

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z); 
  131.801 ms (208 allocations: 12.33 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z); <=
   52.394 ms (59 allocations: 3.03 KiB)
(Float32[0.012302159 0.017965356 … -0.00015686426 0.014433293; -0.018675877 0.006785178 … 0.0022356245 0.0063521285; … ; -0.032304276 -0.042408943 … -0.039765503 -0.011026613; -0.008006868 0.03263003 … -0.012576684 0.007417702;;; -0.018318532 -0.0031435888 … 0.007511898 -0.034395427; -0.039739955 0.0067583104 … 0.001627093 0.004926813; … ; -0.014008925 -0.027526634 … -0.022851078 -0.03130795; 0.031661086 0.016136607 … 0.028245382 0.010535367;;; 0.008197796 0.026282633 … -0.0089175515 -0.03374061; -0.049456023 -0.033448916 … -0.03887646 -0.0071907463; … ; 0.025568059 -0.009142641 … 0.024252633 -0.0011169636; -0.00068929675 0.031708676 … 0.015330821 0.040038884;;; … ;;; 0.013630414 0.04991551 … -0.0006895974 -0.0067350264; -0.0010544751 -0.018107645 … 0.0016674573 0.026153116; … ; -0.021417249 0.0038626934 … 0.0025041953 0.012506626; -0.020112598 -0.014542421 … 0.012148654 -0.012378455;;; 0.0074338443 0.04256689 … -0.024805214 -0.03777187; -0.004147524 -0.0038875814 … -0.046487812 -0.010641718; … ; 0.0037521122 0.007813704 … 0.008547867 -0.0024645203; 0.021574477 0.0029059318 … -0.02713483 -0.019076949;;; -0.05634497 0.041607294 … 0.012316624 -0.0070284223; 0.015494817 -0.018000303 … 0.03147435 -0.050554577; … ; -0.04466497 -0.0039102207 … -0.036508568 -0.04374122; 0.048409604 -0.0018546274 … 0.020939333 0.010200917], Float32[8.759159 8.817464 … 8.727321 8.78187;;; 8.877047 8.600131 … 8.591627 8.652889;;; 9.015136 8.857832 … 8.761169 8.694989;;; … ;;; 8.665132 8.844432 … 8.560938 8.978898;;; 8.554172 8.873826 … 8.87549 9.010774;;; 8.905191 8.768375 … 8.803723 8.813686])


julia> x = CUDA.randn(Float32, 128, 4096, 128); y = CUDA.randn(Float32, 128, 4096, 128); z = CUDA.randn(Float32, 128, 4096, 128); <=

julia> @btime CUDA.@sync NeuralAttentionlib.naive_qkv_attention($x, $y, $z); 
  138.517 ms (208 allocations: 12.33 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_forward($x, $y, $z); 
  192.907 ms (59 allocations: 3.03 KiB)

julia>  CUDA.@time NeuralAttentionlib.naive_qkv_attention(x, y, z); 
  0.149221 seconds (6.77 k CPU allocations: 455.070 KiB) (5 GPU allocations: 8.252 GiB, 0.02% memmgmt time)

julia> CUDA.@time Flash.flash_attention_forward(x, y, z); 
  0.208802 seconds (60 CPU allocations: 3.062 KiB) (2 GPU allocations: 258.000 MiB, 0.01% memmgmt time)



julia> using Zygote


julia> x = CUDA.randn(Float32, 64, 512, 32); y = CUDA.randn(Float32, 64, 512, 32); z = CUDA.randn(Float32, 64, 512, 32); dO = CUDA.randn(Float32, 64, 512, 32); 

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO);
   602.745 μs (209 allocations: 12.53 KiB)
#75 (generic function with 1 method)


julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z); 
  1.724 ms (132 allocations: 7.28 KiB)

julia> x = CUDA.randn(Float32, 64, 1024, 128); y = CUDA.randn(Float32, 64, 1024, 128); z = CUDA.randn(Float32, 64, 1024, 128); dO = CUDA.randn(Float32, 64, 1024, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO); 
  16.974 ms (324 allocations: 19.52 KiB)

julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  24.202 ms (132 allocations: 7.28 KiB)

julia> CUDA.@time Zygote.pullback(NeuralAttentionlib.naive_qkv_attention, x, y, z)[2](dO);  
  0.051180 seconds (5.96 k CPU allocations: 335.906 KiB) (11 GPU allocations: 1.626 GiB, 0.07% memmgmt time)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);
  0.033187 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 128.500 MiB, 0.05% memmgmt time)

julia> x = CUDA.randn(Float32, 32, 4096, 128); y = CUDA.randn(Float32, 32, 4096, 128); z = CUDA.randn(Float32, 32, 4096, 128); dO = CUDA.randn(Float32, 32, 4096, 128);

julia> @btime CUDA.@sync Zygote.pullback($(NeuralAttentionlib.naive_qkv_attention), $x, $y, $z)[2]($dO); 
  ERROR: Out of GPU memory trying to allocate 8.000 GiB
  Effective GPU memory usage: 99.85% (23.607 GiB/23.642 GiB)
  Memory pool usage: 16.312 GiB (22.875 GiB reserved)


julia> @btime CUDA.@sync Flash.flash_attention_backward($dO, Flash.flash_attention_forward($x, $y, $z)..., $x, $y, $z);
  171.479 ms (132 allocations: 7.28 KiB)

julia> CUDA.@time Flash.flash_attention_backward(dO, Flash.flash_attention_forward(x, y, z)..., x, y, z);  <=
  0.189891 seconds (136 CPU allocations: 7.391 KiB) (5 GPU allocations: 258.000 MiB, 0.01% memmgmt time)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants