feat(mla): support nhead < 16 in MLA decode via transparent head padding#2585
feat(mla): support nhead < 16 in MLA decode via transparent head padding#2585ChuanLi1101 wants to merge 6 commits into
Conversation
Models like Kimi-Linear-48B-A3B with TP=8 produce nhead=4, which was previously unsupported by the ASM kernel (requires nhead>=16). This adds transparent query head padding from nhead to 16 via repeat_interleave, with output stripping after decode, across Python, C++ metadata, and test configurations. Validated on MI355X with nhead=4 and nhead=8, BF16 and FP8 KV caches, across context lengths 1K-128K and batch sizes 1-64 (96 test cases passed). Made-with: Cursor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
cc @valarLip @carlushuang — Requesting expedited review. This is an alternative approach for nhead < 16 support in MLA decode via transparent head padding. Required for GLM-5 TP=8 on MI355X. Blocking vLLM-side PRs (vllm-project/vllm#36855, vllm-project/vllm#38665) for customer-facing GLM-5 inference. |
|
looks good but please make sure passed all ci test |
Made-with: Cursor
…< 16 Made-with: Cursor
531bbbd to
793bfcc
Compare
|
@valarLip All CI checks are now green (10/10 Standard Test shards passed on both MI325 and MI35X). The merge conflict has been resolved and a bug fix was added to strip padded heads from returned logits when nhead < 16. Ready for re-review and merge. |
|
@valarLip — Friendly follow-up. All 10 Standard Test shards passed on both MI325 and MI35X (the shard 2 MI35X failure was a transient infra issue, not a code regression). This PR is now blocking multiple downstream deliverables:
Also filed a related issue for the uint32 overflow at >8M seqlen: #2768 Could you re-review and merge at your earliest convenience? Happy to address any remaining concerns. |
|
@ChuanLi1101, thanks for this PR. I have trouble using this PR to get Kimi to generate coherent text in graph mode. Eager mode is okay, though. Through many experiments, I narrowed down the failure mode:
The strongest control is now consistent in both the
Changing the model from Kimi to I reduced this to a minimal MLA repro that does not depend on our benchmark
The directory bundle is self-contained and can recreate the matrix without the How To Read The Tables
16 Heads/GPU Control
4 Heads/GPU Supported Small-Head Path
Graph-Control Follow-Up On
|
| Runtime mode | Result |
|---|---|
graph-mode baseline (FULL_AND_PIECEWISE) |
incoherent |
cudagraph_mode=NONE |
coherent |
--enforce-eager |
coherent |
Concrete artifacts:
- graph-mode baseline:
1024: startsThanks!!!!!!!!!!!!!!!!...2048: startsIt!!!!!!!!!!!!!!!!...4096: startsThanks!!!!!!!!!!!!!!!!...
cudagraph_mode=NONE:1024: startsThanks for the context!...2048: startsIt looks like you're referencing...4096: startsYou're asking for a technical report...
--enforce-eager:1024: startsThanks for the context!...2048: startsIt looks like you're referencing...4096: startsYou're asking for a technical report...
The full logs are saved for both follow-up runs:
../recreated/v019_kimi_tp8_graph_controls_20260416T215131Z/v019_pr2585_rebuilt_tp8_cudagraph_none/vllm_server.log../recreated/v019_kimi_tp8_graph_controls_20260416T215131Z/v019_pr2585_rebuilt_tp8_enforce_eager/vllm_server.log
Graph-Control Follow-Up On v0.19 Kimi PR2585 rebuilt At 16 Heads/GPU
I reran the same v0.19 + Kimi + PR2585 rebuilt + 16 heads/GPU control-path
case in two additional modes that remove cudagraph replay from the path:
| Runtime mode | Result |
|---|---|
graph-mode baseline (FULL_AND_PIECEWISE) |
incoherent |
cudagraph_mode=NONE |
coherent |
--enforce-eager |
coherent |
Concrete artifacts:
- graph-mode baseline:
1024: startsThanks!!!!!!!!...2048: startsIt!!!!!!!!...4096: startsIt!!!!!!!!...
cudagraph_mode=NONE:1024: startsThanks for the context!...2048: startsIt looks like you're referencing...4096: startsIt looks like you're referencing...
--enforce-eager:1024: startsThanks for the context!...2048: startsIt looks like you're referencing...4096: startsYou're asking for a technical report...
The full logs are saved for both follow-up runs:
../recreated/v019_kimi_tp2_graph_controls_20260417T005318Z/v019_pr2585_rebuilt_tp2_cudagraph_none/vllm_server.log../recreated/v019_kimi_tp2_graph_controls_20260417T005318Z/v019_pr2585_rebuilt_tp2_enforce_eager/vllm_server.log
Main Questions
- Kimi is incoherent on
v0.19even in the16heads/GPU control path, while
DeepSeek-V2-Lite is coherent there. Is this expected for Kimi on forced
ROCM_AITER_MLA? - In the supported
4heads/GPU regime, DeepSeek-V2-Lite is coherent on
bothv0.18bundled/rebuilt andv0.19rebuilt, but Kimi is incoherent on
v0.18bundled,v0.18rebuilt, andv0.19rebuilt. What Kimi-specific
interaction should we inspect next? v0.19bundled still fails the old small-head assert for both models in the
4heads/GPU regime. Is that expected until the rebuilt PR2585 path is used?- The strongest new discriminator is that
v0.19 + Kimi + PR2585 rebuiltis
coherent when cudagraph replay is removed, but incoherent in the original
graph-mode path at both16and4heads/GPU. Does that point to a known
issue in the ROCm MLA cudagraph path for Kimi specifically?
Implementation Observation
-
One implementation detail that stood out to us, and that you may already have
considered, is the score-scaling path inside the MLA decode kernel. If we are
reading the code correctly, the score path appears to be equivalent toz = (QK) / sqrt(D_q) w = exp2((z - m) * log2(e))and therefore algebraically equivalent to
w = exp2((QK) * (log2(e) / sqrt(D_q)) - m_2)The reason this stood out to us is that, if we are understanding the kernel
correctly, the folded form would appear to reduce the score-side elementwise
multiplies from two to one. For these specialized kernels,D_qappears
fixed by the kernel traits, so we wondered whether folding those factors into
a single constant scale might be worth considering. We mention this only as a
respectful observation in case it is useful, not as a claim that this is the
source of the Kimi incoherence.
Summary
Add transparent nhead < 16 support for MLA decode path, enabling models like Kimi-Linear-48B-A3B with TP=8 (nhead=4) to use the ASM kernel without manual head padding in vLLM.
Problem
The MLA decode ASM kernel requires nhead >= 16. Models with smaller head counts (e.g. nhead=4 for Kimi-Linear with TP=8) would either fail with an assertion error or require redundant head padding at the vLLM integration level, causing up to 4x wasted computation.
Changes
aiter/mla.py: Automatically pad query heads from nhead to 16 viarepeat_interleavebefore kernel dispatch, and strip output back to original nhead after decode. Also added FP8 block size entries for nhead=4 and nhead=8.aiter/ops/attention.py: Useeffective_num_head(padded to 16) for metadata buffer sizing when nhead < 16.csrc/kernels/mla/metadata/v1_2_device.cuh: Addpad_to_qh16path in C++ metadata generation for nhead < 16.op_tests/test_mla.py,test_mla_persistent.py,test_mla_sparse.py: Add nhead=4 (and nhead=8) to default test configurations.Test Plan
Validated on MI355X (gfx950) with: