[Feat] Add num_heads < 128 support for mla decode kernel#3309
[Feat] Add num_heads < 128 support for mla decode kernel#3309Observer007 wants to merge 9 commits into
Conversation
…te_dsl_impl= (AI-assisted) PR flashinfer-ai#2805 refactored the monolithic CuTe-DSL MLA decode kernel into a modular structure and removed the original implementation. The original authors want it kept available because the modular path is still maturing. Restore it under the cute-dsl backend (no new backend name) and let the user pick: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( ..., backend="cute-dsl", cute_dsl_impl="auto" | "modular" | "monolithic") Layout: flashinfer/cute_dsl/attention/ monolithic/ - restored kernels (verbatim from before flashinfer-ai#2805, relocated to live next to the modular code). Includes the H<128 / Kimi K2.5 fix from flashinfer-ai#3235 backported (workspace pads H to max(H, 128) when split_kv != 1; can_implement no longer rejects H<128). wrappers/ - existing modular standalone + wrapper. mla_dispatch.py - new dispatcher in front of both impls. Dispatcher contract: - "auto" (default): monolithic, but auto-promotes to modular when a modular-only feature is requested (currently sinks). - "modular" : strict, always modular. - "monolithic" : strict, raises ValueError if a modular-only feature is requested rather than silently substituting. The dispatcher strips modular-only kwargs (sinks=None) before forwarding to monolithic, so callers can pass sinks= unconditionally without breaking the monolithic path. Sinks support: - trtllm_batch_decode_with_kv_cache_mla(sinks=...) on backend= "cute-dsl" now constructs an AttentionWithSink variant inside the modular standalone, instead of being rejected at the API boundary. - AttentionWithSink gains value-based __hash__/__eq__ keyed on (data_ptr, shape, dtype) so @functools.cache on _compile_mla_kernel correctly reuses compiled kernels across invocations with the same sinks tensor. Without this, a fresh variant per call hashed by object identity, JIT-recompiled the kernel on every iteration, and made cuda-graph + sinks bench measurements appear to hang. Tests: - test_cute_dsl_mla_decode.py: existing standalone and public-API tests now parametrize over modular/monolithic via a cute_dsl_impl fixture; new minimal sinks tests pin the auto/modular dispatch branches and the monolithic+sinks ValueError contract. Wrapper sinks numerics remain covered by the pre-existing test_cute_dsl_mla_decode_attention_sink. - test_trtllm_gen_mla.py: comment near the cute-dsl skip refreshed to reflect the dispatcher's cute_dsl_impl behaviour. Bench: - bench_trtllm_gen_mla.py grows a focused 6-cell with_sinks=True sub-sweep (B in {1,16,128} x S in {1024,8192} at q_len=1, page=64, bf16) on top of the existing main sweep, instead of doubling the full grid. Argument list deduplicated into a common_kwargs dict so warmup and benchmark calls cannot drift.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdds fold-S_q packing and fold-aware causal masking to monolithic FP16/FP8 MLA kernels; converts LSE outputs to natural-log; exposes optional LSE buffers via wrapper/API with sinks and cute_dsl_impl plumbing; reworks FP8 dual-group softmax; parametrizes tests; fixes tile-scheduler index and benchmark q_lens grid. ChangesMonolithic Kernel Folding, Masking, LSE, and Test Updates
Sequence Diagram(s): sequenceDiagram
participant Caller as API Caller
participant Wrapper as cute_dsl_mla_decode (batch)
participant Monolithic as Monolithic Wrapper
participant Kernel as FP16/FP8 Kernel
Caller->>Wrapper: call with sinks, lse, return_lse, cute_dsl_impl
Wrapper->>Monolithic: validate/reshape/provide lse (modular: raise NotImplemented)
Monolithic->>Kernel: compile/run with num_heads, seq_len_q, fold_sq, lse buffer
Kernel-->>Monolithic: compute outputs, write LSE (natural-log)
Monolithic-->>Wrapper: return out or (out, lse)
Wrapper-->>Caller: return to caller
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a dispatcher for the cute-dsl backend to select between modular and monolithic MLA decode implementations. It adds support for sinks in the cute-dsl backend via automatic promotion to the modular implementation and restores monolithic kernels to support fold_sq and spec-decoding causal masks. Feedback was provided regarding a leftover debug print statement in the monolithic implementation.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode.py`:
- Line 456: The debug print statement in the MLA decoding hot path
(print(f"is_persistent: {is_persistent}")) should be removed to eliminate noisy
stdout and runtime overhead; locate the print in mla_decode.py (function/method
around mla_decode / the decode hot path) and delete it or replace it with a
conditional debug log using an existing logger (e.g., logger.debug) guarded by
log level checks so no per-call I/O occurs in production.
- Around line 396-399: The code assumes kv_cache becomes 3D but currently
squeezes dim 1 unconditionally when kv_cache.dim() == 4; update the check in the
mla_decode logic to validate the 4D layout before removing dim 1: only call
squeeze(1) if kv_cache.dim() == 4 and kv_cache.shape[1] == 1, otherwise raise a
clear ValueError (or reshape/transpose as appropriate) explaining the unexpected
kv_cache shape so downstream code that uses page_size = kv_cache.shape[1] and
other 3D assumptions fails fast and clearly; reference the kv_cache variable and
the block around the squeeze and page_size computation.
In `@flashinfer/cute_dsl/attention/wrappers/batch_mla.py`:
- Around line 838-854: The code currently only checks sinks.is_contiguous()
before creating AttentionWithSink, but you must also validate sinks.shape
matches the expected head count (e.g., 1-D and length == H) to avoid invalid
head indexing; before instantiating AttentionWithSink (and before calling
variant.extra_params.to(...)), assert sinks.dim() == 1 and sinks.size(0) == H
(or the variable holding the number of heads) and raise a ValueError that prints
sinks.shape and the expected H if the check fails, then proceed with creating
AttentionWithSink and converting variant.extra_params to torch.float32 on
query.device.
In `@flashinfer/mla/_core.py`:
- Around line 677-681: Update the docstring for the sinks parameter to reflect
actual backend support: state that sinks are supported by the cute-dsl backend
(requires modular implementation via cute_dsl_impl="auto" promoting to modular
or manual modular selection) but are rejected by the XQA execution path (which
raises a ValueError); ensure the wording mentions the sinks parameter name,
cute_dsl_impl, and the ValueError raised by the XQA path so the docstring aligns
with real behavior.
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Line 1130: The regex pattern passed to pytest.raises in the
pytest.raises(ValueError, match="monolithic.*sinks.*modular") call should be a
raw string to avoid unintended escaping; update the match argument to use a raw
string literal (r"monolithic.*sinks.*modular") in the pytest.raises(...)
invocation so the regex metacharacters are interpreted correctly and the RUF043
warning is resolved.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 39130b48-8f1e-4bb4-9335-f7f71e99355f
📒 Files selected for processing (13)
benchmarks/bench_trtllm_gen_mla.pyflashinfer/cute_dsl/attention/__init__.pyflashinfer/cute_dsl/attention/fusion/variant.pyflashinfer/cute_dsl/attention/mla_dispatch.pyflashinfer/cute_dsl/attention/monolithic/__init__.pyflashinfer/cute_dsl/attention/monolithic/mla_decode.pyflashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.pyflashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.pyflashinfer/cute_dsl/attention/monolithic/mla_helpers.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pyflashinfer/mla/_core.pytests/attention/test_cute_dsl_mla_decode.pytests/attention/test_trtllm_gen_mla.py
|
SGLang team is verifying the perf of this PR. |
Edit: the bug is actually due to shared workspace by accident from integration side, not the kernel. This PR LGTM. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py (1)
103-120:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate the constraints block to match the new folded-head path.
This block still documents a fixed 128-head requirement, but the kernel now explicitly supports
num_heads < 128via folding/padding. Leaving the old invariant here will mislead the next person debuggingcan_implement()or the fold logic.📝 Suggested doc update
Constraints: * Data type requirements: - Input/output: Float16 - Accumulation and LSE: Float32 * Fixed architecture parameters: - - Number of attention heads: 128 - Latent dimension: 512 - RoPE dimension: 64 +* Head count may be smaller than the 128-row MMA-M tile; when enabled, + `seq_len_q` can be folded into the head dimension to populate the tile. * Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) * Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py` around lines 103 - 120, The constraints comment in mla_decode_fp16.py still states a fixed 128-head requirement which is outdated; update that block to document that the kernel supports folded/padded heads (num_heads <= 128) and may accept num_heads < 128 via folding/padding, and clarify how RopeDim/LatentDim mapping and sequence-length limits (1-4) interact with folding. Edit the constraints nearby the class/definition for the monolithic FP16 kernel (the MLA monolithic implementation) and mention the folding/padding behavior referenced by can_implement() and the fold logic so readers know the kernel will pad/fold heads rather than requiring exactly 128 heads.flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py (1)
3754-3765:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winUpdate the pipeline participant counts for the second compute warp group.
After adding
second_compute_warp_ids, both softmax groups participate inmma_s_pipeline.consumer_wait()and produce intop_mma/p_cor, but these cooperative groups are still sized fromself.compute_warp_idsonly. That leaves half of the active compute warps outside the barrier accounting and can release stages early.Suggested fix
consumer_thread_size = ( self.threads_per_warp - * len(self.compute_warp_ids) + * self.num_total_compute_warps * self.cluster_shape_mnk[0] ) @@ producer_thread_size = ( self.threads_per_warp - * len(self.compute_warp_ids) + * self.num_total_compute_warps * self.cluster_shape_mnk[0] ) @@ - producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + producer_thread_size = self.threads_per_warp * self.num_total_compute_warpsAlso applies to: 3789-3836
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py` around lines 3754 - 3765, The cooperative group sizes for the MMA softmax pipeline are still built using only self.compute_warp_ids, which leaves the warps in second_compute_warp_ids out of the barrier; update construction of mma_s_producer_group and mma_s_consumer_group to account for both compute sets (e.g., use len(self.compute_warp_ids) + len(self.second_compute_warp_ids) or the combined list length) so that all threads participating in mma_s_pipeline.consumer_wait(), all producers writing to p_mma/p_cor, and all consumer threads are included in the CooperativeGroup counts (adjust any other similar groups in the region around mma_s_* and the block at 3789-3836).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode.py`:
- Around line 465-479: The code currently accepts lse by shape/dtype only but
may misinterpret non-compact memory; update the lse handling to require
contiguous C-order buffers and reject non-compact inputs: when lse is provided,
check lse.is_contiguous() (and lse.dim() in {2,3}) and raise a ValueError if not
contiguous (or not 2D/3D), then proceed to assign lse_k (either use lse directly
when shape==(B,q_len,H) or view when shape==(B*q_len,H)); reference lse, lse_k,
and B/q_len/H and use .is_contiguous() in the validation to ensure callers pass
compact storage.
---
Outside diff comments:
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py`:
- Around line 103-120: The constraints comment in mla_decode_fp16.py still
states a fixed 128-head requirement which is outdated; update that block to
document that the kernel supports folded/padded heads (num_heads <= 128) and may
accept num_heads < 128 via folding/padding, and clarify how RopeDim/LatentDim
mapping and sequence-length limits (1-4) interact with folding. Edit the
constraints nearby the class/definition for the monolithic FP16 kernel (the MLA
monolithic implementation) and mention the folding/padding behavior referenced
by can_implement() and the fold logic so readers know the kernel will pad/fold
heads rather than requiring exactly 128 heads.
In `@flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py`:
- Around line 3754-3765: The cooperative group sizes for the MMA softmax
pipeline are still built using only self.compute_warp_ids, which leaves the
warps in second_compute_warp_ids out of the barrier; update construction of
mma_s_producer_group and mma_s_consumer_group to account for both compute sets
(e.g., use len(self.compute_warp_ids) + len(self.second_compute_warp_ids) or the
combined list length) so that all threads participating in
mma_s_pipeline.consumer_wait(), all producers writing to p_mma/p_cor, and all
consumer threads are included in the CooperativeGroup counts (adjust any other
similar groups in the region around mma_s_* and the block at 3789-3836).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 830ef233-a6a0-4228-945d-132638b8a222
📒 Files selected for processing (8)
benchmarks/bench_trtllm_gen_mla.pyflashinfer/cute_dsl/attention/monolithic/mla_decode.pyflashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.pyflashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.pyflashinfer/cute_dsl/attention/monolithic/mla_helpers.pyflashinfer/cute_dsl/attention/wrappers/batch_mla.pyflashinfer/mla/_core.pytests/attention/test_cute_dsl_mla_decode.py
💤 Files with no reviewable changes (3)
- flashinfer/mla/_core.py
- tests/attention/test_cute_dsl_mla_decode.py
- flashinfer/cute_dsl/attention/wrappers/batch_mla.py
|
/bot run |
|
@Observer007 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
📌 Description
Features addition:
Performance enhancement:
Performance collection before & after this commit with MTP==4.
Blackwell B200, no freq lock. seqlen_k = 80k.
🔍 Related Issues
#3287
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores