Add Mxfp4 trtllm-gen moe unit tests#1399
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've implemented new unit tests and extended existing infrastructure to support MxFP4 quantization within the trtllm-gen Mixture-of-Experts (MoE) system. This work ensures that the system can correctly process and verify models using the MxFP4 format, enhancing its versatility and precision for various quantization schemes. The changes involve adapting quantization logic, reference implementations, and test cases to accommodate the unique characteristics of MxFP4.
Highlights
- New MxFP4 Unit Tests: I've added comprehensive unit tests for the new MxFP4 quantization mode within the
trtllm-genMixture-of-Experts (MoE) framework. This ensures the correctness and accuracy of MxFP4 operations, particularly concerning hidden states and weights. - Dual-Mode FP4Moe Implementation: I've extended the
FP4Moeclass to dynamically support both NvFP4 and MxFP4 quantization. This involved introducing anis_mxfp4flag that intelligently adjusts scaling factor vector sizes, global scale calculations, and the specific quantization/dequantization functions used for inputs and weights. - Enhanced Reference Dequantization: I've updated the core reference implementations for MoE dequantization (
run_moe_dequantandrun_moe_reference_fp4) to correctly handle the newmxe4m3quantization mode. This ensures that the reference outputs accurately reflect the behavior of MxFP4, allowing for precise comparison with kernel outputs. - Adaptive FP4 Utility Functions: I've modified several utility functions, including
calculate_fp4_global_scale_factorandquant_fp4, to be aware of theuse_ue8m0flag. This allows them to adapt their calculations and block sizes based on whether the UE8M0 format is employed for scaling factors, which is crucial for MxFP4.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds unit tests for MxFP4 quantization in Mixture of Experts (MoE) layers. The changes are comprehensive, introducing a new quantization mode and parameterizing existing tests to cover it. The code is generally well-structured, but there are several instances of leftover debugging print statements that should be removed. Additionally, there's an unconventional use of docstrings within conditional blocks that should be refactored for better readability.
tests/test_trtllm_gen_fused_moe.py
Outdated
| # print( | ||
| # f"input_tensor.shape: {self.input_tensor.shape}, dtype: {self.input_tensor.dtype}" | ||
| # ) | ||
| # print( | ||
| # f'hidden_states.shape: {input_quantized["hidden_states"].shape}, dtype: {input_quantized["hidden_states"].dtype}' | ||
| # ) | ||
| # print( | ||
| # f'hidden_states_scale.shape: {input_quantized["hidden_states_scale"].shape}, dtype: {input_quantized["hidden_states_scale"].dtype}' | ||
| # ) | ||
| # print( | ||
| # f'gemm1_weights.shape: {self.static_data["gemm1_weights_fp4_shuffled"].shape}, dtype: {self.static_data["gemm1_weights_fp4_shuffled"].dtype}' | ||
| # ) | ||
| # print( | ||
| # f'gemm1_scales.shape: {self.static_data["gemm1_scales_fp4_shuffled"].shape}, dtype: {self.static_data["gemm1_scales_fp4_shuffled"].dtype}' | ||
| # ) | ||
| # print( | ||
| # f'gemm2_weights.shape: {self.static_data["gemm2_weights_fp4_shuffled"].shape}, dtype: {self.static_data["gemm2_weights_fp4_shuffled"].dtype}' | ||
| # ) | ||
| # print( | ||
| # f'gemm2_scales.shape: {self.static_data["gemm2_scales_fp4_shuffled"].shape}, dtype: {self.static_data["gemm2_scales_fp4_shuffled"].dtype}' | ||
| # ) | ||
| # print(f'routing_method_type: {self.config["routing_method_type"]}') |
tests/test_trtllm_gen_fused_moe.py
Outdated
| if self.is_mxfp4: | ||
| """Quantize hidden states to MxFP8 format.""" | ||
| hidden_states_quant, hidden_states_scale = mxfp8_quantize( | ||
| hidden_states, is_swizzling | ||
| ) | ||
| hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape( | ||
| -1 | ||
| ) | ||
| print( | ||
| f"hidden_states.shape: {hidden_states_quant.shape}, dtype: {hidden_states_quant.dtype}" | ||
| ) | ||
| print( | ||
| f"hidden_states_scale.shape: {hidden_states_scale.shape}, dtype: {hidden_states_scale.dtype}" | ||
| ) | ||
| return { | ||
| "hidden_states": hidden_states_quant, | ||
| "hidden_states_scale": hidden_states_scale, | ||
| } | ||
| else: | ||
| """Quantize hidden states to NvFP4 format using pre-computed global scale.""" | ||
| ( | ||
| hidden_states_fp4_bytes, | ||
| hidden_states_scale_fp4_bytes, | ||
| _, | ||
| ) = quant_fp4( | ||
| hidden_states, hidden_states_scale_global, False, is_swizzling | ||
| ) | ||
|
|
||
| # Quantize hidden states using pre-computed global scale factor | ||
| ( | ||
| hidden_states_fp4_bytes, | ||
| hidden_states_scale_fp4_bytes, | ||
| _, | ||
| ) = quant_fp4(hidden_states, hidden_states_scale_global, use_ue8m0, True) | ||
|
|
||
| return { | ||
| "hidden_states": hidden_states_fp4_bytes, | ||
| "hidden_states_scale": hidden_states_scale_fp4_bytes, | ||
| } | ||
| return { | ||
| "hidden_states": hidden_states_fp4_bytes, | ||
| "hidden_states_scale": hidden_states_scale_fp4_bytes, | ||
| } |
There was a problem hiding this comment.
This method contains docstrings within if/else blocks, which is not standard Python practice and harms readability. These should be converted to regular comments. Additionally, there are active print statements that appear to be for debugging and should be removed.
if self.is_mxfp4:
# Quantize hidden states to MxFP8 format.
hidden_states_quant, hidden_states_scale = mxfp8_quantize(
hidden_states, is_swizzling
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
-1
)
return {
"hidden_states": hidden_states_quant,
"hidden_states_scale": hidden_states_scale,
}
else:
# Quantize hidden states to NvFP4 format using pre-computed global scale.
(
hidden_states_fp4_bytes,
hidden_states_scale_fp4_bytes,
_,
) = quant_fp4(
hidden_states, hidden_states_scale_global, False, is_swizzling
)
return {
"hidden_states": hidden_states_fp4_bytes,
"hidden_states_scale": hidden_states_scale_fp4_bytes,
}
tests/test_trtllm_gen_fused_moe.py
Outdated
| scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * ( | ||
| 1.0 / args.gemm2_scales_global | ||
| ) | ||
| print(f"gemm1_scales_global: {scale_c_fc1}, gemm2_scales_global: {scale_c_fc2}") |
tests/test_trtllm_gen_fused_moe.py
Outdated
| print(f"ref: {output_dequant_reference[:10, :10]}") | ||
| print(f"actual: {output_dequant_actual[:10, :10]}") |
0bb5e22 to
c4362ac
Compare
|
Hi @IwakuraRein would you mind resolving the conflict with main branch? (brought by #1412 ) |
Signed-off-by: siyuanf <siyuanf@nvidia.com>
c4362ac to
be3299c
Compare
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes