Skip to content

feat(mla): support nhead < 16 in MLA decode via transparent head padding#2585

Open
ChuanLi1101 wants to merge 6 commits into
mainfrom
chuan/mla-nhead-lt16-support
Open

feat(mla): support nhead < 16 in MLA decode via transparent head padding#2585
ChuanLi1101 wants to merge 6 commits into
mainfrom
chuan/mla-nhead-lt16-support

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Contributor

Summary

Add transparent nhead < 16 support for MLA decode path, enabling models like Kimi-Linear-48B-A3B with TP=8 (nhead=4) to use the ASM kernel without manual head padding in vLLM.

Problem

The MLA decode ASM kernel requires nhead >= 16. Models with smaller head counts (e.g. nhead=4 for Kimi-Linear with TP=8) would either fail with an assertion error or require redundant head padding at the vLLM integration level, causing up to 4x wasted computation.

Changes

  • aiter/mla.py: Automatically pad query heads from nhead to 16 via repeat_interleave before kernel dispatch, and strip output back to original nhead after decode. Also added FP8 block size entries for nhead=4 and nhead=8.
  • aiter/ops/attention.py: Use effective_num_head (padded to 16) for metadata buffer sizing when nhead < 16.
  • csrc/kernels/mla/metadata/v1_2_device.cuh: Add pad_to_qh16 path in C++ metadata generation for nhead < 16.
  • op_tests/test_mla.py, test_mla_persistent.py, test_mla_sparse.py: Add nhead=4 (and nhead=8) to default test configurations.

Test Plan

Validated on MI355X (gfx950) with:

  • nhead=4 and nhead=8
  • BF16 and FP8 KV caches
  • Context lengths: 1K, 2K, 4K, 8K, 16K, 32K, 64K, 128K
  • Batch sizes: 1, 4, 16, 64
  • Sparse MLA decode path (96 test cases, all passed)
  • Max error within expected tolerance (BF16: ~0, FP8: < 0.05)

Models like Kimi-Linear-48B-A3B with TP=8 produce nhead=4, which was
previously unsupported by the ASM kernel (requires nhead>=16). This adds
transparent query head padding from nhead to 16 via repeat_interleave,
with output stripping after decode, across Python, C++ metadata, and
test configurations.

Validated on MI355X with nhead=4 and nhead=8, BF16 and FP8 KV caches,
across context lengths 1K-128K and batch sizes 1-64 (96 test cases passed).

Made-with: Cursor
@ChuanLi1101 ChuanLi1101 requested a review from a team April 1, 2026 20:10
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2585 --add-label <label>

@ChuanLi1101
Copy link
Copy Markdown
Contributor Author

