Skip to content

CuteDSL MoE fix redundant output buffer zeroing#2811

Merged
nv-yunzheq merged 4 commits intoflashinfer-ai:mainfrom
leejnau:cutedsl-fix-redundant-zero
Mar 19, 2026
Merged

CuteDSL MoE fix redundant output buffer zeroing#2811
nv-yunzheq merged 4 commits intoflashinfer-ai:mainfrom
leejnau:cutedsl-fix-redundant-zero

Conversation

@leejnau
Copy link
Contributor

@leejnau leejnau commented Mar 18, 2026

📌 Description

The CuteDSL MoE pipeline redundantly zeroed the entire max_num_tokens output buffer before each GEMM2 scatter-add, costing ~3.7 ms/fwd across 61 layers in DeepSeek R1. This PR replaces it with a dense zero of only the active [:num_tokens] slice, overlapped with GEMM1 on an auxiliary stream — matching TRT-LLM's original zeroing strategy.

🔍 Related Issues

feat: cuteDSL fp4 moe for better DSR1 performance.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Improved memory management and per-batch output slicing for CUDA-graph execution; allocation/zeroing now targets only the active output slice.
  • Bug Fixes / Reliability
    • Added validation of provided output buffers; callers must supply zero-initialized preallocated outputs to avoid corruption (runtime no longer implicitly zeroes caller buffers).
  • Documentation
    • Clarified output ownership, zero-initialization contract, and atomic-accumulate behavior in user-facing docs.

@leejnau leejnau requested a review from nv-yunzheq as a code owner March 18, 2026 00:26
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the CuteDSL Mixture-of-Experts (MoE) pipeline by addressing an inefficiency in output buffer management. The change focuses on refining how output buffers are zeroed, moving from a broad, redundant operation to a targeted, efficient approach that zeros only the necessary portion of the buffer. This optimization is crucial for improving the forward pass performance of large language models by reducing computational overhead.

Highlights

  • Performance Optimization: Eliminated redundant zeroing of the entire output buffer in the CuteDSL MoE pipeline, replacing it with a dense zero of only the active slice, which is overlapped with GEMM1 on an auxiliary stream. This change is expected to yield significant performance improvements, specifically reducing processing time by approximately 3.7 ms/fwd across 61 layers in DeepSeek R1.
  • Refactored Output Buffer Handling: Modified the blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 function to no longer zero the output buffer internally if it's provided by the caller. Instead, the responsibility for zero-initialization is explicitly shifted to the caller, allowing for more precise control and optimization.
  • CUDA Graph Compatibility and Efficiency: Enhanced CUDA graph compatibility by ensuring that when a pre-allocated output buffer is used, only the active num_tokens slice is passed to the core MoE implementation. This prevents unnecessary zeroing of the entire max_num_tokens buffer, aligning with the optimized zeroing strategy.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 18, 2026

📝 Walkthrough

Walkthrough

Changes shift zero-initialization responsibility for MoE outputs from internal code to the caller, add validation that provided output views are sliced to num_tokens, and adjust buffer-slicing behavior for CUDA-graph and async zeroing paths; comments/docstrings expanded about atomic accumulation and corruption risks. (49 words)

Changes

Cohort / File(s) Summary
MoE Finalize Fusion Buffer Contract
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Docstring and comments updated to state that if out is provided it must be zero-initialized by the caller; removed code that forcibly zeroed a provided out. Expanded notes on atomic accumulate behavior and corruption risk from non-zero residuals.
MoE Wrapper & Buffer Management
flashinfer/fused_moe/cute_dsl/fused_moe.py
Removed moe_output_memset import and usage; replaced with explicit moe_output.zero_() calls. Added assert that provided moe_output is sliced to num_tokens. Adjusted handling to use _moe_output[:num_tokens] for CUDA-graph paths and to zero only the active slice (async and non-async variants).

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Wrapper as FusedMoE
    participant Finalize as FinalizeFusion
    participant Stream as AsyncStream

    rect rgba(200,230,255,0.5)
    Caller->>Wrapper: forward(inputs, optional moe_output)
    end

    rect rgba(200,255,200,0.5)
    Wrapper->>Wrapper: assert moe_output is sliced to num_tokens (if provided)
    Wrapper->>Wrapper: use _moe_output[:num_tokens] for this batch
    Wrapper->>Finalize: call finalize with moe_output slice (or None)
    end

    alt moe_output provided and async zero path
        Wrapper->>Stream: schedule zero on active slice (async)
        Stream-->>Finalize: zeroing happens on aux stream before finalize
    else moe_output provided and non-async
        Wrapper->>Finalize: call moe_output.zero_() on active slice
    end

    Finalize->>moe_output: atomic accumulates into moe_output slice
    Note right of Finalize: Non-zero residuals in provided moe_output corrupt results
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested labels

v0.6.3

Suggested reviewers

  • wenscarl
  • djmmoss
  • cyx-6
  • yzh119

🐰 I used to hop and zero all night,
Now callers bring candles — neat and bright.
Slice me small, keep residues at bay,
Atomic hops land where zeros lay.
A tiny change, a clearer way.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: replacing redundant full-buffer zeroing with selective zeroing of active tokens in the CuteDSL MoE pipeline.
Description check ✅ Passed The description is complete with details on the performance issue, the proposed solution, related issues, and all checklist items properly marked.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively optimizes the CuteDSL MoE pipeline by addressing a redundant buffer zeroing operation. The change to zero only the active [:num_tokens] slice of the output buffer, rather than the entire max_num_tokens buffer, is a significant performance improvement. The responsibility for zeroing is correctly shifted to the caller of the low-level GEMM kernel, and the higher-level wrappers and functional APIs are updated accordingly with new assertions and buffer slicing logic. The accompanying documentation changes are clear and accurately reflect the new API contract. I have one suggestion to enhance an assertion's error message for better debuggability.

Comment on lines +195 to +197
assert moe_output.size(0) == num_tokens, (
"moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is a valuable safeguard. To improve debuggability when this assertion fails, I recommend enhancing the error message to include the actual and expected tensor sizes. This provides immediate, actionable context to the developer, reducing debugging time.

Suggested change
assert moe_output.size(0) == num_tokens, (
"moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
)
assert moe_output.size(0) == num_tokens, (
f"moe_output has {moe_output.size(0)} rows, but expected {num_tokens}. "
"It must be sliced to num_tokens rows before calling _moe_core_impl."
)

@nv-yunzheq
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !422 has been created, and the CI pipeline #46387257 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 301-305: The public API function
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 currently removes
zero-initialization on caller-supplied out buffers, turning legitimate out=
usage into a silent accumulation bug; restore the original overwrite semantics
so when callers pass a non-None out buffer it is always zeroed (or explicitly
documented to be required zeroed), and move the zero-free fast path into an
internal helper (or gate it behind an explicit opt-in flag) used only by
internal callers; update the implementations referenced around the other similar
sites (the same pattern at the blocks around lines 321-325 and 410-420) to
follow the same approach so public entry points preserve overwrite semantics
while internal optimized paths can bypass zeroing when explicitly opted-in.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 185-197: The assert in the _moe_core_impl input handling is too
strict: allow callers to pass a larger preallocated moe_output buffer by
checking moe_output.size(0) >= num_tokens and moe_output.size(1) == hidden_size
(validate hidden dimension and dtype/device if desired), then locally slice
moe_output = moe_output[:num_tokens] before using it; keep the existing
allocation path when moe_output is None (using torch.empty((num_tokens,
hidden_size), dtype=output_dtype, device=x.device)) so the fast path and
CUDA-graph slice semantics remain intact.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5ccca6eb-7414-43bc-8b62-689d2a2f662f

📥 Commits

Reviewing files that changed from the base of the PR and between f7322d9 and 5b8b5e8.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py

Comment on lines +301 to +305
This tensor is used for atomic accumulation. If `out` is
provided, it must already be zero-initialized by the caller.
If `out` is None, this function allocates a zero-initialized
output tensor. Passing a non-zeroed `out` buffer will silently
produce incorrect results.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

This turns the public out= path into a silent accumulation trap.

blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 is still a flashinfer_api, so removing the internal zero on caller-supplied buffers breaks existing out= call sites with wrong answers rather than a loud failure. Please keep the zero-free fast path internal, or gate it behind an explicit opt-in, and preserve overwrite semantics on the public entry point.

Also applies to: 321-325, 410-420

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 301 - 305, The public API function
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 currently removes
zero-initialization on caller-supplied out buffers, turning legitimate out=
usage into a silent accumulation bug; restore the original overwrite semantics
so when callers pass a non-None out buffer it is always zeroed (or explicitly
documented to be required zeroed), and move the zero-free fast path into an
internal helper (or gate it behind an explicit opt-in flag) used only by
internal callers; update the implementations referenced around the other similar
sites (the same pattern at the blocks around lines 321-325 and 410-420) to
follow the same approach so public entry points preserve overwrite semantics
while internal optimized paths can bypass zeroing when explicitly opted-in.

Comment on lines +185 to +197
# Allocate output if not provided. The caller (wrapper or functional
# API) should pass a [:num_tokens] slice of the pre-allocated buffer
# when using CUDA graphs. The buffer is zeroed in Step 3 below.
if moe_output is None:
moe_output = torch.empty(
(num_tokens, hidden_size),
dtype=output_dtype,
device=x.device,
)
else:
assert moe_output.size(0) == num_tokens, (
"moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't require an exact-row moe_output here.

The new assert breaks callers that reuse a larger pre-allocated output buffer, even though the optimization only needs a [:num_tokens] view. Accept size(0) >= num_tokens, validate the hidden dimension, and slice locally so the fast path stays intact without changing the public contract.

Proposed fix
     if moe_output is None:
         moe_output = torch.empty(
             (num_tokens, hidden_size),
             dtype=output_dtype,
             device=x.device,
         )
     else:
-        assert moe_output.size(0) == num_tokens, (
-            "moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
-        )
+        if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size:
+            raise ValueError(
+                "moe_output must have shape [>= num_tokens, hidden_size]"
+            )
+        moe_output = moe_output[:num_tokens]
+        if not moe_output.is_contiguous():
+            raise ValueError("moe_output[:num_tokens] must be contiguous")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 185 - 197, The
assert in the _moe_core_impl input handling is too strict: allow callers to pass
a larger preallocated moe_output buffer by checking moe_output.size(0) >=
num_tokens and moe_output.size(1) == hidden_size (validate hidden dimension and
dtype/device if desired), then locally slice moe_output =
moe_output[:num_tokens] before using it; keep the existing allocation path when
moe_output is None (using torch.empty((num_tokens, hidden_size),
dtype=output_dtype, device=x.device)) so the fast path and CUDA-graph slice
semantics remain intact.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)

185-197: ⚠️ Potential issue | 🟠 Major

Validate full moe_output shape before finalize.

Line 195 currently validates only the row count. Please also validate the hidden dimension (size(1) == hidden_size) before launching finalize to avoid late CUDA-side failures with mismatched output buffers.

Suggested patch
     else:
-        assert moe_output.size(0) == num_tokens, (
-            "moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
-        )
+        if moe_output.size(0) != num_tokens or moe_output.size(1) != hidden_size:
+            raise ValueError(
+                f"moe_output must have shape [{num_tokens}, {hidden_size}], "
+                f"got {tuple(moe_output.shape)}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 185 - 197, The
current check for moe_output only validates row count; before calling
_moe_core_impl/finalize you must also validate the column dimension matches
hidden_size. Update the assert on moe_output (in the allocation branch) to
require moe_output.size(0) == num_tokens and moe_output.size(1) == hidden_size
(or raise a clear error) so mismatched hidden dimensions are caught early and
avoid CUDA failures in finalize.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 185-197: The current check for moe_output only validates row
count; before calling _moe_core_impl/finalize you must also validate the column
dimension matches hidden_size. Update the assert on moe_output (in the
allocation branch) to require moe_output.size(0) == num_tokens and
moe_output.size(1) == hidden_size (or raise a clear error) so mismatched hidden
dimensions are caught early and avoid CUDA failures in finalize.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4394e3c0-8b29-49fa-b3db-8851f98020fb

📥 Commits

Reviewing files that changed from the base of the PR and between 5b8b5e8 and 7c68930.

📒 Files selected for processing (1)
  • flashinfer/fused_moe/cute_dsl/fused_moe.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)

194-198: ⚠️ Potential issue | 🟠 Major

Relax and harden moe_output validation in _moe_core_impl.

Using assert here is fragile for runtime validation, and the exact-row check rejects valid larger preallocated buffers. Validate [>= num_tokens, hidden_size], slice locally, and raise ValueError for invalid inputs.

Suggested fix
     else:
-        assert moe_output.size(0) == num_tokens, (
-            f"moe_output must be sliced to num_tokens rows before calling "
-            f"_moe_core_impl (got {moe_output.size(0)}, expected {num_tokens})"
-        )
+        if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size:
+            raise ValueError(
+                f"moe_output must have shape [>= {num_tokens}, {hidden_size}], "
+                f"got {tuple(moe_output.shape)}"
+            )
+        moe_output = moe_output[:num_tokens]
+        if not moe_output.is_contiguous():
+            raise ValueError("moe_output[:num_tokens] must be contiguous")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 194 - 198, Replace
the fragile assert in _moe_core_impl that checks moe_output.size(0) ==
num_tokens with robust runtime validation: verify moe_output is a 2D tensor
whose first dimension is >= num_tokens and whose second dimension equals
hidden_size, slice a local view (e.g., moe_output_sliced =
moe_output[:num_tokens]) for subsequent use, and raise a ValueError with a clear
message if the shapes don't match (include actual shapes and expected
num_tokens/hidden_size). Ensure you only change validation and create a local
sliced tensor so callers with larger preallocated buffers continue to work.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 194-198: Replace the fragile assert in _moe_core_impl that checks
moe_output.size(0) == num_tokens with robust runtime validation: verify
moe_output is a 2D tensor whose first dimension is >= num_tokens and whose
second dimension equals hidden_size, slice a local view (e.g., moe_output_sliced
= moe_output[:num_tokens]) for subsequent use, and raise a ValueError with a
clear message if the shapes don't match (include actual shapes and expected
num_tokens/hidden_size). Ensure you only change validation and create a local
sliced tensor so callers with larger preallocated buffers continue to work.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5e3308b6-df79-4f2c-b1b0-ea7c7b0976d4

📥 Commits

Reviewing files that changed from the base of the PR and between 7c68930 and bb49b24.

📒 Files selected for processing (1)
  • flashinfer/fused_moe/cute_dsl/fused_moe.py

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46387257: 13/20 passed

@nv-yunzheq nv-yunzheq enabled auto-merge (squash) March 18, 2026 22:49
@nv-yunzheq nv-yunzheq merged commit 27b9dc7 into flashinfer-ai:main Mar 19, 2026
31 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants