Skip to content

[NVIDIA] Enable TRTLLM BF16 MoE on Blackwell GPUs#13798

Merged
Fridge003 merged 19 commits intosgl-project:mainfrom
samuellees:trtllm-moe-bf16
Dec 12, 2025
Merged

[NVIDIA] Enable TRTLLM BF16 MoE on Blackwell GPUs#13798
Fridge003 merged 19 commits intosgl-project:mainfrom
samuellees:trtllm-moe-bf16

Conversation

@samuellees
Copy link
Contributor

@samuellees samuellees commented Nov 23, 2025

Dependency

flashinfer-python >= 0.5.3

[Merged] This PR could be merged after PR14350

Motivation

Enable TRTLLM BF16 MoE on Blackwell GPUs

Accuracy Tests

TRTLLM MoE

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 8 exact_match 0.9598 ± 0.0054
strict-match 8 exact_match 0.8362 ± 0.0102

Triton MoE

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 8 exact_match 0.9575 ± 0.0056
strict-match 8 exact_match 0.8347 ± 0.0102

Benchmarking and Profiling

B200, ISL=1k, OSL=8k, TEP4, DP OFF, MTP OFF, TRTLLM-MHA

Triton

concurrency Mean TPOT (ms) Output token throughput (tok/s)
1 5.49 181.75
4 6.16 646.34
16 7.6 2090.84
64 11.27 5634.14
256 20.31 12466.59
512 29.64 16998.63

Flashinfer TRTLLM-GEN-MoE

concurrency Mean TPOT (ms) Output token throughput (tok/s)
1 4.14 240.81
4 4.94 806.11
16 6.33 2512.84
64 9.38 6776.53
256 16.57 15286.36
512 24.42 20639.92

Modifications

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant updates to enable BF16 (bfloat16) quantization for Mixture-of-Experts (MoE) layers, leveraging the NVIDIA TensorRT-LLM (TRTLLM) backend with FlashInfer, particularly optimized for Blackwell GPUs. The changes involve refining the MoE layer selection mechanism, enhancing the unquantized MoE method to properly process and utilize BF16 weights, and adjusting server configuration arguments to accommodate this new precision. A new test case has also been added to ensure the correctness of the BF16 MoE implementation.

Highlights

  • TRTLLM BF16 MoE Support: Enabled support for BF16 quantization in Mixture-of-Experts (MoE) layers when using the NVIDIA TensorRT-LLM (TRTLLM) backend with FlashInfer, specifically targeting Blackwell GPUs.
  • MoE Layer Selection Logic: Refactored the get_moe_impl_class function to correctly select FlashInferFusedMoE for BF16 (unquantized) and FP8, and FlashInferFP4MoE for FP4 quantization when the flashinfer_trtllm backend is active.
  • Unquantized MoE Method Enhancements: Extended the UnquantizedFusedMoEMethod to handle FlashInfer TRTLLM MoE, including specific weight reordering and block layout conversion for BF16 weights during loading, and a dedicated forward path utilizing flashinfer.fused_moe.trtllm_bf16_moe.
  • Relaxed Quantization Constraints: Removed the strict requirement for FP4 or FP8 quantization when using the flashinfer_trtllm MoE backend in server_args.py, allowing for BF16 operation.
  • New BF16 Test Case: Added a new nightly test, TestFlashinferTrtllmGenMoeBackendBF16, to validate the functionality and accuracy of BF16 MoE with the flashinfer_trtllm backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@samuellees
Copy link
Contributor Author

cc @kaixih @b8zhong

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables TRTLLM BF16 MoE on Blackwell GPUs, which is a valuable addition. The changes to support bf16 in the MoE layer selection, weight loading, and forward pass logic are well-implemented. The inclusion of a new test case for bf16 is also a great way to ensure correctness. I have a few suggestions to enhance code quality and address a potential bug.

Comment on lines +271 to +276
w13_weights_bf16_shuffled = (
torch.stack(w13_weights_bf16_shuffled).view(torch.bfloat16).contiguous()
)
w2_weights_bf16_shuffled = (
torch.stack(w2_weights_bf16_shuffled).view(torch.bfloat16).contiguous()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .view(torch.bfloat16) calls are redundant. w13_weights_bf16_shuffled and w2_weights_bf16_shuffled are lists of bfloat16 tensors, so torch.stack will already produce a bfloat16 tensor. Removing the unnecessary .view() call will make the code cleaner.

            w13_weights_bf16_shuffled = torch.stack(w13_weights_bf16_shuffled).contiguous()
            w2_weights_bf16_shuffled = torch.stack(w2_weights_bf16_shuffled).contiguous()

@samuellees samuellees marked this pull request as ready for review November 24, 2025 11:30
@samuellees
Copy link
Contributor Author

cc @yizhang2077 Because this also supports Qwen3/Qwen3-Next models

@b8zhong b8zhong added the run-ci label Nov 29, 2025
w13_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16))
w2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16))

# Stack weights for all experts

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert all experts layout and stack may double the memory usage, which may cause oom when loading weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense. Fixed by inplace convert.

pass
return FlashInferFP4MoE
elif (quant_config is None) or (
quant_config is not None and quant_config.get_name() == "fp8"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be quant_config.get_name() can be fp8 and modelopt_fp8?

Copy link
Contributor Author

@samuellees samuellees Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be. Fixed

@b8zhong
Copy link
Collaborator

b8zhong commented Dec 5, 2025

/tag-and-rerun-ci
one more time

@b8zhong
Copy link
Collaborator

b8zhong commented Dec 9, 2025


 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 785, in forward
   return self.forward_normal_dual_stream(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 819, in forward_normal_dual_stream
   final_hidden_states = self.experts(hidden_states, topk_output)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 1046, in forward
   topk_output.topk_config.renormalize
   ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'StandardTopKOutput' object has no attribute 'topk_config'

@samuellees I think this CI failure might be related

@samuellees
Copy link
Contributor Author

samuellees commented Dec 9, 2025


 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 785, in forward
   return self.forward_normal_dual_stream(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/models/deepseek_v2.py", line 819, in forward_normal_dual_stream
   final_hidden_states = self.experts(hidden_states, topk_output)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
   return self._call_impl(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
   return forward_call(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/actions-runner/_work/sglang/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 1046, in forward
   topk_output.topk_config.renormalize
   ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'StandardTopKOutput' object has no attribute 'topk_config'

@samuellees I think this CI failure might be related

Fixed by setting TopKOutputFormat correctly in this case.
test/srt/test_deepseek_v3_fp4_4gpu.py passed on my local env.
@b8zhong

@b8zhong
Copy link
Collaborator

b8zhong commented Dec 10, 2025

Can you also make sure that:

python test/nightly/test_deepseek_v3_fp4_cutlass_moe.py

This test can pass locally?

@samuellees
Copy link
Contributor Author

samuellees commented Dec 10, 2025

Can you also make sure that:

python test/nightly/test_deepseek_v3_fp4_cutlass_moe.py

This test can pass locally?

Done, it works well.

----------------------------------------------------------------------
Ran 1 test in 806.469s

OK
Accuracy: 0.948
Invalid: 0.000
Latency: 24.852 s
Output throughput: 5605.046 token/s
metrics={'accuracy': np.float64(0.9476876421531463), 'invalid': np.float64(0.0), 'latency': 24.852428324000357, 'output_throughput': 5605.045840348603}
[W1210 01:55:43.703721092 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())

@Fridge003 Fridge003 merged commit d7ed8a8 into sgl-project:main Dec 12, 2025
164 of 185 checks passed
@b8zhong
Copy link
Collaborator

b8zhong commented Dec 12, 2025

@samuellees Hi, can you look at https://sgl-fru7574.slack.com/archives/C09NG5Q0LEP/p1765528804581049, thanks.

@samuellees
Copy link
Contributor Author

samuellees commented Dec 12, 2025

@samuellees Hi, can you look at https://sgl-fru7574.slack.com/archives/C09NG5Q0LEP/p1765528804581049, thanks.

Sure, I'm taking a look at the issue. Two quick comments:

  1. This might break the accuracy of MoE/Dense(if have) Layer of DeepSeek MTP module. I'll try to find the reason and solve that.
  2. A backup solution is keep DeepSeek MTP MoE layer in triton backend.
    I'll make a decision today.

@samuellees
Copy link
Contributor Author

A quick fix.
PR15002

BenYao21 pushed a commit to minleminzui/sglang that referenced this pull request Dec 13, 2025
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 17, 2025
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
@samuellees samuellees mentioned this pull request Feb 11, 2026
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants