[MoE Refactor] Convert mxfp4 moe quant method into oracle#34983
[MoE Refactor] Convert mxfp4 moe quant method into oracle#34983zyongye wants to merge 46 commits intovllm-project:mainfrom
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed refactoring by creating an "oracle" for MXFP4 MoE backend selection and configuration. This greatly improves the structure by centralizing backend-specific logic, which was previously scattered across multiple files. The changes to the testing framework to use configuration files are also a welcome improvement for maintainability and scalability.
However, I've found a few critical bugs that need to be addressed:
- An incorrect condition in the backend selection logic.
- Typos in weight loading that will cause runtime errors.
- A duplicated enum value in a conditional check.
Additionally, there's a recurring use of hardcoded magic numbers specific to one model, which should be made configurable to improve maintainability and support for other models.
After these issues are fixed, this PR will be a great step forward for the codebase.
| # n = hidden_size | ||
| # k = intermediate_size_per_partition_after_pad | ||
| intermediate_size = round_up(intermediate_size, 128) | ||
| if backend == current_platform.is_xpu(): |
There was a problem hiding this comment.
The condition if backend == current_platform.is_xpu(): is incorrect. backend is an enum member, while current_platform.is_xpu() returns a boolean. This will always evaluate to False and lead to incorrect padding logic for XPU devices. The check should be if current_platform.is_xpu():.
| if backend == current_platform.is_xpu(): | |
| if current_platform.is_xpu(): |
| w2_scale = layer.w2_weight | ||
| w13_bias = layer.w1_bias | ||
| w2_bias = layer.w2_bias |
There was a problem hiding this comment.
There are a few typos here that will likely cause an AttributeError at runtime:
w2_scaleis incorrectly assigned fromlayer.w2_weight. It should belayer.w2_weight_scale.w13_biasis assigned fromlayer.w1_bias, but the parameter is namedw13_bias.w2_biasis assigned fromlayer.w2_bias, but the parameter is namedw2_bias.
Since biases are optional, it's safer to use getattr with a default value of None.
| w2_scale = layer.w2_weight | |
| w13_bias = layer.w1_bias | |
| w2_bias = layer.w2_bias | |
| w2_scale = layer.w2_weight_scale | |
| w13_bias = getattr(layer, "w13_bias", None) | |
| w2_bias = getattr(layer, "w2_bias", None) |
| self.gemm1_alpha = torch.tensor( | ||
| [1.702] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_beta = torch.tensor( | ||
| [1.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_clamp_limit = torch.tensor( | ||
| [7.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) |
There was a problem hiding this comment.
The values for gemm1_alpha (1.702) and gemm1_clamp_limit (7.0) are hardcoded. The comment indicates these are specific to gpt-oss. Hardcoding model-specific parameters makes the code less maintainable and harder to extend to other models. These values should be passed in through the model configuration rather than being hardcoded in the kernel implementation. This issue is also present in trtllm_moe.py and the new oracle/mxfp4.py.
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | ||
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, |
There was a problem hiding this comment.
There's a duplicate Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16 in the elif condition. This is likely a copy-paste error and should probably be Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8 to cover both CUTLASS backends.
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | |
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | |
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | |
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, |
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
392a502 to
6abf521
Compare
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request is a significant and valuable refactoring that introduces an oracle for MXFP4 MoE quantization, centralizing backend selection and processing logic. This improves the architecture by making it more modular and extensible, similar to how FP8 and NVFP4 are handled. The addition of comprehensive tests for various backends and hardware is also a great enhancement. My review focuses on a few areas to improve the robustness and maintainability of the new oracle, including fixing a logic bug, correcting a typo in an enum, and removing hardcoded model-specific values from generic kernels.
| # n = hidden_size | ||
| # k = intermediate_size_per_partition_after_pad | ||
| intermediate_size = round_up(intermediate_size, 128) | ||
| if backend == current_platform.is_xpu(): |
There was a problem hiding this comment.
The condition if backend == current_platform.is_xpu(): is incorrect. backend is an enum member, while current_platform.is_xpu() returns a boolean. This comparison will always evaluate to False, leading to incorrect behavior on XPU platforms. The condition should likely be if current_platform.is_xpu(): to check the platform type directly.
if current_platform.is_xpu():| self.gemm1_alpha = torch.tensor( | ||
| [1.702] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_beta = torch.tensor( | ||
| [1.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_clamp_limit = torch.tensor( | ||
| [7.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) |
There was a problem hiding this comment.
The values for gemm1_alpha, gemm1_beta, and gemm1_clamp_limit are hardcoded. The comment indicates these are specific to the gpt-oss model. Hardcoding model-specific parameters within a general-purpose kernel reduces its reusability and makes it harder to maintain. These values should be made configurable and passed in, for example, through the model's configuration, to make the kernel more generic.
| FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC = ( | ||
| "FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC" | ||
| ) | ||
| FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_MXFP4_MXFP8_CUTLASS" | ||
| FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_MXFP4_BF16" | ||
| FLASHINFER_TRTLLM_MXFP4_BF16_MONOLOTHIC = "FLASHINFER_MXFP4_BF16_MONOLOTHIC" |
There was a problem hiding this comment.
There's a recurring typo MONOLOTHIC which should be MONOLITHIC. This appears in the enum member names and their string values. While it's used consistently, it's best to correct it for clarity and to prevent future confusion. This typo is present on lines 55, 56, 60, and several other places in this file and others where these enum members are used.
| FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC = ( | |
| "FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC" | |
| ) | |
| FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_MXFP4_MXFP8_CUTLASS" | |
| FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_MXFP4_BF16" | |
| FLASHINFER_TRTLLM_MXFP4_BF16_MONOLOTHIC = "FLASHINFER_MXFP4_BF16_MONOLOTHIC" | |
| FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC = ( | |
| "FLASHINFER_TRTLLM_MXFP4_MXFP8_MONOLITHIC" | |
| ) | |
| FLASHINFER_CUTLASS_MXFP4_MXFP8 = "FLASHINFER_MXFP4_MXFP8_CUTLASS" | |
| FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_MXFP4_BF16" | |
| FLASHINFER_TRTLLM_MXFP4_BF16_MONOLITHIC = "FLASHINFER_MXFP4_BF16_MONOLITHIC" |
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | ||
| Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, | ||
| ): |
| self.gemm1_alpha = torch.tensor( | ||
| [1.702] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_beta = torch.tensor( | ||
| [1.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) | ||
| self.gemm1_clamp_limit = torch.tensor( | ||
| [7.0] * self.num_experts, dtype=torch.float32, device=self.device | ||
| ) |
There was a problem hiding this comment.
The values for gemm1_alpha, gemm1_beta, and gemm1_clamp_limit are hardcoded here. As noted in a similar file, these values appear to be specific to a particular model (gpt-oss). To improve the reusability and generality of this kernel, these parameters should be made configurable and passed in, rather than being hardcoded.
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
981ea82 to
98cd346
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
close for #37128 |
Purpose
Ongoing MXFP4 MoE refactor
This PR can be greatly improved once #32564 is merged.
Test Plan
gpt-oss-20b gpqa score with medium reason effort.
Blackwell (gb200):
Hopper (h200):
All of them are test with TP=2 and DP/EP=2
Test Result
Blackwell:
Hopper
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.