Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 20 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,74 +128,47 @@ FlashAttention-2 ROCm CK backend currently supports:
3. Both forward's and backward's head dimensions up to 256.

#### Triton Backend
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes
5) Multi and grouped query attention
6) Dropout
7) Rotary embeddings
8) ALiBi

We are working on the following things
1) Paged Attention
2) Sliding Window
3) FP8
4) Performance Improvements

##### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed.

Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention:
```sh
pip install triton==3.5.1
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
```
To run the tests (note: full suite takes hours):
```sh
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
```
For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`.

###### Docker
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
```
For a quick start with Docker:
```dockerfile
FROM rocm/pytorch:latest

WORKDIR /workspace

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# install triton
RUN pip install triton==3.5.1

RUN git clone https://github.com/ROCm/flash-attention.git &&\
# build flash attention with triton backend
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
python setup.py install
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
```

To build the docker file
```
docker build -t fa_triton .
# set env variable to use triton backend
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
```

To run the docker image
```
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
Build and run:
```sh
docker build -t flash-attn-triton .
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton
```

## How to use FlashAttention
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# We need to import the CUDA kernels after importing torch
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

Expand Down
17 changes: 0 additions & 17 deletions flash_attn/flash_attn_triton_amd/Dockerfile

This file was deleted.

113 changes: 0 additions & 113 deletions flash_attn/flash_attn_triton_amd/README.md

This file was deleted.

4 changes: 4 additions & 0 deletions flash_attn/flash_attn_triton_amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import interface_v2 as flash_attn_2
from . import interface_v3 as flash_attn_3

__all__ = ["flash_attn_2", "flash_attn_3"]
Loading