[NVIDIA] Fix breakage of using trtllm-gen fp8 moe#8773
[NVIDIA] Fix breakage of using trtllm-gen fp8 moe#8773zhyncs merged 2 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a breakage in the "trtllm-gen fp8 moe" functionality. I've adjusted the expected input to the "forward" methods within the relevant MoE layers to correctly handle the data format required by "TRTLLM" mode, ensuring the system operates as intended.
Highlights
- Method Signature Update: Modified the "forward" method signature in "FlashInferEPMoE" and "FlashInferFusedMoE" classes to accept a "topk_output" tuple instead of "router_logits".
- Input Validation and Extraction: Introduced a check to ensure "topk_output" is a tuple of length 2 and extracts "router_logits" from it, aligning with the "TRTLLM" mode's expectations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request addresses a breakage in trtllm-gen fp8 moe by updating the forward method signature in FlashInferEPMoE and FlashInferFusedMoE. The change correctly adapts to a new API where topk_output is passed as a tuple containing router logits. My main feedback is to refactor the duplicated validation logic into a shared helper function to improve code reuse and maintainability.
| # TRTLLM mode expects (TopK_config, router_logits) tuple | ||
| if not isinstance(topk_output, tuple) or len(topk_output) != 2: | ||
| raise ValueError( | ||
| f"FlashInferEPMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" | ||
| ) | ||
| _, router_logits = topk_output |
There was a problem hiding this comment.
This logic for validating and unpacking topk_output is duplicated in python/sglang/srt/layers/moe/fused_moe_triton/layer.py. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function.
For example, you could create a helper function like this in a utility file or at the module level in fused_moe_triton/layer.py:
def _unpack_topk_output_for_trtllm(topk_output: tuple, class_name: str) -> torch.Tensor:
"""Unpacks topk_output tuple for TRTLLM mode and returns router_logits."""
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"{class_name} expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
return router_logitsThen you could call it from both FlashInferEPMoE.forward and FlashInferFusedMoE.forward.
| # TRTLLM mode expects (TopK_config, router_logits) tuple | ||
| if not isinstance(topk_output, tuple) or len(topk_output) != 2: | ||
| raise ValueError( | ||
| f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" | ||
| ) | ||
| _, router_logits = topk_output |
There was a problem hiding this comment.
This validation and unpacking logic for topk_output is also present in FlashInferEPMoE.forward. To improve code reuse and maintainability, you could extract this logic into a helper function.
For example, you could define this function in this file:
def _unpack_topk_output_for_trtllm(topk_output: tuple, class_name: str) -> torch.Tensor:
"""Unpacks topk_output tuple for TRTLLM mode and returns router_logits."""
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"{class_name} expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
return router_logitsThen, you can simplify the forward method here and in FlashInferEPMoE by calling this helper.
|
Talked with @ch-wan and learned that the |
|
The accuracy tests are conducted as well: |
|
Sure. Just done the fp4 accuracy test with repro commands in here. @zhyncs ^^ |
This PR fixes a breakage caused by this PR.
cc. @kushanam