feat: add pdl support for cute dsl mla decode kernel support#2901
feat: add pdl support for cute dsl mla decode kernel support#2901yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThreads a new runtime Changes
Sequence Diagram(s)sequenceDiagram
participant API as trtllm_batch_decode_with_kv_cache_mla
participant DeviceCheck as device_support_pdl
participant Wrapper as cute_dsl_mla_decode
participant Kernel as MLA Kernel (FP16/FP8)
participant GPU as Device
API->>DeviceCheck: if enable_pdl is None -> query device
DeviceCheck-->>API: boolean enable_pdl
API->>Wrapper: call cute_dsl_mla_decode(..., enable_pdl)
Wrapper->>Kernel: construct/lookup kernel(enable_pdl)
Kernel->>GPU: launch kernels with use_pdl=enable_pdl
GPU-->>Kernel: execute (PDL sync paths gated by enable_pdl)
Kernel-->>Wrapper: outputs
Wrapper-->>API: decoded output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Programmatic Dependency Launch (PDL) within the cute-dsl backend for MLA decode kernels, covering both FP16 and FP8 implementations. The changes include updating the core API to handle the enable_pdl flag, integrating grid dependency controls (griddepcontrol_wait and griddepcontrol_launch_dependents) into the CUDA kernels, and updating the test suite to validate these new paths. Feedback focuses on critical race conditions identified in the placement of dependency launch calls within the kernels, which should be moved to the correction warp blocks to ensure data integrity. Additionally, further parameterization of existing tests is recommended to ensure full coverage of the PDL feature.
I am having trouble creating individual review comments. Click here to see my feedback.
flashinfer/mla/cute_dsl/mla_decode_fp16.py (1199-1200)
The griddepcontrol_launch_dependents() call is misplaced. It is currently in the mma_warp block, but the reduction_kernel depends on the output of the correction warps (mAccO and mAccLSE). Launching dependents from the MMA warp can lead to a race condition where the reduction_kernel starts before the correction warps have finished writing their data.
This call should be moved to the end of the correction_warp_ids block, after all work tiles have been processed, and should be conditional on mAccO is not None (i.e., when split_kv > 1 and the reduction kernel is actually launched).
Please remove these lines and add the following code at the end of the if warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1]: block (around line 1332):
if cutlass.const_expr(self.enable_pdl):
if mAccO is not None:
cute.arch.griddepcontrol_launch_dependents()flashinfer/mla/cute_dsl/mla_decode_fp8.py (1263-1264)
The griddepcontrol_launch_dependents() call is misplaced, which can cause a race condition. It should be called by the correction warps, which produce the data for the dependent reduction_kernel, not by the mma_warp.
This should be moved to the end of the correction_warp_ids block, after all work tiles have been processed, and should be conditional on mAccO is not None (i.e., when split_kv > 1 and the reduction kernel is actually launched). A similar issue exists in mla_decode_fp16.py.
Please remove these lines and add the corresponding logic at the end of the correction_warp_ids block.
tests/attention/test_cute_dsl_mla_decode.py (192)
To ensure full test coverage for the new enable_pdl feature, this test should also be parameterized for both True and False values of enable_pdl, similar to other tests in this file. Currently, it only runs with enable_pdl=False.
Please add the following decorator before this function:
@pytest.mark.parametrize("enable_pdl", [True, False])def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, enable_pdl, page_size=128):
tests/attention/test_cute_dsl_mla_decode.py (274)
To ensure full test coverage for the new enable_pdl feature, this test should also be parameterized for both True and False values of enable_pdl, similar to other tests in this file. Currently, it only runs with enable_pdl=False.
Please add the following decorator before this function:
@pytest.mark.parametrize("enable_pdl", [True, False])def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, enable_pdl, page_size=128):
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)
332-389:⚠️ Potential issue | 🟡 Minor
enable_pdlis parametrized here but never reaches the cute-dsl call.Both parametrized cases currently execute the same backend configuration, so this parity test does not actually cover the new PDL path and just duplicates runtime.
🐛 Forward the flag into the cute-dsl backend
out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla( - **common_args, backend="cute-dsl", is_var_seq=False + **common_args, backend="cute-dsl", is_var_seq=False, enable_pdl=enable_pdl )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 332 - 389, The parametrized enable_pdl flag isn't forwarded into the cute-dsl call, so both branches run identical configs; update the trtllm_batch_decode_with_kv_cache_mla invocation that produces out_cute_dsl to pass the flag (e.g. add enable_pdl=enable_pdl or the correct parameter name expected by trtllm_batch_decode_with_kv_cache_mla) while leaving the trtllm-gen call unchanged; locate the two calls to trtllm_batch_decode_with_kv_cache_mla and add the enable_pdl argument to the cute-dsl call.
🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)
190-246: These two tests still only exerciseenable_pdl=False.Adding a defaulted parameter does not create another pytest case, so the variable-seq and API flows never hit the new PDL branch. Parametrize
enable_pdlhere the same way as the FP16/FP8 tests.♻️ Suggested test expansion
`@pytest.mark.parametrize`("batch_size", [1, 4]) `@pytest.mark.parametrize`("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128, enable_pdl=False): +@pytest.mark.parametrize("enable_pdl", [False, True]) +def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, enable_pdl, page_size=128): @@ `@pytest.mark.parametrize`("batch_size", [1, 4]) `@pytest.mark.parametrize`("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128, enable_pdl=False): +@pytest.mark.parametrize("enable_pdl", [False, True]) +def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, enable_pdl, page_size=128):Also applies to: 272-325
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 190 - 246, The test test_cute_dsl_mla_decode_variable_seq_len currently takes enable_pdl as a default argument so pytest only runs the False case; change the decorator list to parametrize enable_pdl (e.g. add `@pytest.mark.parametrize`("enable_pdl",[False, True]) alongside the existing batch_size and seq_len_k parametrizations) so both PDL branches run, and make the same update to the other test covering lines 272-325 (the companion MLA decode test) to parametrize enable_pdl there as well; keep the argument name enable_pdl in the function signature so the runner injects the values.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 716-717: The PDL (dependent launch) is being triggered too early
(right after tmem.free()) causing reduction_kernel() to consume
partially-written accumulators mAccLSE/mAccO before epilogue() finishes; locate
the producer-side dependent-launch trigger(s) associated with
use_pdl/self.enable_pdl (occurrences around the reduction launch sites
referenced) and move each trigger to after the final global stores performed by
the epilogue/correction path (or alternatively replace the trigger with an
explicit signal emitted from epilogue() after the final stores); ensure changes
touch the reduction launch locations, the tmem.free() site, and the
epilogue()/correction path so the reduction_kernel() only starts after
mAccLSE/mAccO are fully written (apply same fix to the other occurrences noted).
- Line 173: The new required parameter enable_pdl in
BlackwellMultiHeadLatentAttentionForwardFP16 breaks callers; either make
enable_pdl optional with a default (e.g., default to False) on the
BlackwellMultiHeadLatentAttentionForwardFP16 constructor signature or propagate
the flag from the module entrypoint by adding an enable_pdl parameter to
run()/the CLI and passing it into the
BlackwellMultiHeadLatentAttentionForwardFP16(...) instantiation (update its call
site around line where run() constructs the object). Ensure the symbol
BlackwellMultiHeadLatentAttentionForwardFP16 and the run()/CLI entrypoint are
updated consistently so existing local examples that construct the object
without enable_pdl continue to work.
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 169-170: The new required parameter enable_pdl breaks callers
(notably run() which constructs BlackwellMultiHeadLatentAttentionForwardFP8);
make enable_pdl backward-compatible by giving it a default (e.g., change the
constructor/signature to enable_pdl: bool = False) so existing instantiations
still work, or alternatively propagate a PDL flag from run() and the CLI into
the BlackwellMultiHeadLatentAttentionForwardFP8 constructor so run() (the code
that builds the module) supplies the parameter; update whichever symbol you
choose: the constructor of BlackwellMultiHeadLatentAttentionForwardFP8 and/or
the run() call site that constructs it.
---
Outside diff comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 332-389: The parametrized enable_pdl flag isn't forwarded into the
cute-dsl call, so both branches run identical configs; update the
trtllm_batch_decode_with_kv_cache_mla invocation that produces out_cute_dsl to
pass the flag (e.g. add enable_pdl=enable_pdl or the correct parameter name
expected by trtllm_batch_decode_with_kv_cache_mla) while leaving the trtllm-gen
call unchanged; locate the two calls to trtllm_batch_decode_with_kv_cache_mla
and add the enable_pdl argument to the cute-dsl call.
---
Nitpick comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 190-246: The test test_cute_dsl_mla_decode_variable_seq_len
currently takes enable_pdl as a default argument so pytest only runs the False
case; change the decorator list to parametrize enable_pdl (e.g. add
`@pytest.mark.parametrize`("enable_pdl",[False, True]) alongside the existing
batch_size and seq_len_k parametrizations) so both PDL branches run, and make
the same update to the other test covering lines 272-325 (the companion MLA
decode test) to parametrize enable_pdl there as well; keep the argument name
enable_pdl in the function signature so the runner injects the values.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6a987e2e-932b-4512-bad6-099c4304bee7
📒 Files selected for processing (5)
flashinfer/mla/_core.pyflashinfer/mla/cute_dsl/mla_decode.pyflashinfer/mla/cute_dsl/mla_decode_fp16.pyflashinfer/mla/cute_dsl/mla_decode_fp8.pytests/attention/test_cute_dsl_mla_decode.py
|
/bot run |
|
[FAILED] Pipeline #47106799: 11/20 passed |
The BlackwellMultiHeadLatentAttentionForward{FP8,FP16} constructors
require enable_pdl but the run() harnesses weren't passing it (mypy
call-arg error). Also follow public API convention in cute_dsl_mla_decode
by defaulting to Optional[bool] = None with device_support_pdl() auto-detection.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
These standalone test/benchmark harnesses are superseded by the public API in mla_decode.py and proper tests in tests/. Also removes the now-unused argparse import. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/mla/cute_dsl/mla_decode_fp16.py (1)
709-717: Only enable PDL when the reduction kernel is actually present.
acc_o is Noneis the single-kernel path, sosplit_kv_kernelhas no dependent launch to unlock. Gatinguse_pdlwithacc_o is not Noneavoids paying the PDL bookends on the no-workspace case.♻️ Minimal change
- use_pdl=self.enable_pdl, + use_pdl=self.enable_pdl and acc_o is not None,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py` around lines 709 - 717, The launch call for split_kv_kernel should only enable PDL when the reduction kernel exists: check the accumulator variable acc_o and set use_pdl to True only if acc_o is not None; update the launch invocation for split_kv_kernel (the .launch(...) call) to gate use_pdl with acc_o is not None so the single-kernel (acc_o is None) path does not pay PDL bookends.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 3553-3554: The function in mla_decode_fp8.py now accepts an
enable_pdl: bool = False parameter but its docstring lacks an entry; add a short
docstring line for enable_pdl (similar style to the existing use_cold_l2 entry)
describing that it is a boolean flag, default False, and explaining what
enabling PDL does (e.g., toggles PDL-specific decoding/optimizations) and any
behavioral impact; place this new docstring entry with the other parameter
descriptions in the function's docstring so it’s consistent and easy to
discover.
---
Nitpick comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 709-717: The launch call for split_kv_kernel should only enable
PDL when the reduction kernel exists: check the accumulator variable acc_o and
set use_pdl to True only if acc_o is not None; update the launch invocation for
split_kv_kernel (the .launch(...) call) to gate use_pdl with acc_o is not None
so the single-kernel (acc_o is None) path does not pay PDL bookends.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 58ecae4a-8ef0-4c6e-9c73-a8bb9cad3fe9
📒 Files selected for processing (3)
flashinfer/mla/cute_dsl/mla_decode.pyflashinfer/mla/cute_dsl/mla_decode_fp16.pyflashinfer/mla/cute_dsl/mla_decode_fp8.py
There was a problem hiding this comment.
♻️ Duplicate comments (2)
flashinfer/mla/cute_dsl/mla_decode_fp16.py (1)
1197-1200:⚠️ Potential issue | 🔴 CriticalPDL dependents are still released before split-kernel outputs are finalized.
Line 1199 triggers dependents from the MMA warp immediately after TMEM free, but reduction waits (Line 1380) and then consumes
mAccLSE/mAccOthat are completed by correction/epilogue later. This can race whenenable_pdl=True.💡 Suggested fix (move trigger to post-epilogue completion path)
@@ tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_launch_dependents() @@ if ( warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1] ): @@ tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() + + # Release dependent reduction only after correction/epilogue writes are done. + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py` around lines 1197 - 1200, The PDL dependent trigger (cute.arch.griddepcontrol_launch_dependents()) is fired too early right after tmem.free() while reduction/epilogue still produce/consume mAccLSE and mAccO; move the call out of the immediate TMEM cleanup path and into the post-epilogue completion path (after the reduction/correction that finalizes mAccLSE/mAccO) so dependents are launched only when outputs are fully finalized; keep the conditional cutlass.const_expr(self.enable_pdl) check and ensure the call is placed after the routine that performs the reduction/epilogue completion (the code that consumes mAccLSE/mAccO).flashinfer/mla/cute_dsl/mla_decode_fp8.py (1)
3553-3554:⚠️ Potential issue | 🟡 MinorAdd docstring for
enable_pdlparameter.The new
enable_pdlparameter is missing from the function's docstring (lines 3556-3616). For consistency with other parameters likeuse_cold_l2, please add documentation.📝 Proposed docstring addition (insert after line 3612)
:param use_cold_l2: Whether to use cold L2 cache :type use_cold_l2: bool + :param enable_pdl: Whether to enable Programmatic Dependent Launch (PDL) for kernel scheduling + :type enable_pdl: bool🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py` around lines 3553 - 3554, The function signature includes a new boolean parameter enable_pdl but it is missing from the function docstring; add a short docstring entry for enable_pdl (mirroring style used for use_cold_l2) describing its type, default and behavior (e.g., "enable_pdl (bool, optional): Whether to enable PDL decoding. Defaults to False.") and place it with the other parameter descriptions in the function's docstring so it appears alongside use_cold_l2 and the other params.
🧹 Nitpick comments (1)
flashinfer/mla/cute_dsl/mla_decode_fp8.py (1)
1263-1264: PDLgriddepcontrol_launch_dependents()timing is suboptimal but not a functional correctness issue.The call at line 1264 signals dependent kernels before the correction warp finishes writing to
mAccOandmAccLSE. While this could cause thereduction_kernelto launch earlier than necessary,griddepcontrol_launch_dependents()is documented as a performance hint only ("doesn't impact the functionality"), and the actual synchronization is enforced by the GPU memory model and thereduction_kernel's explicitgriddepcontrol_wait()call.For better performance, consider moving this call to after the correction warp completes its writes, or adding explicit synchronization before signaling. This will prevent the dependent kernel from launching prematurely and competing for resources while the output data is still being written.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py` around lines 1263 - 1264, The PDL hint call cute.arch.griddepcontrol_launch_dependents() is placed before the correction warp finishes writing outputs (mAccO and mAccLSE), which can let reduction_kernel start too early; move the griddepcontrol_launch_dependents() call to immediately after the correction-warp write sequence (or insert an explicit synchronization/fence right before the launch call) so the correction warp's writes complete before signaling dependents; update the branch that checks cutlass.const_expr(self.enable_pdl) to perform the launch only after the correction-warp completion (or add a gridfence/memory barrier prior to cute.arch.griddepcontrol_launch_dependents()).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 1197-1200: The PDL dependent trigger
(cute.arch.griddepcontrol_launch_dependents()) is fired too early right after
tmem.free() while reduction/epilogue still produce/consume mAccLSE and mAccO;
move the call out of the immediate TMEM cleanup path and into the post-epilogue
completion path (after the reduction/correction that finalizes mAccLSE/mAccO) so
dependents are launched only when outputs are fully finalized; keep the
conditional cutlass.const_expr(self.enable_pdl) check and ensure the call is
placed after the routine that performs the reduction/epilogue completion (the
code that consumes mAccLSE/mAccO).
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 3553-3554: The function signature includes a new boolean parameter
enable_pdl but it is missing from the function docstring; add a short docstring
entry for enable_pdl (mirroring style used for use_cold_l2) describing its type,
default and behavior (e.g., "enable_pdl (bool, optional): Whether to enable PDL
decoding. Defaults to False.") and place it with the other parameter
descriptions in the function's docstring so it appears alongside use_cold_l2 and
the other params.
---
Nitpick comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 1263-1264: The PDL hint call
cute.arch.griddepcontrol_launch_dependents() is placed before the correction
warp finishes writing outputs (mAccO and mAccLSE), which can let
reduction_kernel start too early; move the griddepcontrol_launch_dependents()
call to immediately after the correction-warp write sequence (or insert an
explicit synchronization/fence right before the launch call) so the correction
warp's writes complete before signaling dependents; update the branch that
checks cutlass.const_expr(self.enable_pdl) to perform the launch only after the
correction-warp completion (or add a gridfence/memory barrier prior to
cute.arch.griddepcontrol_launch_dependents()).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e7e74280-b0e8-4a28-91a6-4d7c2b0265e2
📒 Files selected for processing (2)
flashinfer/mla/cute_dsl/mla_decode_fp16.pyflashinfer/mla/cute_dsl/mla_decode_fp8.py
|
[FAILED] Pipeline #47135584: 13/20 passed |
📌 Description
Add pdl support for cute dsl mla decode kernels,
Add pdl before,

Add pdl after,

🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores