Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e91d348
Enable Fwd and Backward
micmelesse Jun 19, 2024
79a9303
Enable sequence_parallel in bwd (#89)
micmelesse Oct 30, 2024
cd39393
Autotune off by default (#90)
micmelesse Oct 31, 2024
9f08bc6
Update Triton Version (#91)
micmelesse Nov 1, 2024
b6cc484
update Triton commit readme (#92)
micmelesse Nov 1, 2024
4f993cd
Fix README (#96)
micmelesse Nov 8, 2024
d846165
Enable MQA/GQA in backward (#100)
micmelesse Nov 15, 2024
8ca377e
Added Support for Rotary Positional Embeddings (#99)
alexkranias-amd Nov 20, 2024
73661a1
add RDNA CI (#105)
micmelesse Dec 4, 2024
83d9397
Dropout (#101)
micmelesse Dec 6, 2024
f6e7220
fp8 forward (#116)
micmelesse Jan 24, 2025
1236016
Update readme
micmelesse Jan 24, 2025
f337dd9
Minor fixes (#107)
micmelesse Jan 29, 2025
ded9323
Performant backward Triton implementation with separated dkdv and dq …
jtang10 Feb 4, 2025
c58c4d3
Quick Fixes (#124)
micmelesse Feb 6, 2025
0bcfd0f
reenable gfx1100 ci (#121)
micmelesse Feb 12, 2025
bd405ca
update triton commit (#128)
micmelesse Feb 14, 2025
2ce0b96
update base docker image (#129)
micmelesse Feb 18, 2025
a9f4ff2
Rebase to v2.7.4.post1
micmelesse Feb 20, 2025
866b7dd
Clean up README (#131)
micmelesse Feb 21, 2025
7b07032
use triton==3.2.0 (#132)
micmelesse Feb 21, 2025
d1acdff
Update README.md (#134)
micmelesse Feb 25, 2025
70bd847
fp8 backward (#119)
micmelesse Mar 5, 2025
ca07a46
Casting Kernel (#130)
micmelesse Mar 21, 2025
120cf24
Bench (#135)
micmelesse Apr 17, 2025
23d24d1
Enable Alibi (#138)
micmelesse Apr 22, 2025
bb502c6
min diff
micmelesse Apr 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,38 +137,74 @@ These features are supported in Fwd and Bwd
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes
5) Multi and grouped query attention
6) Dropout
7) Rotary embeddings
8) ALiBi

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
We are working on the following things
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements
3) FP8
4) Performance Improvements

#### Getting Started
##### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
First install the recommended Triton version

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install --verbose -e python
pip install triton==3.2.0
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
git checkout main_perf
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
```

###### Docker
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
```
FROM rocm/pytorch:latest

WORKDIR /workspace

# install triton
RUN pip install triton==3.2.0

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
git checkout main_perf &&\
python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
```

To build the docker file
```
docker build -t fa_triton .
```

To run the docker image
```
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
```

## How to use FlashAttention

Expand Down
17 changes: 17 additions & 0 deletions flash_attn/flash_attn_triton_amd/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM rocm/pytorch:latest

WORKDIR /workspace

# install triton
RUN pip install triton==3.2.0

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
git checkout main_perf &&\
python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
102 changes: 83 additions & 19 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,103 @@ These features are supported in Fwd and Bwd
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes
5) Multi and grouped query attention
6) Dropout
7) Rotary embeddings
8) ALiBi

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
We are working on the following things
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements
3) FP8
4) Performance Improvements

#### Getting Started
##### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4).
First install the recommended Triton version

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4
pip install --verbose -e python
pip install triton==3.2.0
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.
Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
git checkout main_perf
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
```

###### Docker
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
```
FROM rocm/pytorch:latest

WORKDIR /workspace

# install triton
RUN pip install triton==3.2.0

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
cd flash-attention &&\
git checkout main_perf &&\
python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
```

#### Credits
To build the docker file
```
docker build -t fa_triton .
```

To run the docker image
```
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
```

###### FP8
In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example

```
from flash_attn import flash_attn_qkvpacked_fp8_func

# forward pass
out, lse, S_dmask = flash_attn_qkvpacked_fp8_func(
qkv,
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)

# backward pass
do = torch.randn_like(out)
dqkv = torch.autograd.grad(out, (qkv), do)
```

You can use the other api functions in a similar way.



##### Credits
AMD Triton kernels team

OpenAI kernel team
Loading