Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions .github/workflows/flash_attention_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/ck/flash_attn_ck_amd/**'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
Expand All @@ -14,6 +15,7 @@ on:
branches: [main]
paths:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
- 'aiter/ops/ck/flash_attn_ck_amd/**'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
Expand All @@ -27,7 +29,7 @@ concurrency:

env:
# TODO: Switch to Dao-AILab/flash-attention main
FA_BRANCH: micmelesse/aiter_migration
FA_BRANCH: micmelesse/aiter_migration_ck
FA_REPOSITORY_URL: https://github.com/ROCm/flash-attention.git
GPU_ARCH: gfx950
BASE_IMAGE: rocm/pytorch:latest@sha256:683765a52c61341e1674fe730ab3be861a444a45a36c0a8caae7653a08a0e208
Expand Down Expand Up @@ -67,6 +69,7 @@ jobs:
triton:
- 'aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/**'
ck:
- 'aiter/ops/ck/flash_attn_ck_amd/**'
- 'aiter/ops/mha.py'
- 'csrc/py_itfs_ck/mha_*'
- 'csrc/py_itfs_ck/attention_kernels.cu'
Expand Down Expand Up @@ -117,7 +120,7 @@ jobs:

# Clone flash-attention and override aiter submodule with local checkout
COPY . /aiter
RUN git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
RUN git clone --depth 1 -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
rm -rf /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cp -a /aiter /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cd /flash-attention && \
Expand Down Expand Up @@ -207,8 +210,7 @@ jobs:
# CK Backend
# =============================================================================
flash_attention_ck:
if: false # Disabled until CK tests are ready
# if: ${{ needs.prechecks.outputs.run_ck == 'true' }}
if: ${{ needs.prechecks.outputs.run_ck == 'true' }}
name: Flash Attention - CK (1 GPU)
needs: [check-signal, prechecks]
runs-on: linux-aiter-mi355-1
Expand All @@ -218,6 +220,7 @@ jobs:
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
submodules: recursive

- name: Docker login
run: docker login -u rocmshared -p ${{ secrets.DOCKER_PASSWORD }} || true
Expand All @@ -237,7 +240,7 @@ jobs:
# Clone and install flash-attention (CK backend)
# Override aiter submodule with local checkout
COPY . /aiter
RUN git clone -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
RUN git clone --depth 1 -b ${{ env.FA_BRANCH }} ${{ env.FA_REPOSITORY_URL }} /flash-attention && \
rm -rf /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cp -a /aiter /flash-attention/${{ env.AITER_SUBMODULE_PATH }} && \
cd /flash-attention && \
Expand Down Expand Up @@ -279,26 +282,38 @@ jobs:
- name: Run correctness tests
timeout-minutes: 360
run: |
echo "CK tests not yet implemented - to be enabled by ChunYu Lai"
# docker exec fa_ck_test bash -c "
# cd /flash-attention
# pytest -v --reruns 2 --timeout=120 tests/test_flash_attn_ck.py
# "
if [ "${{ github.event_name }}" = "push" ]; then
# Post-merge: full test suite
docker exec fa_ck_test bash -c "
cd /flash-attention
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_ck.py
"
else
# PR: core API subset
docker exec fa_ck_test bash -c "
cd /flash-attention
pytest -v --reruns 2 --timeout=120 \
tests/test_flash_attn_ck.py \
-k 'test_flash_attn_output or test_flash_attn_kvcache'
"
fi

- name: Run benchmarks
timeout-minutes: 30
run: |
echo "CK benchmarks not yet implemented - to be enabled by ChunYu Lai"
# set -o pipefail
# docker exec fa_ck_test bash -c "
# cd /flash-attention
# python benchmarks/benchmark_flash_attention.py
# " |& tee benchmark_ck.log
set -o pipefail
docker exec fa_ck_test bash -c "
cd /flash-attention
python benchmarks/benchmark_flash_attention.py
" |& tee benchmark_ck.log

- name: Upload benchmark results
if: success()
run: |
echo "CK benchmark upload not yet implemented - to be enabled by ChunYu Lai"
uses: actions/upload-artifact@v4
with:
name: flash-attention-ck-benchmark
path: benchmark_ck.log

- name: Clean Up
if: always()
Expand Down
7 changes: 5 additions & 2 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,9 +759,12 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
if blob_gen_cmd:
blob_dir = f"{op_dir}/blob/"
os.makedirs(blob_dir, exist_ok=True)
cmd = blob_gen_cmd.format(blob_dir)
# Explicitly set PYTHONPATH to script's directory for sibling imports
script_dir = os.path.dirname(os.path.abspath(cmd.split()[0]))
if AITER_LOG_MORE:
logger.info(f"exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}")
os.system(f"{PY} {blob_gen_cmd.format(blob_dir)}")
logger.info(f"exec_blob ---> {PY} {cmd}")
os.system(f"PYTHONPATH={script_dir}:$PYTHONPATH {PY} {cmd}")
Comment on lines +762 to +767
sources += rename_cpp_to_cu([blob_dir], src_dir, hipify, recursive=True)
return sources

Expand Down
Empty file added aiter/ops/ck/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions aiter/ops/ck/flash_attn_ck_amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import interface_v2 as flash_attn_2

__all__ = ["flash_attn_2"]
Loading
Loading