Skip to content

feat: add pdl support for cute dsl mla decode kernel support#2901

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Observer007:feature/mla_pdl
Mar 29, 2026
Merged

feat: add pdl support for cute dsl mla decode kernel support#2901
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Observer007:feature/mla_pdl

Conversation

@Observer007
Copy link
Copy Markdown
Contributor

@Observer007 Observer007 commented Mar 27, 2026

📌 Description

Add pdl support for cute dsl mla decode kernels,

Add pdl before,
image

Add pdl after,
image

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added an optional PDL (Progressive Dependency Lookahead) toggle for MLA decoding with a dynamic default that auto-selects per device; APIs now accept an explicit enable/disable option.
  • Tests

    • Expanded tests to parametrize and validate PDL enabled/disabled across multiple decoding scenarios and backends.
  • Chores

    • Removed some module-level CLI entrypoints in favor of programmatic API usage.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 27, 2026

📝 Walkthrough

Walkthrough

Threads a new runtime enable_pdl flag from the public API into the CuTe DSL decode wrapper and MLA kernels; when unspecified it is resolved per-device. FP16/FP8 kernels gain PDL-control guards and tests are updated to exercise both flag values.

Changes

Cohort / File(s) Summary
Core API
flashinfer/mla/_core.py
Resolve enable_pdl when None via device_support_pdl(query.device) and forward enable_pdl= into cute_dsl_mla_decode.
CuTe DSL Wrapper
flashinfer/mla/cute_dsl/mla_decode.py
Add enable_pdl: Optional[bool]=None to cute_dsl_mla_decode; resolve runtime value and pass enable_pdl into _get_compiled_mla_kernel; add enable_pdl: bool to _get_compiled_mla_kernel so kernel cache key includes the flag.
Kernel Implementations (FP16/FP8)
flashinfer/mla/cute_dsl/mla_decode_fp16.py, flashinfer/mla/cute_dsl/mla_decode_fp8.py
Add enable_pdl ctor param stored on instances; pass use_pdl=self.enable_pdl into kernel launches; insert cute.arch.griddepcontrol_wait() / cute.arch.griddepcontrol_launch_dependents() guarded by cutlass.const_expr(self.enable_pdl) around split-KV and reduction regions; remove module CLI entrypoints and argparse imports; expose enable_pdl through public run(...).
Tests
tests/attention/test_cute_dsl_mla_decode.py
Parametrize tests over enable_pdl (True/False) and propagate the flag into cute_dsl_mla_decode and trtllm_batch_decode_with_kv_cache_mla calls.

Sequence Diagram(s)

sequenceDiagram
    participant API as trtllm_batch_decode_with_kv_cache_mla
    participant DeviceCheck as device_support_pdl
    participant Wrapper as cute_dsl_mla_decode
    participant Kernel as MLA Kernel (FP16/FP8)
    participant GPU as Device

    API->>DeviceCheck: if enable_pdl is None -> query device
    DeviceCheck-->>API: boolean enable_pdl
    API->>Wrapper: call cute_dsl_mla_decode(..., enable_pdl)
    Wrapper->>Kernel: construct/lookup kernel(enable_pdl)
    Kernel->>GPU: launch kernels with use_pdl=enable_pdl
    GPU-->>Kernel: execute (PDL sync paths gated by enable_pdl)
    Kernel-->>Wrapper: outputs
    Wrapper-->>API: decoded output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • bkryu
  • aleozlx

Poem

🐰 I hopped from API down to kernel light,
threaded enable_pdl snug and tight,
waits and launches now align,
kernels hum in proper time,
a little rabbit's engineering delight.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding PDL support for cute DSL MLA decode kernels, which aligns with the file modifications across the codebase.
Description check ✅ Passed The PR description follows the template with a clear description section (including performance images), completed pre-commit checklist, and confirmed test completion, though related issues and reviewer notes remain unfilled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Programmatic Dependency Launch (PDL) within the cute-dsl backend for MLA decode kernels, covering both FP16 and FP8 implementations. The changes include updating the core API to handle the enable_pdl flag, integrating grid dependency controls (griddepcontrol_wait and griddepcontrol_launch_dependents) into the CUDA kernels, and updating the test suite to validate these new paths. Feedback focuses on critical race conditions identified in the placement of dependency launch calls within the kernels, which should be moved to the correction warp blocks to ensure data integrity. Additionally, further parameterization of existing tests is recommended to ensure full coverage of the PDL feature.

I am having trouble creating individual review comments. Click here to see my feedback.

flashinfer/mla/cute_dsl/mla_decode_fp16.py (1199-1200)

critical

The griddepcontrol_launch_dependents() call is misplaced. It is currently in the mma_warp block, but the reduction_kernel depends on the output of the correction warps (mAccO and mAccLSE). Launching dependents from the MMA warp can lead to a race condition where the reduction_kernel starts before the correction warps have finished writing their data.

This call should be moved to the end of the correction_warp_ids block, after all work tiles have been processed, and should be conditional on mAccO is not None (i.e., when split_kv > 1 and the reduction kernel is actually launched).

Please remove these lines and add the following code at the end of the if warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1]: block (around line 1332):

if cutlass.const_expr(self.enable_pdl):
    if mAccO is not None:
        cute.arch.griddepcontrol_launch_dependents()

flashinfer/mla/cute_dsl/mla_decode_fp8.py (1263-1264)

critical

The griddepcontrol_launch_dependents() call is misplaced, which can cause a race condition. It should be called by the correction warps, which produce the data for the dependent reduction_kernel, not by the mma_warp.

This should be moved to the end of the correction_warp_ids block, after all work tiles have been processed, and should be conditional on mAccO is not None (i.e., when split_kv > 1 and the reduction kernel is actually launched). A similar issue exists in mla_decode_fp16.py.

Please remove these lines and add the corresponding logic at the end of the correction_warp_ids block.

tests/attention/test_cute_dsl_mla_decode.py (192)

medium

To ensure full test coverage for the new enable_pdl feature, this test should also be parameterized for both True and False values of enable_pdl, similar to other tests in this file. Currently, it only runs with enable_pdl=False.

Please add the following decorator before this function:

@pytest.mark.parametrize("enable_pdl", [True, False])
def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, enable_pdl, page_size=128):

tests/attention/test_cute_dsl_mla_decode.py (274)

medium

To ensure full test coverage for the new enable_pdl feature, this test should also be parameterized for both True and False values of enable_pdl, similar to other tests in this file. Currently, it only runs with enable_pdl=False.

Please add the following decorator before this function:

@pytest.mark.parametrize("enable_pdl", [True, False])
def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, enable_pdl, page_size=128):

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)

332-389: ⚠️ Potential issue | 🟡 Minor

enable_pdl is parametrized here but never reaches the cute-dsl call.

Both parametrized cases currently execute the same backend configuration, so this parity test does not actually cover the new PDL path and just duplicates runtime.

🐛 Forward the flag into the cute-dsl backend
     out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla(
-        **common_args, backend="cute-dsl", is_var_seq=False
+        **common_args, backend="cute-dsl", is_var_seq=False, enable_pdl=enable_pdl
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 332 - 389, The
parametrized enable_pdl flag isn't forwarded into the cute-dsl call, so both
branches run identical configs; update the trtllm_batch_decode_with_kv_cache_mla
invocation that produces out_cute_dsl to pass the flag (e.g. add
enable_pdl=enable_pdl or the correct parameter name expected by
trtllm_batch_decode_with_kv_cache_mla) while leaving the trtllm-gen call
unchanged; locate the two calls to trtllm_batch_decode_with_kv_cache_mla and add
the enable_pdl argument to the cute-dsl call.
🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)

190-246: These two tests still only exercise enable_pdl=False.

Adding a defaulted parameter does not create another pytest case, so the variable-seq and API flows never hit the new PDL branch. Parametrize enable_pdl here the same way as the FP16/FP8 tests.

♻️ Suggested test expansion
 `@pytest.mark.parametrize`("batch_size", [1, 4])
 `@pytest.mark.parametrize`("seq_len_k", [128, 512])
-def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128, enable_pdl=False):
+@pytest.mark.parametrize("enable_pdl", [False, True])
+def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, enable_pdl, page_size=128):
@@
 `@pytest.mark.parametrize`("batch_size", [1, 4])
 `@pytest.mark.parametrize`("seq_len_k", [128, 512])
-def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128, enable_pdl=False):
+@pytest.mark.parametrize("enable_pdl", [False, True])
+def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, enable_pdl, page_size=128):

Also applies to: 272-325

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 190 - 246, The test
test_cute_dsl_mla_decode_variable_seq_len currently takes enable_pdl as a
default argument so pytest only runs the False case; change the decorator list
to parametrize enable_pdl (e.g. add
`@pytest.mark.parametrize`("enable_pdl",[False, True]) alongside the existing
batch_size and seq_len_k parametrizations) so both PDL branches run, and make
the same update to the other test covering lines 272-325 (the companion MLA
decode test) to parametrize enable_pdl there as well; keep the argument name
enable_pdl in the function signature so the runner injects the values.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 716-717: The PDL (dependent launch) is being triggered too early
(right after tmem.free()) causing reduction_kernel() to consume
partially-written accumulators mAccLSE/mAccO before epilogue() finishes; locate
the producer-side dependent-launch trigger(s) associated with
use_pdl/self.enable_pdl (occurrences around the reduction launch sites
referenced) and move each trigger to after the final global stores performed by
the epilogue/correction path (or alternatively replace the trigger with an
explicit signal emitted from epilogue() after the final stores); ensure changes
touch the reduction launch locations, the tmem.free() site, and the
epilogue()/correction path so the reduction_kernel() only starts after
mAccLSE/mAccO are fully written (apply same fix to the other occurrences noted).
- Line 173: The new required parameter enable_pdl in
BlackwellMultiHeadLatentAttentionForwardFP16 breaks callers; either make
enable_pdl optional with a default (e.g., default to False) on the
BlackwellMultiHeadLatentAttentionForwardFP16 constructor signature or propagate
the flag from the module entrypoint by adding an enable_pdl parameter to
run()/the CLI and passing it into the
BlackwellMultiHeadLatentAttentionForwardFP16(...) instantiation (update its call
site around line where run() constructs the object). Ensure the symbol
BlackwellMultiHeadLatentAttentionForwardFP16 and the run()/CLI entrypoint are
updated consistently so existing local examples that construct the object
without enable_pdl continue to work.

In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 169-170: The new required parameter enable_pdl breaks callers
(notably run() which constructs BlackwellMultiHeadLatentAttentionForwardFP8);
make enable_pdl backward-compatible by giving it a default (e.g., change the
constructor/signature to enable_pdl: bool = False) so existing instantiations
still work, or alternatively propagate a PDL flag from run() and the CLI into
the BlackwellMultiHeadLatentAttentionForwardFP8 constructor so run() (the code
that builds the module) supplies the parameter; update whichever symbol you
choose: the constructor of BlackwellMultiHeadLatentAttentionForwardFP8 and/or
the run() call site that constructs it.

---

Outside diff comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 332-389: The parametrized enable_pdl flag isn't forwarded into the
cute-dsl call, so both branches run identical configs; update the
trtllm_batch_decode_with_kv_cache_mla invocation that produces out_cute_dsl to
pass the flag (e.g. add enable_pdl=enable_pdl or the correct parameter name
expected by trtllm_batch_decode_with_kv_cache_mla) while leaving the trtllm-gen
call unchanged; locate the two calls to trtllm_batch_decode_with_kv_cache_mla
and add the enable_pdl argument to the cute-dsl call.

---

Nitpick comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 190-246: The test test_cute_dsl_mla_decode_variable_seq_len
currently takes enable_pdl as a default argument so pytest only runs the False
case; change the decorator list to parametrize enable_pdl (e.g. add
`@pytest.mark.parametrize`("enable_pdl",[False, True]) alongside the existing
batch_size and seq_len_k parametrizations) so both PDL branches run, and make
the same update to the other test covering lines 272-325 (the companion MLA
decode test) to parametrize enable_pdl there as well; keep the argument name
enable_pdl in the function signature so the runner injects the values.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6a987e2e-932b-4512-bad6-099c4304bee7

📥 Commits

Reviewing files that changed from the base of the PR and between 31b63bc and f4af5db.

📒 Files selected for processing (5)
  • flashinfer/mla/_core.py
  • flashinfer/mla/cute_dsl/mla_decode.py
  • flashinfer/mla/cute_dsl/mla_decode_fp16.py
  • flashinfer/mla/cute_dsl/mla_decode_fp8.py
  • tests/attention/test_cute_dsl_mla_decode.py

@bkryu bkryu added the run-ci label Mar 27, 2026
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 27, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !466 has been created, and the CI pipeline #47106799 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47106799: 11/20 passed

yzh119 and others added 2 commits March 27, 2026 23:43
The BlackwellMultiHeadLatentAttentionForward{FP8,FP16} constructors
require enable_pdl but the run() harnesses weren't passing it (mypy
call-arg error). Also follow public API convention in cute_dsl_mla_decode
by defaulting to Optional[bool] = None with device_support_pdl() auto-detection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
These standalone test/benchmark harnesses are superseded by the public
API in mla_decode.py and proper tests in tests/. Also removes the
now-unused argparse import.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 28, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !466 has been updated with latest changes, and the CI pipeline #47135584 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/mla/cute_dsl/mla_decode_fp16.py (1)

709-717: Only enable PDL when the reduction kernel is actually present.

acc_o is None is the single-kernel path, so split_kv_kernel has no dependent launch to unlock. Gating use_pdl with acc_o is not None avoids paying the PDL bookends on the no-workspace case.

♻️ Minimal change
-            use_pdl=self.enable_pdl,
+            use_pdl=self.enable_pdl and acc_o is not None,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py` around lines 709 - 717, The
launch call for split_kv_kernel should only enable PDL when the reduction kernel
exists: check the accumulator variable acc_o and set use_pdl to True only if
acc_o is not None; update the launch invocation for split_kv_kernel (the
.launch(...) call) to gate use_pdl with acc_o is not None so the single-kernel
(acc_o is None) path does not pay PDL bookends.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 3553-3554: The function in mla_decode_fp8.py now accepts an
enable_pdl: bool = False parameter but its docstring lacks an entry; add a short
docstring line for enable_pdl (similar style to the existing use_cold_l2 entry)
describing that it is a boolean flag, default False, and explaining what
enabling PDL does (e.g., toggles PDL-specific decoding/optimizations) and any
behavioral impact; place this new docstring entry with the other parameter
descriptions in the function's docstring so it’s consistent and easy to
discover.

---

Nitpick comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 709-717: The launch call for split_kv_kernel should only enable
PDL when the reduction kernel exists: check the accumulator variable acc_o and
set use_pdl to True only if acc_o is not None; update the launch invocation for
split_kv_kernel (the .launch(...) call) to gate use_pdl with acc_o is not None
so the single-kernel (acc_o is None) path does not pay PDL bookends.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 58ecae4a-8ef0-4c6e-9c73-a8bb9cad3fe9

📥 Commits

Reviewing files that changed from the base of the PR and between c66d55a and d732387.

📒 Files selected for processing (3)
  • flashinfer/mla/cute_dsl/mla_decode.py
  • flashinfer/mla/cute_dsl/mla_decode_fp16.py
  • flashinfer/mla/cute_dsl/mla_decode_fp8.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
flashinfer/mla/cute_dsl/mla_decode_fp16.py (1)

1197-1200: ⚠️ Potential issue | 🔴 Critical

PDL dependents are still released before split-kernel outputs are finalized.

Line 1199 triggers dependents from the MMA warp immediately after TMEM free, but reduction waits (Line 1380) and then consumes mAccLSE/mAccO that are completed by correction/epilogue later. This can race when enable_pdl=True.

💡 Suggested fix (move trigger to post-epilogue completion path)
@@
             tmem.relinquish_alloc_permit()
             tmem.free(tmem_ptr)
-            if cutlass.const_expr(self.enable_pdl):
-                cute.arch.griddepcontrol_launch_dependents()
@@
         if (
             warp_idx >= self.correction_warp_ids[0]
             and warp_idx <= self.correction_warp_ids[-1]
         ):
@@
                 tile_sched.advance_to_next_work()
                 work_tile = tile_sched.get_current_work()
+
+            # Release dependent reduction only after correction/epilogue writes are done.
+            if cutlass.const_expr(self.enable_pdl):
+                cute.arch.griddepcontrol_launch_dependents()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py` around lines 1197 - 1200, The PDL
dependent trigger (cute.arch.griddepcontrol_launch_dependents()) is fired too
early right after tmem.free() while reduction/epilogue still produce/consume
mAccLSE and mAccO; move the call out of the immediate TMEM cleanup path and into
the post-epilogue completion path (after the reduction/correction that finalizes
mAccLSE/mAccO) so dependents are launched only when outputs are fully finalized;
keep the conditional cutlass.const_expr(self.enable_pdl) check and ensure the
call is placed after the routine that performs the reduction/epilogue completion
(the code that consumes mAccLSE/mAccO).
flashinfer/mla/cute_dsl/mla_decode_fp8.py (1)

3553-3554: ⚠️ Potential issue | 🟡 Minor

Add docstring for enable_pdl parameter.

The new enable_pdl parameter is missing from the function's docstring (lines 3556-3616). For consistency with other parameters like use_cold_l2, please add documentation.

📝 Proposed docstring addition (insert after line 3612)
     :param use_cold_l2: Whether to use cold L2 cache
     :type use_cold_l2: bool
+    :param enable_pdl: Whether to enable Programmatic Dependent Launch (PDL) for kernel scheduling
+    :type enable_pdl: bool
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py` around lines 3553 - 3554, The
function signature includes a new boolean parameter enable_pdl but it is missing
from the function docstring; add a short docstring entry for enable_pdl
(mirroring style used for use_cold_l2) describing its type, default and behavior
(e.g., "enable_pdl (bool, optional): Whether to enable PDL decoding. Defaults to
False.") and place it with the other parameter descriptions in the function's
docstring so it appears alongside use_cold_l2 and the other params.
🧹 Nitpick comments (1)
flashinfer/mla/cute_dsl/mla_decode_fp8.py (1)

1263-1264: PDL griddepcontrol_launch_dependents() timing is suboptimal but not a functional correctness issue.

The call at line 1264 signals dependent kernels before the correction warp finishes writing to mAccO and mAccLSE. While this could cause the reduction_kernel to launch earlier than necessary, griddepcontrol_launch_dependents() is documented as a performance hint only ("doesn't impact the functionality"), and the actual synchronization is enforced by the GPU memory model and the reduction_kernel's explicit griddepcontrol_wait() call.

For better performance, consider moving this call to after the correction warp completes its writes, or adding explicit synchronization before signaling. This will prevent the dependent kernel from launching prematurely and competing for resources while the output data is still being written.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py` around lines 1263 - 1264, The PDL
hint call cute.arch.griddepcontrol_launch_dependents() is placed before the
correction warp finishes writing outputs (mAccO and mAccLSE), which can let
reduction_kernel start too early; move the griddepcontrol_launch_dependents()
call to immediately after the correction-warp write sequence (or insert an
explicit synchronization/fence right before the launch call) so the correction
warp's writes complete before signaling dependents; update the branch that
checks cutlass.const_expr(self.enable_pdl) to perform the launch only after the
correction-warp completion (or add a gridfence/memory barrier prior to
cute.arch.griddepcontrol_launch_dependents()).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp16.py`:
- Around line 1197-1200: The PDL dependent trigger
(cute.arch.griddepcontrol_launch_dependents()) is fired too early right after
tmem.free() while reduction/epilogue still produce/consume mAccLSE and mAccO;
move the call out of the immediate TMEM cleanup path and into the post-epilogue
completion path (after the reduction/correction that finalizes mAccLSE/mAccO) so
dependents are launched only when outputs are fully finalized; keep the
conditional cutlass.const_expr(self.enable_pdl) check and ensure the call is
placed after the routine that performs the reduction/epilogue completion (the
code that consumes mAccLSE/mAccO).

In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 3553-3554: The function signature includes a new boolean parameter
enable_pdl but it is missing from the function docstring; add a short docstring
entry for enable_pdl (mirroring style used for use_cold_l2) describing its type,
default and behavior (e.g., "enable_pdl (bool, optional): Whether to enable PDL
decoding. Defaults to False.") and place it with the other parameter
descriptions in the function's docstring so it appears alongside use_cold_l2 and
the other params.

---

Nitpick comments:
In `@flashinfer/mla/cute_dsl/mla_decode_fp8.py`:
- Around line 1263-1264: The PDL hint call
cute.arch.griddepcontrol_launch_dependents() is placed before the correction
warp finishes writing outputs (mAccO and mAccLSE), which can let
reduction_kernel start too early; move the griddepcontrol_launch_dependents()
call to immediately after the correction-warp write sequence (or insert an
explicit synchronization/fence right before the launch call) so the correction
warp's writes complete before signaling dependents; update the branch that
checks cutlass.const_expr(self.enable_pdl) to perform the launch only after the
correction-warp completion (or add a gridfence/memory barrier prior to
cute.arch.griddepcontrol_launch_dependents()).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7e74280-b0e8-4a28-91a6-4d7c2b0265e2

📥 Commits

Reviewing files that changed from the base of the PR and between d732387 and eaed46b.

📒 Files selected for processing (2)
  • flashinfer/mla/cute_dsl/mla_decode_fp16.py
  • flashinfer/mla/cute_dsl/mla_decode_fp8.py

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47135584: 13/20 passed

@yzh119 yzh119 enabled auto-merge (squash) March 29, 2026 04:26
@yzh119 yzh119 merged commit 779c24d into flashinfer-ai:main Mar 29, 2026
47 of 59 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants