[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4#33506
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request adds support for non-gated Mixture of Experts (MoE) models using FlashInfer with FP8 and NVFP4 quantization. The changes are comprehensive, including updates to tests, support checks, activation handling, and weight preparation logic. Overall, the changes are well-aligned with the PR's objective. However, I've identified a critical bug in the FP4 MoE weight preparation logic that incorrectly calculates shapes for gated activations, which could lead to runtime errors or incorrect results. I have provided specific suggestions to address this issue.
|
|
||
| gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( | ||
| num_experts, hidden_size, intermediate_size // 2 | ||
| num_experts, hidden_size, actual_intermediate_size // 2 | ||
| ) # packed fp4 | ||
| gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( | ||
| torch.float8_e4m3fn | ||
| ).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors | ||
| ).reshape( | ||
| num_experts, hidden_size, actual_intermediate_size // 16 | ||
| ) # fp8 scaling factors |
There was a problem hiding this comment.
The calculation for gemm2_weights_fp4 and gemm2_scales_linear_fp4 shapes is incorrect for gated activations. actual_intermediate_size is derived from w13's shape, which differs for gated and non-gated models. However, the down-projection (gemm2) should have a consistent intermediate dimension. This change introduces mlp_ffn_dim to correctly calculate the shapes for both gated and non-gated cases, fixing a bug for gated activations.
| gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( | |
| num_experts, hidden_size, intermediate_size // 2 | |
| num_experts, hidden_size, actual_intermediate_size // 2 | |
| ) # packed fp4 | |
| gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( | |
| torch.float8_e4m3fn | |
| ).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors | |
| ).reshape( | |
| num_experts, hidden_size, actual_intermediate_size // 16 | |
| ) # fp8 scaling factors | |
| mlp_ffn_dim = intermediate_size if is_gated_activation else 2 * intermediate_size | |
| gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( | |
| num_experts, hidden_size, mlp_ffn_dim // 2 | |
| ) # packed fp4 | |
| gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( | |
| torch.float8_e4m3fn | |
| ).reshape( | |
| num_experts, hidden_size, mlp_ffn_dim // 16 | |
| ) # fp8 scaling factors |
| gemm2_scales_fp4_shuffled = ( | ||
| torch.stack(gemm2_scales_fp4_shuffled) | ||
| .view(torch.float8_e4m3fn) | ||
| .reshape(num_experts, hidden_size, intermediate_size // 16) | ||
| .reshape(num_experts, hidden_size, actual_intermediate_size // 16) | ||
| ) |
There was a problem hiding this comment.
Similar to the previous comment, the reshape dimension for gemm2_scales_fp4_shuffled is incorrect for gated activations. It should use the mlp_ffn_dim variable (defined in the suggested fix for the previous issue) to ensure the correct shape.
| gemm2_scales_fp4_shuffled = ( | |
| torch.stack(gemm2_scales_fp4_shuffled) | |
| .view(torch.float8_e4m3fn) | |
| .reshape(num_experts, hidden_size, intermediate_size // 16) | |
| .reshape(num_experts, hidden_size, actual_intermediate_size // 16) | |
| ) | |
| gemm2_scales_fp4_shuffled = ( | |
| torch.stack(gemm2_scales_fp4_shuffled) | |
| .view(torch.float8_e4m3fn) | |
| .reshape(num_experts, hidden_size, mlp_ffn_dim // 16) | |
| ) |
2d08ea2 to
5e81d21
Compare
500f8e3 to
e1b1314
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
5956b26 to
0e40b63
Compare
3e18b3d to
4550510
Compare
| use_routing_scales_on_input: bool, | ||
| routing_method_type: int, | ||
| routed_scaling_factor: float = 1.0, | ||
| activation_type: int = 3, # Swiglu |
There was a problem hiding this comment.
Let's remove the default value to always be explicit
| def is_gated_activation(activation: str) -> bool: | ||
| return not activation.lower().endswith("_no_mul") | ||
|
|
||
|
|
||
| def activation_str_to_int(activation: str) -> int: | ||
| from flashinfer.fused_moe.core import ActivationType | ||
|
|
||
| # silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively | ||
| ACTIVATION_TO_FI_ACTIVATION = { | ||
| "silu_no_mul": ActivationType.Silu, | ||
| "gelu_no_mul": ActivationType.Gelu, | ||
| "silu": ActivationType.Swiglu, | ||
| "gelu": ActivationType.Geglu, | ||
| "relu2_no_mul": ActivationType.Relu2, | ||
| } | ||
| return ACTIVATION_TO_FI_ACTIVATION[activation.lower()].value |
There was a problem hiding this comment.
Would be nice if we could have this use the MoEActivation refactor, hopefully landing soon #33843
There was a problem hiding this comment.
Nice, I definitely agree that refactor is necessary!
Regarding the order, I think it depends on when the refactor PR is merged
| # for the gate-up proj. Pad the weights to respect this. | ||
| is_gated = is_gated_activation(layer.activation) | ||
| if not block_quant: | ||
| min_alignment = 16 if is_gated else 128 |
There was a problem hiding this comment.
Is there some justification for 128 we can reference?
There was a problem hiding this comment.
That's what the current Flashinfer kernels require, otherwise it doesn't find a suitable kernel.
For example, Nemotron 3 Nano TP=1 would fail unless it's set to 128 here:
(EngineCore_DP0 pid=3184059) File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 2258, in trtllm_fp8_per_tensor_scale_moe
(EngineCore_DP0 pid=3184059) return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(
(EngineCore_DP0 pid=3184059) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3184059) File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 1488, in trtllm_fp8_per_tensor_scale_moe_op
(EngineCore_DP0 pid=3184059) result = moe_op.trtllm_fp8_per_tensor_scale_moe(
(EngineCore_DP0 pid=3184059) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3184059) File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
(EngineCore_DP0 pid=3184059) RuntimeError: Error in function 'getValidConfigIndices' at /usr/local/lib/python3.12/dist-packages/flashinfer/data/csrc/trtllm_batched_gemm_runner.cu:416: No valid config found for the given problem shape
| block_quant = ( | ||
| hasattr(layer, "weight_block_size") and layer.weight_block_size is not None | ||
| ) |
There was a problem hiding this comment.
If we are in NVFP4, why would we expect weight_block_size in any case?
There was a problem hiding this comment.
It was copied from the FP8 flow, removing it
694312f to
1a69bc0
Compare
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…ted MoE and rel2_no_mul activation, support DeepSeek routing in FP8 per-tensor, fix prepare_static_weights_for_trtllm_fp4_moe Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…m/utils/flashinfer.py Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…gated, otherwise use min_alignment=16 Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
5d43e07 to
ea22768
Compare
|
Hi @amitz-nv, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
mgoin
left a comment
There was a problem hiding this comment.
LGTM nice work! Will manually trigger MoE refactor tests
| # time in the oracle rather than here. | ||
| assert layer.activation == MoEActivation.SILU, ( | ||
| f"Expected 'silu' activation but got {layer.activation}" | ||
| SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] | ||
| assert layer.activation in SUPPORTED_ACTIVATIONS, ( | ||
| f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " | ||
| f"TRTLLM FP4 MoE, {layer.activation} found instead." | ||
| ) | ||
| assert not layer.renormalize |
There was a problem hiding this comment.
Note: we need to update the compressed tensors side too, can do in followup PR
…llm-project#33506) Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
…llm-project#33506) Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…llm-project#33506) Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Purpose
Add support for Flashinfer trtllm fused MoE non-gated activation for FP8 and for NVFP4.
Changes:
activation_typeargument to FlashInfer trtllm fused MoE FP8 and NVFP4.min_alignment=128(padding) for non-gated activation in Flashinfer trtllm fused MoEtests/kernels/moe/test_flashinfer.pyand expand it to also testrelu2_no_mulactivation for both cutlass and trtllm kernels.lm_evalonNemotron 3 Nano FP8:Outputs:
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.