Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
112f149 to
7a6f68f
Compare
There was a problem hiding this comment.
Pull request overview
Adds a selectable “dao_ai” MHA forward implementation (backed by flash_attn_triton_amd / flash_attn_2) and wires it into the existing Triton MHA benchmark script and the flash-attention integration workflow so CI can produce benchmark artifacts.
Changes:
- Add
-impl {default,dao_ai}tobench_mha.py, expand benchmark config grids (including GQA + causal axis), and annotate provider labels with the impl. - Introduce a global MHA forward-impl switch in
aiter/ops/triton/attention/mha.pythat routes forward toflash_attn_2when selected. - Extend
.github/workflows/flash_attention_integration.yamlto run and upload dao_ai kernel + model benchmarks (and rename existing benchmark log artifacts).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
op_tests/op_benchmarks/triton/bench_mha.py |
Adds -impl flag, includes causal as a benchmark dimension, expands head/seq config coverage, and updates provider labeling. |
aiter/ops/triton/attention/mha.py |
Adds global impl selection and routes forward to flash_attn_2 for dao_ai. |
.github/workflows/flash_attention_integration.yaml |
Runs dao_ai benchmarks in CI and uploads logs/CSVs as artifacts. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
56a638e to
1d85ccf
Compare
This comment was marked as resolved.
This comment was marked as resolved.
055223f to
07d864a
Compare
9ae8dd5 to
ae5999b
Compare
9c6caf4 to
48a086e
Compare
add impl bench simple config list edge cases lint more configs seperate scenario add batch config split branches print loop limit configs add literal scenarios workload function try catch better print csv writer add skip save save try free arch check again save save try again save fp8 skip save save lint
48a086e to
6b4b934
Compare
6b4b934 to
15b22ff
Compare
brunomazzottiamd
left a comment
There was a problem hiding this comment.
LGTM! Let's wait for the CI and then merge it.
|
The following failures on Triton Tests (MI35X) / Shard 0 are expected: Everything else have passed, we're good to merge. |
|
The bench results on MI5x and RDNA3 are at https://github.com/ROCm/aiter/actions/runs/24152934159/job/70485243064. I have posted the data below. Here are the MI5X numbers
Here are the RDNA3 numbers
|
Motivation
This pr benches the code in
flash_attention_triton_amdwhich is used in upstream flash attention using the bench_mha.py. It addsmha_set_impl("dao_ai")to dispatch forward/backward through theflash_attn_2codepath, and refactorsbench_mha.py. Tests intest_mha.pynow parametrize over bothdefaultanddao_aiimplementations.Technical Details
Test Plan
Test Result
Submission Checklist