Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions .github/workflows/flash_attention_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,19 @@ jobs:
RUN pip install --upgrade pip
RUN pip install pytest pytest-rerunfailures pytest-timeout einops

# Clone flash-attention and override aiter submodule with local checkout
COPY . /aiter
# Install FA (setup.py handles submodule init and installs its pinned aiter)
RUN git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
rm -rf /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cp -a /aiter /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cd /flash-attention && \
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pip install --no-build-isolation .

RUN echo "=== Installed versions ===" && \
pip show flash-attn && \
pip show amd-aiter
# Override with PR aiter
COPY . /aiter
RUN pip uninstall -y amd-aiter && pip install --no-build-isolation -e /aiter

# Verify aiter is installed from PR branch
RUN pip show flash-attn && pip show amd-aiter && pip show triton && \
python -c "from importlib.metadata import distribution; d = distribution('amd-aiter'); loc = d.read_text('direct_url.json'); assert '/aiter' in loc, f'aiter not from PR branch: {loc}'"

WORKDIR /flash-attention
EOF
Expand Down Expand Up @@ -171,44 +172,63 @@ jobs:
--name fa_triton_test \
fa_triton_test:ci

- name: Run correctness tests
timeout-minutes: 360
run: |
if [ "${{ github.event_name }}" = "push" ]; then
# Post-merge: full test suite
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py
"
else
# PR: core API subset (~1 hour)
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py \
-k 'test_flash_attn_output or test_flash_attn_kvcache'
"
fi

- name: Run benchmarks
timeout-minutes: 30
if: ${{ github.event_name != 'push' }}
timeout-minutes: 60
run: |
set -o pipefail
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
python benchmarks/benchmark_flash_attention.py
" |& tee benchmark_triton_${{ matrix.label }}.log
" |& tee benchmark_fa_triton_${{ matrix.label }}.log
docker exec fa_triton_test bash -c "
cd /aiter
python op_tests/op_benchmarks/triton/bench_mha.py \
-impl dao_ai -metric throughput -o /tmp/bench_dao_ai
" |& tee benchmark_mha_daoai_${{ matrix.label }}.log
docker cp fa_triton_test:/tmp/bench_dao_ai/ bench_dao_ai_${{ matrix.label }}/ 2>/dev/null || true

- name: PR correctness tests
if: ${{ github.event_name != 'push' }}
timeout-minutes: 90
run: |
docker exec fa_triton_test bash -c "
cd /flash-attention
export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py::test_flash_attn_output \
-k '2048-2048' && \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py::test_flash_attn_varlen_output \
-k '1024-1024' && \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py::test_flash_attn_kvcache \
-k '16-20000'
"

- name: Post-merge correctness tests
if: ${{ github.event_name == 'push' && matrix.label == 'MI35X' }}
timeout-minutes: 360
run: |
docker exec fa_triton_test bash -c "
cd /flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 \
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_triton_amd.py
"

- name: Upload benchmark results
if: success()
if: ${{ success() && github.event_name != 'push' }}
uses: actions/upload-artifact@v4
with:
name: flash-attention-triton-benchmark-${{ matrix.label }}
path: benchmark_triton_${{ matrix.label }}.log
path: |
benchmark_fa_triton_${{ matrix.label }}.log
benchmark_mha_daoai_${{ matrix.label }}.log
bench_dao_ai_${{ matrix.label }}/

- name: Clean Up
if: always()
Expand Down Expand Up @@ -360,7 +380,7 @@ jobs:
# docker exec fa_ck_test bash -c "
# cd /flash-attention
# python benchmarks/benchmark_flash_attention.py
# " |& tee benchmark_ck.log
# " |& tee benchmark_fa_ck.log

- name: Upload benchmark results
if: success()
Expand Down
Loading
Loading