[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE#30647
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for Mixture-of-Experts (MoE) layers in GPT-OSS models using Flashinfer with MXFP4/MXFP8 quantization. The key changes involve eliminating explicit padding and slicing operations around the MoE computation. This is achieved by leveraging new capabilities in the Flashinfer library to handle padding within the quantization kernel and to write to an unpadded output buffer directly.
The main changes are:
- Elimination of Padding/Slicing: The
FusedMoElayer no longer performs manual padding before the MoE kernel for supported backends. Instead, the padding is handled byflashinfer::mxfp8_quantize, and the subsequent slicing is effectively done by the MoE kernel writing to a smaller, pre-allocated output tensor. This change enables better fusion opportunities, as seen by theall-reduce + normfusion now being possible. - Code Refactoring: The logic for rounding up hidden sizes for MXFP4 quantization has been moved from the generic
fused_moe/layer.pyto the specificquantization/mxfp4.py, which is a more appropriate location. This removes duplicated code and improves modularity. - Conditional Logic: The new behavior is controlled by a
support_padded_mxfp8_quantflag, ensuring that it only applies to theSM100_FI_MXFP4_MXFP8_TRTLLMbackend on Blackwell GPUs, maintaining compatibility with other configurations. - Testing: New test cases have been added to
test_fusions_e2e.pyto validate the fusions and performance improvements for GPT-OSS models on Blackwell.
The changes are well-implemented and align with the stated goals of improving performance. The code is clean and the new logic is properly encapsulated. The performance benchmarks in the PR description show a significant 6% end-to-end improvement, which is a great result.
I have reviewed the code and found no critical or high-severity issues. The changes are correct and contribute to better performance and code structure.
|
This pull request has merge conflicts that must be resolved before it can be |
4fd26b1 to
1fdd5ec
Compare
1fdd5ec to
3648f8a
Compare
There was a problem hiding this comment.
💡 Codex Review
https://github.com/vllm-project/vllm/blob/3648f8ab8e1f75350586bd226d8a55778f7e3ebc/vllm/model_executor/layers/fused_moe/layer.py#L510-L514
Keep MoE config hidden size in sync with MXFP4 padding
Here the MoE config is built with whatever hidden_size was passed in, but the MXFP4 backend now rounds hidden_size up later in Mxfp4MoEMethod.create_weights (e.g., to 256-aligned for SM100 FlashInfer, see vllm/model_executor/layers/quantization/mxfp4.py around lines 298-309). Because moe_config.hidden_dim stays at the unpadded value, any DP+EP run that uses the all2all kernels will size dispatch buffers from the smaller hidden_dim (see maybe_make_prepare_finalize in all2all_utils.py), while the kernel operates on the larger padded hidden size, leading to under-sized buffers and potential memory corruption for models whose hidden size is not already aligned. Please update the config’s hidden_dim after padding or pad before creating the config.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
5c08ae1 to
fc9b00f
Compare
|
Hi @elvischenv, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
fc9b00f to
ad3ff99
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
ad3ff99 to
938bf35
Compare
|
@elvischenv could you post the new benchmarking numbers once you have them? |
938bf35 to
02510d0
Compare
This is the perf number based on main ToT: main accuracy of PR |
|
This pull request has merge conflicts that must be resolved before it can be |
| elif ( | ||
| current_platform.is_rocm() | ||
| or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM | ||
| or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 | ||
| ): |
There was a problem hiding this comment.
I just ran into a situation where the mxfp4 marlin kernels require 256 element padding. Will this PR also address that or is it premature to remove this function?
tests/entrypoints/openai/responses/test_harmony.py fails if marlin is used for mxfp4.
There was a problem hiding this comment.
Or maybe this needs to be incorporated into maybe_roundup_layer_hidden_size?
There was a problem hiding this comment.
If you look into the create_weights() inside vllm/model_executor/layers/quantization/mxfp4.py, you will see a completely duplicated logic with this function.
So the current padding logic inside FusedMoE init() will first call maybe_roundup_hidden_size, which includes a small part of padding logic for mxfp4. Then it will call get_quant_method() and quant_method.create_weights(). create_weights() will go through the whole padding logic again if it is using Mxfp4MoEMethod. cc @robertgshaw2-redhat
There was a problem hiding this comment.
Can we keep all the logic here in layer.py instead of having it in two places?
| moe_quant_params["intermediate_size_full"] = intermediate_size | ||
|
|
||
| self.quant_method.create_weights(layer=self, **moe_quant_params) | ||
| # hidden_size may be padded in create_weights |
There was a problem hiding this comment.
Can you point to where this happens?
There was a problem hiding this comment.
Answered in the previous comment.
The calling order is like maybe_roundup_hidden_size -> create MoE config self.moe_config: FusedMoEConfig = FusedMoEConfig() -> get_quant_method() -> quant_method.create_weights().
There may be some paddings happening inside the create_weights() so need to update the MoE config.
The problem is where we should put the padding logic. Currently some are in vllm/model_executor/layers/fused_moe/layer.py, some are in vllm/model_executor/layers/quantization/mxfp4.py. And looks like PR #29008 depends on the logic in layers/fused_moe/layer.py.
There was a problem hiding this comment.
fyi I have a (wip) draft to refactor the roundup logic at #34285, to move the kernel dependent rounding logic to quant_method. fused_moe/layer.py still needs to invoke it to update the sizes.
There was a problem hiding this comment.
I think delegating the decision to the quant methods is the correct approach.
02510d0 to
c0a5ab3
Compare
|
Looks good to me but would want @mgoin or @robertgshaw2-redhat or @bnell to check the Moe code |
|
|
||
| # The padding in the forward pass can be skipped | ||
| self.skip_forward_padding = ( | ||
| hasattr(self.quant_method, "support_skip_forward_padding") |
There was a problem hiding this comment.
Can you make this a method on FusedMoEMethodBase instead of an attribute? It could default to False.
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
c0a5ab3 to
d195774
Compare
|
I am getting the error that the fusion pass trigger the assertion error when running |
…XFP4 MXFP8 MoE (vllm-project#30647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Purpose
Cleaned up the padding logic: for mxfp4 quant, the padded hidden size is calculated increate_weights(), themaybe_roundup_hidden_size()invllm/model_executor/layers/fused_moe/layer.pyseems like a dup.Test Plan && Test Result(GPT-OSS-120b TP8)
Accuracy
PR:
main:
Kernel
PR:
main:
Perf (GPT-OSS-120b TP8 con8)
PR: 5% E2E improvement
main:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.