cc @valarLip @carlushuang — Requesting expedited review. This is an alternative approach for nhead < 16 support in MLA decode via transparent head padding. Required for GLM-5 TP=8 on MI355X. Blocking vLLM-side PRs (vllm-project/vllm#36855, vllm-project/vllm#38665) for customer-facing GLM-5 inference.

valarLip
valarLip previously approved these changes Apr 4, 2026
@valarLip
Copy link
Copy Markdown
Collaborator

valarLip commented Apr 4, 2026

looks good but please make sure passed all ci test

@ChuanLi1101 ChuanLi1101 force-pushed the chuan/mla-nhead-lt16-support branch from 531bbbd to 793bfcc Compare April 4, 2026 16:14
@ChuanLi1101
Copy link
Copy Markdown
Contributor Author

@valarLip All CI checks are now green (10/10 Standard Test shards passed on both MI325 and MI35X). The merge conflict has been resolved and a bug fix was added to strip padded heads from returned logits when nhead < 16. Ready for re-review and merge.

@ChuanLi1101 ChuanLi1101 requested a review from carlushuang April 5, 2026 00:40
@ChuanLi1101
Copy link
Copy Markdown
Contributor Author

@valarLip — Friendly follow-up. All 10 Standard Test shards passed on both MI325 and MI35X (the shard 2 MI35X failure was a transient infra issue, not a code regression).

This PR is now blocking multiple downstream deliverables:

  • vLLM PR #37353 (skip head repeat for BF16)
  • vLLM PR #38665 (sparse MLA head repeat fix)
  • Kimi-Linear-48B-A3B decode performance on MI355X (nhead=4 with TP=8 → 4x redundant compute without this fix)
  • MI355 64M-context release target

Also filed a related issue for the uint32 overflow at >8M seqlen: #2768

Could you re-review and merge at your earliest convenience? Happy to address any remaining concerns.

@ephremw
Copy link
Copy Markdown

ephremw commented Apr 17, 2026

@ChuanLi1101, thanks for this PR. I have trouble using this PR to get Kimi to generate coherent text in graph mode. Eager mode is okay, though. Through many experiments, I narrowed down the failure mode:

  • vLLM 0.19
  • forced ROCM_AITER_MLA
  • Kimi-Linear-48B-A3B-Instruct
  • the ROCm MLA cudagraph path, especially decode

pr2585_text_mre.tar.gz

The strongest control is now consistent in both the 16 heads/GPU control
regime and the 4 heads/GPU supported small-head regime:

  • graph-mode baseline: incoherent
  • cudagraph_mode=NONE: coherent
  • --enforce-eager: coherent

Changing the model from Kimi to DeepSeek-V2-Lite-Chat on the same forced MLA
path does not reproduce the same failure pattern: DeepSeek stays coherent
in the supported small-head regime on v0.19 + PR2585 rebuilt.

I reduced this to a minimal MLA repro that does not depend on our benchmark
harness at runtime:

  • forced ROCM_AITER_MLA
  • frozen /v1/completions prompt-token requests
  • length keys 1024, 2048, 4096
  • only vLLM 0.18 and vLLM 0.19
  • only the representative AITER states:
    • bundled
    • PR2585 rebuilt
  • two models:
    • moonshotai/Kimi-Linear-48B-A3B-Instruct
    • deepseek-ai/DeepSeek-V2-Lite-Chat

The directory bundle is self-contained and can recreate the matrix without the
rest of helios-demo.

How To Read The Tables

  • Results Expected means expected from the exact inspected code path for that
    vLLM/AITER combination, not from the older v0.15.1 mental model.
  • 16 heads/GPU is the control path and does not require PR2585's new
    small-head repeat-to-16 logic.
  • 4 heads/GPU is the supported small-head path that does exercise that
    PR2585 logic.
  • Heads/GPU below 4 are unsupported and intentionally excluded from the main
    tables.

16 Heads/GPU Control

Model vLLM AITER package at runtime Heads/GPU Result Results Expected
DeepSeek-V2-Lite-Chat 0.18.0+rocm700 bundled 16 coherent
DeepSeek-V2-Lite-Chat 0.18.0+rocm700 PR2585 rebuilt 16 coherent
DeepSeek-V2-Lite-Chat 0.19.0+rocm721 bundled 16 coherent
DeepSeek-V2-Lite-Chat 0.19.0+rocm721 PR2585 rebuilt 16 coherent
Kimi-Linear-48B-A3B-Instruct 0.18.0+rocm700 bundled 16 coherent
Kimi-Linear-48B-A3B-Instruct 0.18.0+rocm700 PR2585 rebuilt 16 coherent
Kimi-Linear-48B-A3B-Instruct 0.19.0+rocm721 bundled 16 incoherent
Kimi-Linear-48B-A3B-Instruct 0.19.0+rocm721 PR2585 rebuilt 16 incoherent

4 Heads/GPU Supported Small-Head Path

Model vLLM AITER package at runtime Heads/GPU Result Results Expected
DeepSeek-V2-Lite-Chat 0.18.0+rocm700 bundled 4 coherent
DeepSeek-V2-Lite-Chat 0.18.0+rocm700 PR2585 rebuilt 4 coherent
DeepSeek-V2-Lite-Chat 0.19.0+rocm721 bundled 4 startup failure on old small-head assert
DeepSeek-V2-Lite-Chat 0.19.0+rocm721 PR2585 rebuilt 4 coherent
Kimi-Linear-48B-A3B-Instruct 0.18.0+rocm700 bundled 4 incoherent
Kimi-Linear-48B-A3B-Instruct 0.18.0+rocm700 PR2585 rebuilt 4 incoherent
Kimi-Linear-48B-A3B-Instruct 0.19.0+rocm721 bundled 4 startup failure on old small-head assert
Kimi-Linear-48B-A3B-Instruct 0.19.0+rocm721 PR2585 rebuilt 4 incoherent

Graph-Control Follow-Up On v0.19 Kimi PR2585 rebuilt At 4 Heads/GPU

I reran the same v0.19 + Kimi + PR2585 rebuilt + 4 heads/GPU case in two
additional modes that remove cudagraph replay from the path:

Runtime mode Result
graph-mode baseline (FULL_AND_PIECEWISE) incoherent
cudagraph_mode=NONE coherent
--enforce-eager coherent

Concrete artifacts:

  • graph-mode baseline:
    • 1024: starts Thanks!!!!!!!!!!!!!!!!...
    • 2048: starts It!!!!!!!!!!!!!!!!...
    • 4096: starts Thanks!!!!!!!!!!!!!!!!...
  • cudagraph_mode=NONE:
    • 1024: starts Thanks for the context!...
    • 2048: starts It looks like you're referencing...
    • 4096: starts You're asking for a technical report...
  • --enforce-eager:
    • 1024: starts Thanks for the context!...
    • 2048: starts It looks like you're referencing...
    • 4096: starts You're asking for a technical report...

The full logs are saved for both follow-up runs:

  • ../recreated/v019_kimi_tp8_graph_controls_20260416T215131Z/v019_pr2585_rebuilt_tp8_cudagraph_none/vllm_server.log
  • ../recreated/v019_kimi_tp8_graph_controls_20260416T215131Z/v019_pr2585_rebuilt_tp8_enforce_eager/vllm_server.log

Graph-Control Follow-Up On v0.19 Kimi PR2585 rebuilt At 16 Heads/GPU

I reran the same v0.19 + Kimi + PR2585 rebuilt + 16 heads/GPU control-path
case in two additional modes that remove cudagraph replay from the path:

Runtime mode Result
graph-mode baseline (FULL_AND_PIECEWISE) incoherent
cudagraph_mode=NONE coherent
--enforce-eager coherent

Concrete artifacts:

  • graph-mode baseline:
    • 1024: starts Thanks!!!!!!!!...
    • 2048: starts It!!!!!!!!...
    • 4096: starts It!!!!!!!!...
  • cudagraph_mode=NONE:
    • 1024: starts Thanks for the context!...
    • 2048: starts It looks like you're referencing...
    • 4096: starts It looks like you're referencing...
  • --enforce-eager:
    • 1024: starts Thanks for the context!...
    • 2048: starts It looks like you're referencing...
    • 4096: starts You're asking for a technical report...

The full logs are saved for both follow-up runs:

  • ../recreated/v019_kimi_tp2_graph_controls_20260417T005318Z/v019_pr2585_rebuilt_tp2_cudagraph_none/vllm_server.log
  • ../recreated/v019_kimi_tp2_graph_controls_20260417T005318Z/v019_pr2585_rebuilt_tp2_enforce_eager/vllm_server.log

Main Questions

  • Kimi is incoherent on v0.19 even in the 16 heads/GPU control path, while
    DeepSeek-V2-Lite is coherent there. Is this expected for Kimi on forced
    ROCM_AITER_MLA?
  • In the supported 4 heads/GPU regime, DeepSeek-V2-Lite is coherent on
    both v0.18 bundled/rebuilt and v0.19 rebuilt, but Kimi is incoherent on
    v0.18 bundled, v0.18 rebuilt, and v0.19 rebuilt. What Kimi-specific
    interaction should we inspect next?
  • v0.19 bundled still fails the old small-head assert for both models in the
    4 heads/GPU regime. Is that expected until the rebuilt PR2585 path is used?
  • The strongest new discriminator is that v0.19 + Kimi + PR2585 rebuilt is
    coherent when cudagraph replay is removed, but incoherent in the original
    graph-mode path at both 16 and 4 heads/GPU. Does that point to a known
    issue in the ROCm MLA cudagraph path for Kimi specifically?

Implementation Observation

  • One implementation detail that stood out to us, and that you may already have
    considered, is the score-scaling path inside the MLA decode kernel. If we are
    reading the code correctly, the score path appears to be equivalent to

    z = (QK) / sqrt(D_q)
    w = exp2((z - m) * log2(e))
    

    and therefore algebraically equivalent to

    w = exp2((QK) * (log2(e) / sqrt(D_q)) - m_2)
    

    The reason this stood out to us is that, if we are understanding the kernel
    correctly, the folded form would appear to reduce the score-side elementwise
    multiplies from two to one. For these specialized kernels, D_q appears
    fixed by the kernel traits, so we wondered whether folding those factors into
    a single constant scale might be worth considering. We mention this only as a
    respectful observation in case it is useful, not as a claim that this is the
    source of the Kimi incoherence.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants