CuteDSL MoE fix redundant output buffer zeroing#2811
CuteDSL MoE fix redundant output buffer zeroing#2811nv-yunzheq merged 4 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the CuteDSL Mixture-of-Experts (MoE) pipeline by addressing an inefficiency in output buffer management. The change focuses on refining how output buffers are zeroed, moving from a broad, redundant operation to a targeted, efficient approach that zeros only the necessary portion of the buffer. This optimization is crucial for improving the forward pass performance of large language models by reducing computational overhead. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
📝 WalkthroughWalkthroughChanges shift zero-initialization responsibility for MoE outputs from internal code to the caller, add validation that provided output views are sliced to num_tokens, and adjust buffer-slicing behavior for CUDA-graph and async zeroing paths; comments/docstrings expanded about atomic accumulation and corruption risks. (49 words) Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller
participant Wrapper as FusedMoE
participant Finalize as FinalizeFusion
participant Stream as AsyncStream
rect rgba(200,230,255,0.5)
Caller->>Wrapper: forward(inputs, optional moe_output)
end
rect rgba(200,255,200,0.5)
Wrapper->>Wrapper: assert moe_output is sliced to num_tokens (if provided)
Wrapper->>Wrapper: use _moe_output[:num_tokens] for this batch
Wrapper->>Finalize: call finalize with moe_output slice (or None)
end
alt moe_output provided and async zero path
Wrapper->>Stream: schedule zero on active slice (async)
Stream-->>Finalize: zeroing happens on aux stream before finalize
else moe_output provided and non-async
Wrapper->>Finalize: call moe_output.zero_() on active slice
end
Finalize->>moe_output: atomic accumulates into moe_output slice
Note right of Finalize: Non-zero residuals in provided moe_output corrupt results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request effectively optimizes the CuteDSL MoE pipeline by addressing a redundant buffer zeroing operation. The change to zero only the active [:num_tokens] slice of the output buffer, rather than the entire max_num_tokens buffer, is a significant performance improvement. The responsibility for zeroing is correctly shifted to the caller of the low-level GEMM kernel, and the higher-level wrappers and functional APIs are updated accordingly with new assertions and buffer slicing logic. The accompanying documentation changes are clear and accurately reflect the new API contract. I have one suggestion to enhance an assertion's error message for better debuggability.
| assert moe_output.size(0) == num_tokens, ( | ||
| "moe_output must be sliced to num_tokens rows before calling _moe_core_impl" | ||
| ) |
There was a problem hiding this comment.
This assertion is a valuable safeguard. To improve debuggability when this assertion fails, I recommend enhancing the error message to include the actual and expected tensor sizes. This provides immediate, actionable context to the developer, reducing debugging time.
| assert moe_output.size(0) == num_tokens, ( | |
| "moe_output must be sliced to num_tokens rows before calling _moe_core_impl" | |
| ) | |
| assert moe_output.size(0) == num_tokens, ( | |
| f"moe_output has {moe_output.size(0)} rows, but expected {num_tokens}. " | |
| "It must be sliced to num_tokens rows before calling _moe_core_impl." | |
| ) |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 301-305: The public API function
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 currently removes
zero-initialization on caller-supplied out buffers, turning legitimate out=
usage into a silent accumulation bug; restore the original overwrite semantics
so when callers pass a non-None out buffer it is always zeroed (or explicitly
documented to be required zeroed), and move the zero-free fast path into an
internal helper (or gate it behind an explicit opt-in flag) used only by
internal callers; update the implementations referenced around the other similar
sites (the same pattern at the blocks around lines 321-325 and 410-420) to
follow the same approach so public entry points preserve overwrite semantics
while internal optimized paths can bypass zeroing when explicitly opted-in.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 185-197: The assert in the _moe_core_impl input handling is too
strict: allow callers to pass a larger preallocated moe_output buffer by
checking moe_output.size(0) >= num_tokens and moe_output.size(1) == hidden_size
(validate hidden dimension and dtype/device if desired), then locally slice
moe_output = moe_output[:num_tokens] before using it; keep the existing
allocation path when moe_output is None (using torch.empty((num_tokens,
hidden_size), dtype=output_dtype, device=x.device)) so the fast path and
CUDA-graph slice semantics remain intact.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5ccca6eb-7414-43bc-8b62-689d2a2f662f
📒 Files selected for processing (2)
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/fused_moe.py
| This tensor is used for atomic accumulation. If `out` is | ||
| provided, it must already be zero-initialized by the caller. | ||
| If `out` is None, this function allocates a zero-initialized | ||
| output tensor. Passing a non-zeroed `out` buffer will silently | ||
| produce incorrect results. |
There was a problem hiding this comment.
This turns the public out= path into a silent accumulation trap.
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 is still a flashinfer_api, so removing the internal zero on caller-supplied buffers breaks existing out= call sites with wrong answers rather than a loud failure. Please keep the zero-free fast path internal, or gate it behind an explicit opt-in, and preserve overwrite semantics on the public entry point.
Also applies to: 321-325, 410-420
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 301 - 305, The public API function
blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4 currently removes
zero-initialization on caller-supplied out buffers, turning legitimate out=
usage into a silent accumulation bug; restore the original overwrite semantics
so when callers pass a non-None out buffer it is always zeroed (or explicitly
documented to be required zeroed), and move the zero-free fast path into an
internal helper (or gate it behind an explicit opt-in flag) used only by
internal callers; update the implementations referenced around the other similar
sites (the same pattern at the blocks around lines 321-325 and 410-420) to
follow the same approach so public entry points preserve overwrite semantics
while internal optimized paths can bypass zeroing when explicitly opted-in.
| # Allocate output if not provided. The caller (wrapper or functional | ||
| # API) should pass a [:num_tokens] slice of the pre-allocated buffer | ||
| # when using CUDA graphs. The buffer is zeroed in Step 3 below. | ||
| if moe_output is None: | ||
| moe_output = torch.empty( | ||
| (num_tokens, hidden_size), | ||
| dtype=output_dtype, | ||
| device=x.device, | ||
| ) | ||
| else: | ||
| assert moe_output.size(0) == num_tokens, ( | ||
| "moe_output must be sliced to num_tokens rows before calling _moe_core_impl" | ||
| ) |
There was a problem hiding this comment.
Don't require an exact-row moe_output here.
The new assert breaks callers that reuse a larger pre-allocated output buffer, even though the optimization only needs a [:num_tokens] view. Accept size(0) >= num_tokens, validate the hidden dimension, and slice locally so the fast path stays intact without changing the public contract.
Proposed fix
if moe_output is None:
moe_output = torch.empty(
(num_tokens, hidden_size),
dtype=output_dtype,
device=x.device,
)
else:
- assert moe_output.size(0) == num_tokens, (
- "moe_output must be sliced to num_tokens rows before calling _moe_core_impl"
- )
+ if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size:
+ raise ValueError(
+ "moe_output must have shape [>= num_tokens, hidden_size]"
+ )
+ moe_output = moe_output[:num_tokens]
+ if not moe_output.is_contiguous():
+ raise ValueError("moe_output[:num_tokens] must be contiguous")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 185 - 197, The
assert in the _moe_core_impl input handling is too strict: allow callers to pass
a larger preallocated moe_output buffer by checking moe_output.size(0) >=
num_tokens and moe_output.size(1) == hidden_size (validate hidden dimension and
dtype/device if desired), then locally slice moe_output =
moe_output[:num_tokens] before using it; keep the existing allocation path when
moe_output is None (using torch.empty((num_tokens, hidden_size),
dtype=output_dtype, device=x.device)) so the fast path and CUDA-graph slice
semantics remain intact.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)
185-197:⚠️ Potential issue | 🟠 MajorValidate full
moe_outputshape before finalize.Line 195 currently validates only the row count. Please also validate the hidden dimension (
size(1) == hidden_size) before launching finalize to avoid late CUDA-side failures with mismatched output buffers.Suggested patch
else: - assert moe_output.size(0) == num_tokens, ( - "moe_output must be sliced to num_tokens rows before calling _moe_core_impl" - ) + if moe_output.size(0) != num_tokens or moe_output.size(1) != hidden_size: + raise ValueError( + f"moe_output must have shape [{num_tokens}, {hidden_size}], " + f"got {tuple(moe_output.shape)}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 185 - 197, The current check for moe_output only validates row count; before calling _moe_core_impl/finalize you must also validate the column dimension matches hidden_size. Update the assert on moe_output (in the allocation branch) to require moe_output.size(0) == num_tokens and moe_output.size(1) == hidden_size (or raise a clear error) so mismatched hidden dimensions are caught early and avoid CUDA failures in finalize.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 185-197: The current check for moe_output only validates row
count; before calling _moe_core_impl/finalize you must also validate the column
dimension matches hidden_size. Update the assert on moe_output (in the
allocation branch) to require moe_output.size(0) == num_tokens and
moe_output.size(1) == hidden_size (or raise a clear error) so mismatched hidden
dimensions are caught early and avoid CUDA failures in finalize.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4394e3c0-8b29-49fa-b3db-8851f98020fb
📒 Files selected for processing (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)
194-198:⚠️ Potential issue | 🟠 MajorRelax and harden
moe_outputvalidation in_moe_core_impl.Using
asserthere is fragile for runtime validation, and the exact-row check rejects valid larger preallocated buffers. Validate[>= num_tokens, hidden_size], slice locally, and raiseValueErrorfor invalid inputs.Suggested fix
else: - assert moe_output.size(0) == num_tokens, ( - f"moe_output must be sliced to num_tokens rows before calling " - f"_moe_core_impl (got {moe_output.size(0)}, expected {num_tokens})" - ) + if moe_output.size(0) < num_tokens or moe_output.size(1) != hidden_size: + raise ValueError( + f"moe_output must have shape [>= {num_tokens}, {hidden_size}], " + f"got {tuple(moe_output.shape)}" + ) + moe_output = moe_output[:num_tokens] + if not moe_output.is_contiguous(): + raise ValueError("moe_output[:num_tokens] must be contiguous")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 194 - 198, Replace the fragile assert in _moe_core_impl that checks moe_output.size(0) == num_tokens with robust runtime validation: verify moe_output is a 2D tensor whose first dimension is >= num_tokens and whose second dimension equals hidden_size, slice a local view (e.g., moe_output_sliced = moe_output[:num_tokens]) for subsequent use, and raise a ValueError with a clear message if the shapes don't match (include actual shapes and expected num_tokens/hidden_size). Ensure you only change validation and create a local sliced tensor so callers with larger preallocated buffers continue to work.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 194-198: Replace the fragile assert in _moe_core_impl that checks
moe_output.size(0) == num_tokens with robust runtime validation: verify
moe_output is a 2D tensor whose first dimension is >= num_tokens and whose
second dimension equals hidden_size, slice a local view (e.g., moe_output_sliced
= moe_output[:num_tokens]) for subsequent use, and raise a ValueError with a
clear message if the shapes don't match (include actual shapes and expected
num_tokens/hidden_size). Ensure you only change validation and create a local
sliced tensor so callers with larger preallocated buffers continue to work.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5e3308b6-df79-4f2c-b1b0-ea7c7b0976d4
📒 Files selected for processing (1)
flashinfer/fused_moe/cute_dsl/fused_moe.py
|
[SUCCESS] Pipeline #46387257: 13/20 passed |
📌 Description
The CuteDSL MoE pipeline redundantly zeroed the entire max_num_tokens output buffer before each GEMM2 scatter-add, costing ~3.7 ms/fwd across 61 layers in DeepSeek R1. This PR replaces it with a dense zero of only the active [:num_tokens] slice, overlapped with GEMM1 on an auxiliary stream — matching TRT-LLM's original zeroing strategy.
🔍 Related Issues
feat: cuteDSL fp4 moe for better DSR1 performance.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit