Skip to content

Support MLA decode with nhead < 16 by transparent pad-to-16#2577

Open
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/mla-nhead-pad-to-16
Open

Support MLA decode with nhead < 16 by transparent pad-to-16#2577
ChuanLi1101 wants to merge 1 commit intomainfrom
chuan/mla-nhead-pad-to-16

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Contributor

Summary

  • For MLA models with small query head counts (e.g., Kimi-Linear-48B-A3B with TP=8 giving nhead=4), AITER's ASM kernel has no pre-compiled support for gqa_ratio < 16, causing decode failures.
  • This PR adds transparent head padding within AITER: when nhead < 16 and divides 16 evenly, Q is padded to 16 heads via repeat_interleave, the nhead=16 ASM kernel runs, then the output is un-padded.
  • Adjusts C++ persistent metadata generation and Python metadata sizing to accept nhead < 16.
  • Adds nhead=4 test configurations to test_mla.py and test_mla_persistent.py.

Changes

  • aiter/mla.py: Pad Q heads to 16 when nhead < 16 (both non-persistent and persistent paths), un-pad output before return. Add safe entries in get_block_n_fp8.
  • csrc/kernels/mla/metadata/v1_2_device.cuh: Add pad_to_qh16 logic for nhead < 16 in persistent metadata generation.
  • aiter/ops/attention.py: Relax num_head_qo % 16 assertion in get_mla_metadata_info_v1 to allow nhead < 16 with effective_num_head = 16.
  • op_tests/test_mla.py: Add nhead=(4,1) to default test configurations.
  • op_tests/test_mla_persistent.py: Add nhead=(4,1) to default test configurations.

Test plan

  • nhead=4 BF16 decode test on MI355X (gfx950): all checkAllclose passed
  • nhead=16 BF16 regression test: all passed, no regressions
  • nhead=4 FP8 decode test (future work - needs ASM kernel support)
  • CI pipeline validation

@ChuanLi1101 ChuanLi1101 requested a review from a team April 1, 2026 10:29
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 1, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2577 --add-label <label>

@ChuanLi1101
Copy link
Copy Markdown
Contributor Author

cc @valarLip @carlushuang — Requesting expedited review. This enables MLA decode with nhead < 16 (required for GLM-5 TP=8 on MI355X). This is blocking vLLM-side PRs (vllm-project/vllm#36855, vllm-project/vllm#38665) for customer-facing GLM-5 inference. See also #2563 for the upstream issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant