Skip to content

[Cute,Fwd,Sm100] support irregular qhead / kvhead ratios#2186

Merged
tridao merged 1 commit intoDao-AILab:mainfrom
modal-labs:irregular-pack-gqa
Mar 20, 2026
Merged

[Cute,Fwd,Sm100] support irregular qhead / kvhead ratios#2186
tridao merged 1 commit intoDao-AILab:mainfrom
modal-labs:irregular-pack-gqa

Conversation

@timmy-feng
Copy link
Copy Markdown
Contributor

@timmy-feng timmy-feng commented Jan 16, 2026

This PR enables the Pack GQA optimization for attention head group sizes which don't divide m_block_size by using cp.async to load Q. Previously, pack_gqa would be set to false in the SM100 interface whenever nheads / nheads_kv does not divide 128.

This offers significant speedup for autoregressive sampling in certain LLM models (e.g. GLM 4.7):

### hdim=128, causal=False, s_q=1, s_k=65536, b=4, nheads=96, nheads_kv=8, num_splits=1 ###
FA Python fwd:              1.816ms,    7.1 TFLOPS  correct=True  max_diff=0.000183
FA Python fwd pack_gqa:     0.621ms,   20.7 TFLOPS  correct=True  max_diff=0.000183
speedup (pack_gqa=True vs False): 2.92x

Note: we disable paired CTA when using cp.async to load Q because there is no pre-MMA barrier construct to wait on both CTAs to finish loading before continuing.

Correctness

Updated test_flash_attn_fast.py to cover MQA, which invokes the irregular GQA path. All tests continue to pass.

Command: /usr/local/bin/python -m pytest -q /flash-attention/tests/cute/test_flash_attn_fast.py -x
{
  "test_flash_attn_fast.py": {
    "test_flash_attn_output": 96,
    "test_flash_attn_varlen_output": 36,
    "test_flash_attn_varlen_unpad_output": 96,
    "test_flash_attn_combine": 12
  }
}
........................................................................ [ 30%]
........................................................................ [ 60%]
........................................................................ [ 90%]
....................
=============================== warnings summary ===============================
cute/test_flash_attn_fast.py: 3389 warnings
  /usr/local/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:63: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

cute/test_flash_attn_fast.py: 10 warnings
  /flash-attention/flash_attn/cute/pack_gqa.py:122: DSLOptimizationWarning: This static loop has 64 iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.
    for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
240 passed, 3399 warnings in 796.92s (0:13:16)

Perf Regression

Running bench_sm90.py on a B200, perf stays consistent across various problem sizes.

================================================================================
  BEFORE (main)
================================================================================

================================================================================
  SM100 FWD  (rep=30)
================================================================================
 hdim hdim_v causal batch seqlen       ms   TFLOPS   max_diff
--------------------------------------------------------------------------------
   64     64  False    16   2048    0.604    910.4   0.001953
   64     64   True    16   2048    0.455    604.8   0.015625
   64     64  False     8   4096    1.163    945.6   0.000977
   64     64   True     8   4096    0.746    737.1   0.015625
   64     64  False     4   8192    2.310    952.1   0.000977
   64     64   True     4   8192    1.321    832.0   0.015625
   96     96  False    16   2048    0.348   1184.1   0.001953
   96     96   True    16   2048    0.241    854.1   0.015625
   96     96  False     8   4096    0.663   1244.2   0.001953
   96     96   True     8   4096    0.390   1057.1   0.015625
   96     96  False     4   8192    1.292   1276.2   0.000488
   96     96   True     4   8192    0.712   1157.7   0.015625
  128    128  False    16   2048    0.385   1427.8   0.001953
  128    128   True    16   2048    0.271   1013.1   0.015625
  128    128  False     8   4096    0.727   1512.9   0.001953
  128    128   True     8   4096    0.462   1191.2   0.015625
  128    128  False     4   8192    1.420   1548.9   0.000488
  128    128   True     4   8192    0.911   1207.1   0.015625

================================================================================
  AFTER (PR: irregular-pack-gqa)
================================================================================

================================================================================
  SM100 FWD  (rep=30)
================================================================================
 hdim hdim_v causal batch seqlen       ms   TFLOPS   max_diff
--------------------------------------------------------------------------------
   64     64  False    16   2048    0.606    906.7   0.001953
   64     64   True    16   2048    0.456    602.8   0.015625
   64     64  False     8   4096    1.159    948.3   0.000977
   64     64   True     8   4096    0.742    740.7   0.015625
   64     64  False     4   8192    2.306    953.7   0.000977
   64     64   True     4   8192    1.335    823.4   0.015625
   96     96  False    16   2048    0.347   1188.8   0.001953
   96     96   True    16   2048    0.243    847.5   0.015625
   96     96  False     8   4096    0.661   1247.0   0.001953
   96     96   True     8   4096    0.413    999.0   0.015625
   96     96  False     4   8192    1.317   1252.6   0.000488
   96     96   True     4   8192    0.752   1097.0   0.015625
  128    128  False    16   2048    0.384   1432.4   0.001953
  128    128   True    16   2048    0.269   1023.3   0.015625
  128    128  False     8   4096    0.725   1516.2   0.001953
  128    128   True     8   4096    0.461   1192.9   0.015625
  128    128  False     4   8192    1.418   1550.3   0.000488
  128    128   True     4   8192    0.910   1208.1   0.015625

@timmy-feng timmy-feng marked this pull request as ready for review March 17, 2026 16:32
@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 17, 2026

Do we have tests that hit this code path? I guess we set nheads=6 and for mqa it should hit this code path?

@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 18, 2026

LGTM, we can merge when it's ready

@tridao tridao merged commit 3250081 into Dao-AILab:main Mar 20, 2026
@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 20, 2026

Thank you!

zhuochenKIDD pushed a commit to zhuochenKIDD/flash-attention that referenced this pull request Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants