Skip to content

feat: Support unpadded output hidden size for trtllm_fp4_block_scale_moe#2217

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
elvischenv:elvischenv/support_moe_output_hidden_size
Dec 16, 2025
Merged

feat: Support unpadded output hidden size for trtllm_fp4_block_scale_moe#2217
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
elvischenv:elvischenv/support_moe_output_hidden_size

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Dec 14, 2025

📌 Description

Support unpadded output hidden size for trtllm_fp4_block_scale_moe by passing the output tensor with unpadded hidden size, so that we don't need to do slicing after MoE.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Corrected output dimension handling in MoE operations to use actual output tensor shapes rather than assumed values.
    • Enhanced validation for pre-allocated output tensors with stricter shape and dtype checks.
    • Improved consistency between different code execution paths for MoE operations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 14, 2025

Walkthrough

The changes modify the FP4 block-scale MoE kernel and Python wrapper to derive output dimensions directly from the actual output tensor shape rather than internal assumptions. Validation checks are added to pre-allocated output tensors to enforce shape and dtype consistency.

Changes

Cohort / File(s) Summary
CUDA Kernel Update
csrc/trtllm_fused_moe_kernel_launcher.cu
Modified FP4 block-scale MoE path to set hidden_size_output from output.size(1) instead of internal hidden_size, aligning reported output dimension with actual tensor shape.
Python Wrapper Validation
flashinfer/fused_moe/core.py
Added check_shape_dtype_device validation for explicit output tensors in trtllm_fp4_block_scale_moe_op, enforcing shape constraints. Updated _fake_trtllm_fp4_block_scale_moe to derive hidden_size from output.shape[1] when output is provided, aligning fake path with actual allocation.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10–15 minutes

  • Verify that output.size(1) correctly reflects the intended output dimension in all calling contexts
  • Confirm that validation constraints in check_shape_dtype_device align with the kernel's actual requirements and don't introduce unexpected failures for valid use cases

Poem

🐰 A rabbit hops through tensors bright,
Dimensions aligned, now perfect and right,
No hidden assumptions in the MoE way,
Output shapes dance in the light of day! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding support for unpadded output hidden size in trtllm_fp4_block_scale_moe, which aligns with the changeset modifications.
Description check ✅ Passed The description explains what the PR does and why, though the Related Issues section is empty and the Tests section remains unchecked, indicating tests may not have been added.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @elvischenv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the trtllm_fp4_block_scale_moe operation by enabling direct support for unpadded output hidden sizes. This change streamlines the MoE workflow by allowing the kernel to produce an output tensor with the exact required dimensions, thereby removing the necessity for subsequent slicing operations and potentially improving overall performance and simplifying downstream processing.

Highlights

  • Unpadded Output Hidden Size Support: The trtllm_fp4_block_scale_moe function now directly supports an unpadded output hidden size, eliminating the need for manual slicing after the MoE operation.
  • Kernel Argument Update: The hidden_size_output argument for the MoE kernel is now dynamically set based on the provided output tensor's actual hidden dimension, rather than a potentially padded hidden_size.
  • Output Tensor Validation: Added validation checks for the output tensor's shape and data type when an unpadded output is utilized, ensuring its first dimension matches num_tokens and its second dimension is less than or equal to the full hidden_size.
  • Fake MoE Function Update: The _fake_trtllm_fp4_block_scale_moe function has been updated to correctly determine the hidden_size for its output based on whether an unpadded output tensor is provided.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for unpadded output hidden sizes in trtllm_fp4_block_scale_moe, which is a valuable optimization to avoid unnecessary slicing operations. The changes made to the C++ kernel launcher and the Python operator are well-implemented and align with the stated goal. I've identified a minor issue in the corresponding "fake" operator implementation, which could result in incorrect behavior under specific conditions. My detailed feedback is provided in the review comments.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)

1756-1765: Consider additional validation for edge cases.

