Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 0 additions & 65 deletions .github/workflows/amd_tests.yml

This file was deleted.

14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` s

```
cd flash-attention
git checkout main_perf
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

Expand All @@ -184,16 +183,17 @@ WORKDIR /workspace
# install triton
RUN pip install triton==3.3.0

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

RUN git clone https://github.com/ROCm/flash-attention.git &&\
# build flash attention with triton backend
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
git checkout main_perf &&\
python setup.py install
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install

# set working dir
WORKDIR /workspace/flash-attention

# set env variable to use triton backend
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"

```

To build the docker file
Expand Down
2 changes: 0 additions & 2 deletions flash_attn/flash_attn_triton_amd/.gitignore

This file was deleted.

17 changes: 0 additions & 17 deletions flash_attn/flash_attn_triton_amd/Dockerfile

This file was deleted.

113 changes: 0 additions & 113 deletions flash_attn/flash_attn_triton_amd/README.md

This file was deleted.

8 changes: 7 additions & 1 deletion flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors
from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8

from typing import Optional, Tuple

Expand Down Expand Up @@ -1503,11 +1503,17 @@ def attention_prefill_backward_triton_fused_atomics_impl(
descale_v: Optional[torch.Tensor] = None,
descale_do: Optional[torch.Tensor] = None,
fused: bool = False,
# seqused for FA v3 (currently ignored in this implementation)
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
):
IS_FP8 = is_fp8(q)
if IS_FP8:
FP8_MAX = torch.finfo(q.dtype).max
descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) )

if DEBUG:
print(f"FP8 path triggered in bwd_prefill_fused_atomics.py")
else:
FP8_MAX = None
stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None
Expand Down
Loading