Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = amd-master
[submodule "third_party/aiter"]
path = third_party/aiter
url = https://github.com/ROCm/aiter.git
21 changes: 8 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,19 @@ FlashAttention-2 ROCm CK backend currently supports:
#### Triton Backend
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.

The Triton backend kernels are provided by the [aiter](https://github.com/ROCm/aiter) package, included as a git submodule at `third_party/aiter` and automatically installed during setup.

To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Flash Attention:
```sh
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
```

To use a specific aiter commit (e.g., for testing or development):
To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention:
```sh
pip install triton==3.5.1
cd flash-attention
cd third_party/aiter && git fetch origin && git checkout <commit-sha> && cd ../..
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

To run the tests (note: full suite takes hours):
```sh
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

The Triton backend uses a default kernel configuration optimized for determinism and reasonable performance across workloads. For peak throughput, enable `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` to search for optimal settings, which incurs a one-time warmup cost.
For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`.

Alternativly, if _not_ autotuning, `FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON` may be used to set a single triton config overriding the hardcoded defaults for `attn_fwd`. E.g.
```sh
Expand All @@ -180,10 +172,13 @@ FROM rocm/pytorch:latest

WORKDIR /workspace

# install triton
RUN pip install triton==3.5.1

# build flash attention with triton backend
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation .
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
USE_TRITON_ROCM = True

if USE_TRITON_ROCM:
from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

Expand Down
4 changes: 4 additions & 0 deletions flash_attn/flash_attn_triton_amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import interface_v2 as flash_attn_2
from . import interface_v3 as flash_attn_3

__all__ = ["flash_attn_2", "flash_attn_3"]
Loading