The validation logic correctly ensures the output tensor has the right dtype, device, and shape constraints. However, consider adding checks for edge cases:

  1. Dimension check: Verify output.ndim == 2 before accessing output.shape[0] and output.shape[1] to provide a clearer error message if the wrong dimensionality is provided.

  2. Minimum size check: Consider validating that output.shape[1] > 0 to catch cases where a zero-sized output dimension is provided, which may not be meaningful for MoE operations.

Apply this diff to add dimension validation:

 else:
     check_shape_dtype_device(
         output, None, torch.bfloat16, hidden_states.device, "output"
     )
+    assert output.ndim == 2, (
+        f"output must be 2D, got {output.ndim}D"
+    )
     assert output.shape[0] == num_tokens, (
         f"output.shape[0]={output.shape[0]} must be equal to {num_tokens}"
     )
     assert output.shape[1] <= hidden_size, (
         f"output.shape[1]={output.shape[1]} must be less than or equal to {hidden_size}"
     )
+    assert output.shape[1] > 0, (
+        f"output.shape[1]={output.shape[1]} must be greater than 0"
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ac4e1d and 59bd2bb.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (1 hunks)
  • flashinfer/fused_moe/core.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/fused_moe/core.py (1)
flashinfer/utils.py (1)
  • check_shape_dtype_device (565-583)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/fused_moe/core.py (1)

1912-1912: LGTM! Fake implementation now consistent with actual behavior.

This change correctly derives hidden_size from the provided output tensor when available, ensuring the fake implementation (used for tracing/compilation) matches the actual kernel behavior. This maintains consistency with the C++ changes that use output.size(1) for hidden_size_output.

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1748-1748: The kernel implementation correctly respects the hidden_size_output parameter. The finalize kernel uses hiddenDim (set from hidden_size_output) for output buffer calculations and loop bounds, separate from hiddenDimPadded (the internal padded dimension). Output writes are bounded by numElemsInCol = params.hiddenDim / FINALIZE_ELEM_PER_THREAD, ensuring no buffer overflows occur when output.size(1) < hidden_size.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 15, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !196 has been created, and the CI pipeline #40201089 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 15, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !196 has been updated with latest changes, and the CI pipeline #40207125 is currently running. I'll report back once the pipeline job completes.

):
seq_len = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
hidden_size = hidden_states.shape[1] if output is None else output.shape[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @elvischenv just want to make sure my understanding is right, does this mean hidden_states.shape[1] is the effective hidden dimension where output.shape[1] could be the padded hidden dimension?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. hidden_states.shape[1] is already padded, the input should be padded before the MoE kernel. Before this PR, the API does not have info of the original unpadded hidden dim, the output will be in the same shape with hidden_states. We have to slice the output to original unpadded hidden size, which involves extra overhead.

This PR allows us to pass output with unpadded hidden dim, then the kernel will write the results in such dim by setting args->hidden_size_output = output.size(1);. Then we won't need the slice anymore.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation!

Comment on lines +1760 to +1765
assert output.shape[0] == num_tokens, (
f"output.shape[0]={output.shape[0]} must be equal to {num_tokens}"
)
assert output.shape[1] <= hidden_size, (
f"output.shape[1]={output.shape[1]} must be less than or equal to {hidden_size}"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this assertion since the output could be unpadded so that smaller than the (padded) hidden size.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #40207125: canceled

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed UT are not relevant and let's merge this first.

@yzh119 yzh119 merged commit 02b4c5a into flashinfer-ai:main Dec 16, 2025
4 checks passed
@elvischenv elvischenv deleted the elvischenv/support_moe_output_hidden_size branch December 17, 2025 09:51
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…moe (flashinfer-ai#2217)

<!-- .github/pull_request_template.md -->

## 📌 Description

Support unpadded output hidden size for `trtllm_fp4_block_scale_moe` by
passing the `output` tensor with unpadded hidden size, so that we don't
need to do slicing after MoE.


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Corrected output dimension handling in MoE operations to use actual
output tensor shapes rather than assumed values.
* Enhanced validation for pre-allocated output tensors with stricter
shape and dtype checks.
* Improved consistency between different code execution paths for MoE
operations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: yzh119 <zihaoy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants