[NVIDIA] Enable TRTLLM BF16 MoE on Blackwell GPUs#13798
[NVIDIA] Enable TRTLLM BF16 MoE on Blackwell GPUs#13798Fridge003 merged 19 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant updates to enable BF16 (bfloat16) quantization for Mixture-of-Experts (MoE) layers, leveraging the NVIDIA TensorRT-LLM (TRTLLM) backend with FlashInfer, particularly optimized for Blackwell GPUs. The changes involve refining the MoE layer selection mechanism, enhancing the unquantized MoE method to properly process and utilize BF16 weights, and adjusting server configuration arguments to accommodate this new precision. A new test case has also been added to ensure the correctness of the BF16 MoE implementation. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables TRTLLM BF16 MoE on Blackwell GPUs, which is a valuable addition. The changes to support bf16 in the MoE layer selection, weight loading, and forward pass logic are well-implemented. The inclusion of a new test case for bf16 is also a great way to ensure correctness. I have a few suggestions to enhance code quality and address a potential bug.
| w13_weights_bf16_shuffled = ( | ||
| torch.stack(w13_weights_bf16_shuffled).view(torch.bfloat16).contiguous() | ||
| ) | ||
| w2_weights_bf16_shuffled = ( | ||
| torch.stack(w2_weights_bf16_shuffled).view(torch.bfloat16).contiguous() | ||
| ) |
There was a problem hiding this comment.
The .view(torch.bfloat16) calls are redundant. w13_weights_bf16_shuffled and w2_weights_bf16_shuffled are lists of bfloat16 tensors, so torch.stack will already produce a bfloat16 tensor. Removing the unnecessary .view() call will make the code cleaner.
w13_weights_bf16_shuffled = torch.stack(w13_weights_bf16_shuffled).contiguous()
w2_weights_bf16_shuffled = torch.stack(w2_weights_bf16_shuffled).contiguous()|
cc @yizhang2077 Because this also supports Qwen3/Qwen3-Next models |
| w13_weights_bf16_shuffled.append(tmp_weights1.view(torch.bfloat16)) | ||
| w2_weights_bf16_shuffled.append(tmp_weights2.view(torch.bfloat16)) | ||
|
|
||
| # Stack weights for all experts |
There was a problem hiding this comment.
Convert all experts layout and stack may double the memory usage, which may cause oom when loading weights.
There was a problem hiding this comment.
This makes sense. Fixed by inplace convert.
| pass | ||
| return FlashInferFP4MoE | ||
| elif (quant_config is None) or ( | ||
| quant_config is not None and quant_config.get_name() == "fp8" |
There was a problem hiding this comment.
should it be quant_config.get_name() can be fp8 and modelopt_fp8?
There was a problem hiding this comment.
Yes, it should be. Fixed
|
/tag-and-rerun-ci |
@samuellees I think this CI failure might be related |
Fixed by setting TopKOutputFormat correctly in this case. |
|
Can you also make sure that: This test can pass locally? |
Done, it works well. |
|
@samuellees Hi, can you look at https://sgl-fru7574.slack.com/archives/C09NG5Q0LEP/p1765528804581049, thanks. |
Sure, I'm taking a look at the issue. Two quick comments:
|
|
A quick fix. |
Dependency
flashinfer-python >= 0.5.3
[Merged] This PR could be merged after PR14350
Motivation
Enable TRTLLM BF16 MoE on Blackwell GPUs
Accuracy Tests
TRTLLM MoE
Triton MoE
Benchmarking and Profiling
B200, ISL=1k, OSL=8k, TEP4, DP OFF, MTP OFF, TRTLLM-MHA
Triton
Flashinfer TRTLLM-GEN-MoE
Modifications
Checklist