Skip to content

Add Mxfp4 trtllm-gen moe unit tests#1399

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
IwakuraRein:mxfp4-unit-tests
Aug 12, 2025
Merged

Add Mxfp4 trtllm-gen moe unit tests#1399
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
IwakuraRein:mxfp4-unit-tests

Conversation

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein commented Aug 6, 2025

📌 Description

  • add mxfp4 quantization unit test
  • add mxfp4 x mxfp8 and mxfp4 x bf16 moe unit test

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented new unit tests and extended existing infrastructure to support MxFP4 quantization within the trtllm-gen Mixture-of-Experts (MoE) system. This work ensures that the system can correctly process and verify models using the MxFP4 format, enhancing its versatility and precision for various quantization schemes. The changes involve adapting quantization logic, reference implementations, and test cases to accommodate the unique characteristics of MxFP4.

Highlights

  • New MxFP4 Unit Tests: I've added comprehensive unit tests for the new MxFP4 quantization mode within the trtllm-gen Mixture-of-Experts (MoE) framework. This ensures the correctness and accuracy of MxFP4 operations, particularly concerning hidden states and weights.
  • Dual-Mode FP4Moe Implementation: I've extended the FP4Moe class to dynamically support both NvFP4 and MxFP4 quantization. This involved introducing an is_mxfp4 flag that intelligently adjusts scaling factor vector sizes, global scale calculations, and the specific quantization/dequantization functions used for inputs and weights.
  • Enhanced Reference Dequantization: I've updated the core reference implementations for MoE dequantization (run_moe_dequant and run_moe_reference_fp4) to correctly handle the new mxe4m3 quantization mode. This ensures that the reference outputs accurately reflect the behavior of MxFP4, allowing for precise comparison with kernel outputs.
  • Adaptive FP4 Utility Functions: I've modified several utility functions, including calculate_fp4_global_scale_factor and quant_fp4, to be aware of the use_ue8m0 flag. This allows them to adapt their calculations and block sizes based on whether the UE8M0 format is employed for scaling factors, which is crucial for MxFP4.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds unit tests for MxFP4 quantization in Mixture of Experts (MoE) layers. The changes are comprehensive, introducing a new quantization mode and parameterizing existing tests to cover it. The code is generally well-structured, but there are several instances of leftover debugging print statements that should be removed. Additionally, there's an unconventional use of docstrings within conditional blocks that should be refactored for better readability.

Comment on lines +177 to +198
# print(
# f"input_tensor.shape: {self.input_tensor.shape}, dtype: {self.input_tensor.dtype}"
# )
# print(
# f'hidden_states.shape: {input_quantized["hidden_states"].shape}, dtype: {input_quantized["hidden_states"].dtype}'
# )
# print(
# f'hidden_states_scale.shape: {input_quantized["hidden_states_scale"].shape}, dtype: {input_quantized["hidden_states_scale"].dtype}'
# )
# print(
# f'gemm1_weights.shape: {self.static_data["gemm1_weights_fp4_shuffled"].shape}, dtype: {self.static_data["gemm1_weights_fp4_shuffled"].dtype}'
# )
# print(
# f'gemm1_scales.shape: {self.static_data["gemm1_scales_fp4_shuffled"].shape}, dtype: {self.static_data["gemm1_scales_fp4_shuffled"].dtype}'
# )
# print(
# f'gemm2_weights.shape: {self.static_data["gemm2_weights_fp4_shuffled"].shape}, dtype: {self.static_data["gemm2_weights_fp4_shuffled"].dtype}'
# )
# print(
# f'gemm2_scales.shape: {self.static_data["gemm2_scales_fp4_shuffled"].shape}, dtype: {self.static_data["gemm2_scales_fp4_shuffled"].dtype}'
# )
# print(f'routing_method_type: {self.config["routing_method_type"]}')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out print statements appears to be for debugging. It should be removed to keep the code clean.

Comment on lines +362 to +394
if self.is_mxfp4:
"""Quantize hidden states to MxFP8 format."""
hidden_states_quant, hidden_states_scale = mxfp8_quantize(
hidden_states, is_swizzling
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
-1
)
print(
f"hidden_states.shape: {hidden_states_quant.shape}, dtype: {hidden_states_quant.dtype}"
)
print(
f"hidden_states_scale.shape: {hidden_states_scale.shape}, dtype: {hidden_states_scale.dtype}"
)
return {
"hidden_states": hidden_states_quant,
"hidden_states_scale": hidden_states_scale,
}
else:
"""Quantize hidden states to NvFP4 format using pre-computed global scale."""
(
hidden_states_fp4_bytes,
hidden_states_scale_fp4_bytes,
_,
) = quant_fp4(
hidden_states, hidden_states_scale_global, False, is_swizzling
)

# Quantize hidden states using pre-computed global scale factor
(
hidden_states_fp4_bytes,
hidden_states_scale_fp4_bytes,
_,
) = quant_fp4(hidden_states, hidden_states_scale_global, use_ue8m0, True)

return {
"hidden_states": hidden_states_fp4_bytes,
"hidden_states_scale": hidden_states_scale_fp4_bytes,
}
return {
"hidden_states": hidden_states_fp4_bytes,
"hidden_states_scale": hidden_states_scale_fp4_bytes,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method contains docstrings within if/else blocks, which is not standard Python practice and harms readability. These should be converted to regular comments. Additionally, there are active print statements that appear to be for debugging and should be removed.

        if self.is_mxfp4:
            # Quantize hidden states to MxFP8 format.
            hidden_states_quant, hidden_states_scale = mxfp8_quantize(
                hidden_states, is_swizzling
            )
            hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
                -1
            )
            return {
                "hidden_states": hidden_states_quant,
                "hidden_states_scale": hidden_states_scale,
            }
        else:
            # Quantize hidden states to NvFP4 format using pre-computed global scale.
            (
                hidden_states_fp4_bytes,
                hidden_states_scale_fp4_bytes,
                _,
            ) = quant_fp4(
                hidden_states, hidden_states_scale_global, False, is_swizzling
            )

            return {
                "hidden_states": hidden_states_fp4_bytes,
                "hidden_states_scale": hidden_states_scale_fp4_bytes,
            }

scale_c_fc2 = (1.0 / args_dequant.c_global_sf) * (
1.0 / args.gemm2_scales_global
)
print(f"gemm1_scales_global: {scale_c_fc1}, gemm2_scales_global: {scale_c_fc2}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging and should be removed.

Comment on lines +2023 to +2024
print(f"ref: {output_dequant_reference[:10, :10]}")
print(f"actual: {output_dequant_actual[:10, :10]}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging and should be removed from the test case.

@IwakuraRein IwakuraRein marked this pull request as ready for review August 7, 2025 00:58
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Aug 9, 2025

Hi @IwakuraRein would you mind resolving the conflict with main branch? (brought by #1412 )

Signed-off-by: siyuanf <siyuanf@nvidia.com>
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit fe442a2 into flashinfer-ai:main Aug 12, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants