Skip to content
Merged

Bench #135

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
823ef7d
FP8 Bench work
micmelesse Mar 18, 2025
a8de016
fp8 seems slower
micmelesse Mar 21, 2025
0ee7031
clean up newer benching code. fp8 is slower
micmelesse Mar 24, 2025
a697a69
output markdown and multiple types
micmelesse Mar 24, 2025
e84be04
bench all supported_dtypes for function by default
micmelesse Mar 24, 2025
2149fdf
add dockerignore
micmelesse Mar 25, 2025
1855bb3
need the .git for submodule update
micmelesse Mar 25, 2025
605faae
ignore training data
micmelesse Mar 25, 2025
d929a41
get ready for ck
micmelesse Mar 25, 2025
b39a1de
forward ck bench working
micmelesse Mar 25, 2025
e6f4a46
triton versus ck works
micmelesse Mar 25, 2025
e984f1e
tuned triton perf comp
micmelesse Mar 26, 2025
6aef88c
collect env flags
micmelesse Mar 26, 2025
7cd64d8
bench varlen and kvcache
micmelesse Mar 26, 2025
a15f433
function configs
micmelesse Mar 27, 2025
2cd03e4
show relative percentage diff
micmelesse Mar 27, 2025
22c668e
postive means triton faster negative means ck is faster
micmelesse Mar 27, 2025
0524f86
save
micmelesse Mar 27, 2025
5c6a100
add new decode impl with switch flag
micmelesse Mar 28, 2025
27b28b0
batch 1 and nheads 1 seems to work
micmelesse Mar 28, 2025
29e9296
autotune by default
micmelesse Mar 28, 2025
72266d5
simple stride calc in old impl
micmelesse Mar 28, 2025
3998016
fixed bug due to strides are bhsd
micmelesse Mar 28, 2025
78f0612
rename the dim_k
micmelesse Apr 2, 2025
5ab0441
clean up
micmelesse Apr 2, 2025
de6c9e9
old path works
micmelesse Apr 3, 2025
84722a7
rm block ptrs for q
micmelesse Apr 3, 2025
be65b1e
rm block_ptrs for k
micmelesse Apr 3, 2025
51f6650
rm block_ptrs for v
micmelesse Apr 3, 2025
1f09c12
rm block_ptrs from o
micmelesse Apr 3, 2025
7fd4adf
disable debug on bench
micmelesse Apr 3, 2025
1aa1a95
clean up
micmelesse Apr 3, 2025
b551715
clean up names
micmelesse Apr 3, 2025
a65c31d
compute offs_k properly
micmelesse Apr 3, 2025
5c0cb60
pass padded head to reduce kernel
micmelesse Apr 3, 2025
3f840fc
fix o_mask bug
micmelesse Apr 3, 2025
99f8fa1
rm old impl
micmelesse Apr 4, 2025
1d1a908
lambda grid
micmelesse Apr 4, 2025
14342a9
save final
micmelesse Apr 4, 2025
519f026
ignore git stuff
micmelesse Apr 7, 2025
a17da06
add inference params to prefill
micmelesse Apr 7, 2025
da22125
cache seqlens working
micmelesse Apr 7, 2025
497ee9a
most cases work except newkv
micmelesse Apr 8, 2025
16384ed
fix minor bugs when runing fwd and bwd
micmelesse Apr 8, 2025
456a68e
check for backend
micmelesse Apr 9, 2025
1e42662
don't ignore .git
micmelesse Apr 9, 2025
53c2ed3
add modes
micmelesse Apr 9, 2025
437b215
bench bwd
micmelesse Apr 9, 2025
923e077
add llama configs
micmelesse Apr 9, 2025
f73d3bd
test fwd impl
micmelesse Apr 9, 2025
3529475
run bwd_impl
micmelesse Apr 9, 2025
528feb7
move fp8 code
micmelesse Apr 10, 2025
1e57dcc
use Decode kernel for kvcache
micmelesse Apr 10, 2025
ec7d089
fix fp8 import bug
micmelesse Apr 10, 2025
c15892b
fix bug
micmelesse Apr 11, 2025
f379dfa
add arch in report
micmelesse Apr 11, 2025
d397d8a
clean up test suite
micmelesse Apr 11, 2025
0af8590
fix fp8 typos
micmelesse Apr 11, 2025
b5fed8f
run ci
micmelesse Apr 11, 2025
2c598b4
add fused kernel
micmelesse Apr 11, 2025
df7ad2f
add one kernel
micmelesse Apr 11, 2025
43fbc22
update ci and readme
micmelesse Apr 11, 2025
202d8ec
report ratios and remove split impl test expand bwd impl test
micmelesse Apr 11, 2025
af74966
use split kernel
micmelesse Apr 11, 2025
009b9e4
get one kernel working
micmelesse Apr 11, 2025
34e85b7
use flag to switch bwd mode
micmelesse Apr 11, 2025
e17cb36
clean up test_ir
micmelesse Apr 14, 2025
85cc686
one kernel has its own copy of the bwd kernels
micmelesse Apr 14, 2025
3f609da
autotune stub
micmelesse Apr 14, 2025
6badc44
pass og metaparams by default
micmelesse Apr 14, 2025
52366a4
add autotune configs
micmelesse Apr 14, 2025
15b1569
add tuning configs
micmelesse Apr 14, 2025
98843cd
update fused kernel code
micmelesse Apr 14, 2025
a75f665
use jingning
micmelesse Apr 15, 2025
a50deff
no auto tune for bwd
micmelesse Apr 15, 2025
b572b5c
simpler varlen branching
micmelesse Apr 15, 2025
011ad15
fix constexpr bug
micmelesse Apr 15, 2025
9fcd9c9
fix varlen fp8
micmelesse Apr 15, 2025
861f1f5
qkv fp8 working
micmelesse Apr 15, 2025
47dd5a0
fp8 qkv varlen green
micmelesse Apr 15, 2025
25b406a
fix bench functions
micmelesse Apr 15, 2025
fa157cc
pick bench functions
micmelesse Apr 15, 2025
05775d2
bench defaults set
micmelesse Apr 16, 2025
44dc9be
fix bug
micmelesse Apr 16, 2025
da3f440
add bench deps
micmelesse Apr 16, 2025
a2472a9
bench env variations
micmelesse Apr 16, 2025
96cb0cd
per backend env configs
micmelesse Apr 16, 2025
7b1f6f7
fix bug
micmelesse Apr 16, 2025
ab28894
add improved fused kernel
micmelesse Apr 16, 2025
03b8df4
fix bug
micmelesse Apr 17, 2025
996f62b
final clean up
micmelesse Apr 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.eggs
.gitignore
build
dist
flash_attn.egg-info
log
scripts
training/data
gpucore*
bench*.md
*.pth
*.html
*.png
*.csv
22 changes: 7 additions & 15 deletions .github/workflows/amd_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,20 @@ jobs:

- name: Install dependencies for bench and misc
run: |
pip install matplotlib pandas pytest
pip install numpy==1.24 matplotlib pandas tabulate

# FIXME: run the full suite
- name: AMD Internal Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest flash_attn/flash_attn_triton_amd/test.py::test_fp8
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py

- name: AMD Bench
if: False
- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py

# run big test suites
- name: Flash Attention Tests using Pytorch reference implementation
if: False
- name: AMD Bench
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_REF=1 pytest tests/test_flash_attn_triton_amd.py
python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache

- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py

Nightly-RDNA-AMD:
runs-on: ${{ matrix.runner }}
strategy:
Expand Down Expand Up @@ -110,4 +102,4 @@ jobs:

- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py::test_flash_attn_output
20 changes: 6 additions & 14 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,16 @@ jobs:

- name: Install dependencies for bench and misc
run: |
pip install matplotlib pandas pytest
pip install numpy==1.24 matplotlib pandas tabulate

# FIXME: run the full suite
- name: AMD Internal Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest flash_attn/flash_attn_triton_amd/test.py::test_fp8
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py

- name: AMD Bench
if: False
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=1 python flash_attn/flash_attn_triton_amd/bench.py

# run big test suites
- name: Flash Attention Tests using Pytorch reference implementation
if: False
- name: Flash Attention Tests
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_REF=1 pytest tests/test_flash_attn_triton_amd.py
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest tests/test_flash_attn_triton_amd.py

- name: Flash Attention Tests
- name: AMD Bench
run: |
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func flash_attn_varlen_func flash_attn_with_kvcache
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ venv
scripts
csrc/flash_attn_ck
.eggs
log
*.log
core.*
gpucore.*
Expand All @@ -42,5 +43,9 @@ gpucore.*
*.json
*.txt
*.pth
*.md
training/logs
training/data
# ck modules
csrc/composable_kernel
csrc/cutlass
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ To test that things are working, you can run our tests. These tests take hours s
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
```

###### Docker
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
```
Expand Down
Loading