Skip to content

[NVIDIA] Fix breakage of using trtllm-gen fp8 moe#8773

Merged
zhyncs merged 2 commits intosgl-project:mainfrom
kaixih:fix_trtllm_moe_bs_fp8
Aug 4, 2025
Merged

[NVIDIA] Fix breakage of using trtllm-gen fp8 moe#8773
zhyncs merged 2 commits intosgl-project:mainfrom
kaixih:fix_trtllm_moe_bs_fp8

Conversation

@kaixih
Copy link
Copy Markdown
Collaborator

@kaixih kaixih commented Aug 4, 2025

This PR fixes a breakage caused by this PR.

cc. @kushanam

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a breakage in the "trtllm-gen fp8 moe" functionality. I've adjusted the expected input to the "forward" methods within the relevant MoE layers to correctly handle the data format required by "TRTLLM" mode, ensuring the system operates as intended.

Highlights

  • Method Signature Update: Modified the "forward" method signature in "FlashInferEPMoE" and "FlashInferFusedMoE" classes to accept a "topk_output" tuple instead of "router_logits".
  • Input Validation and Extraction: Introduced a check to ensure "topk_output" is a tuple of length 2 and extracts "router_logits" from it, aligning with the "TRTLLM" mode's expectations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a breakage in trtllm-gen fp8 moe by updating the forward method signature in FlashInferEPMoE and FlashInferFusedMoE. The change correctly adapts to a new API where topk_output is passed as a tuple containing router logits. My main feedback is to refactor the duplicated validation logic into a shared helper function to improve code reuse and maintainability.

Comment on lines +707 to +712
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferEPMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for validating and unpacking topk_output is duplicated in python/sglang/srt/layers/moe/fused_moe_triton/layer.py. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function.

For example, you could create a helper function like this in a utility file or at the module level in fused_moe_triton/layer.py:

def _unpack_topk_output_for_trtllm(topk_output: tuple, class_name: str) -> torch.Tensor:
    """Unpacks topk_output tuple for TRTLLM mode and returns router_logits."""
    if not isinstance(topk_output, tuple) or len(topk_output) != 2:
        raise ValueError(
            f"{class_name} expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
        )
    _, router_logits = topk_output
    return router_logits

Then you could call it from both FlashInferEPMoE.forward and FlashInferFusedMoE.forward.

Comment on lines +776 to +781
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This validation and unpacking logic for topk_output is also present in FlashInferEPMoE.forward. To improve code reuse and maintainability, you could extract this logic into a helper function.

For example, you could define this function in this file:

def _unpack_topk_output_for_trtllm(topk_output: tuple, class_name: str) -> torch.Tensor:
    """Unpacks topk_output tuple for TRTLLM mode and returns router_logits."""
    if not isinstance(topk_output, tuple) or len(topk_output) != 2:
        raise ValueError(
            f"{class_name} expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
        )
    _, router_logits = topk_output
    return router_logits

Then, you can simplify the forward method here and in FlashInferEPMoE by calling this helper.

@zhyncs zhyncs added bug Something isn't working high priority labels Aug 4, 2025
@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Aug 4, 2025

Talked with @ch-wan and learned that the EPMoE and FusedMoE are sharing backends now. So, I removed FlashInferEPMoE and now everything (TP or EP MoE) should use FlashInferFusedMoE for trtllm-gen backends.

@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Aug 4, 2025

The accuracy tests are conducted as well:

for TP:
sglang (pretrained=/model/models--deepseek-ai--DeepSeek-R1-0528/snapshots/4236a6af538feda4548eca9ab308586007567f52/,trust_remote_code=True,tp_size=8,max_model_len=32768,add_bos_token=True,enable_flashinfer_trtllm_moe=True,disable_shared_experts_fusion=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 512
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9575|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9553|±  |0.0057|

for EP:
sglang (pretrained=/model/models--deepseek-ai--DeepSeek-R1-0528/snapshots/4236a6af538feda4548eca9ab308586007567f52/,trust_remote_code=True,tp_size=8,ep_size=8,max_model_len=32768,add_bos_token=True,enable_flashinfer_trtllm_moe=True,disable_shared_experts_fusion=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: 512
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9545|±  |0.0057|
|     |       |strict-match    |     5|exact_match|↑  |0.9522|±  |0.0059|

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Aug 4, 2025

Thanks @kaixih ! May you also help test fp4 ckpt accuracy like this #8552

@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Aug 4, 2025

Sure. Just done the fp4 accuracy test with repro commands in here.

> python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-FP4 --trust-remote-code --tp-size 4 --quantization modelopt_fp4 --enable-flashinfer-trtllm-moe --disable-shared-experts-fusion &
> echo $! > sglang_server.pid
> echo "Server PID: $(cat sglang_server.pid)"
> python3 benchmark/gsm8k/bench_sglang.py \
  --num-questions 900 \
  --parallel 32 \
  --num-shots 8
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 26.8MB/s]
100%|███████████████████████████████████████████████████████| 900/900 [01:55<00:00,  7.78it/s]
Accuracy: 0.957
Invalid: 0.000
Latency: 116.863 s
Output throughput: 786.992 token/s

@zhyncs ^^

@zhyncs zhyncs merged commit 6d0646d into sgl-project:main Aug 4, 2025
63 of 66 checks passed
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants