diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 9c31d9325962..d524b5667047 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -71,7 +71,8 @@ def quant_fp8_per_tensor_batches(a): for i in range(num_batches): a_fp8, a_global_sf = input_to_float8(a[i]) - a_global_sf = 1.0 / a_global_sf + if a_global_sf.numel() == 1: + a_global_sf = a_global_sf.view(1, 1) a_quant.append(a_fp8) a_scales.append(a_global_sf) @@ -81,6 +82,20 @@ def quant_fp8_per_tensor_batches(a): return result_a_quant, result_a_scales +def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925): + close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol) + match_ratio = close.float().mean() + assert match_ratio >= percent, ( + f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}" + ) + + mismatch_percent = 1.0 - match_ratio.item() + assert mismatch_percent <= 1 - percent, ( + f"Mismatch percentage {mismatch_percent:.4f} is above the threshold " + f"{1 - percent:.4f}" + ) + + @dataclass class TestData: hidden_states: torch.Tensor @@ -104,14 +119,16 @@ def make_moe_tensors_8bit( is_gated = activation.is_gated hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 - w13 = torch.randn( - (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16 + w13 = ( + torch.randn( + (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16 + ) + / 10 ) - w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) + w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10 # Scale to fp8 _, a1_scale = input_to_float8(hidden_states) - a1_scale = 1.0 / a1_scale a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32) w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) @@ -124,14 +141,16 @@ def make_moe_tensors_8bit( layer.w2_input_scale = a2_scale layer.w13_weight_scale = w13_weight_scale layer.w2_weight_scale = w2_weight_scale + layer.activation = activation # Setup dummy config. layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel() # flashinfer expects swapped rows for w13 - layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) + if is_gated: + layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if is_trtllm: rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( - layer.w13_weight, layer.w2_weight + layer.w13_weight, layer.w2_weight, is_gated ) register_scales_for_trtllm_fp8_per_tensor_moe( layer, @@ -162,12 +181,14 @@ def make_moe_tensors_8bit( @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]) def test_flashinfer_per_tensor_moe_fp8_no_graph( m: int, n: int, k: int, e: int, topk: int, + activation: MoEActivation, monkeypatch, ): if not current_platform.has_device_capability(100): @@ -175,7 +196,9 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( set_random_seed(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True) + td = TestData.make_moe_tensors_8bit( + m, k, n, e, is_trtllm=True, activation=activation + ) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) topk_weights, topk_ids = Llama4MoE.custom_routing_function( @@ -200,7 +223,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - activation=MoEActivation.SILU, + activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, @@ -219,7 +242,13 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( apply_router_weight_on_input=True, ) - torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2) + check_accuracy( + ref_output=output, + actual_output=flashinfer_output, + atol=0.1, + rtol=0.85, + percent=0.925, + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -320,8 +349,13 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: expert_map=None, apply_router_weight_on_input=True, ) - torch.testing.assert_close( - output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 + + check_accuracy( + ref_output=output, + actual_output=flashinfer_cutlass_output, + atol=0.1, + rtol=0.85, + percent=0.925, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index a50ad6722078..b2d571dd8fff 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -35,8 +35,8 @@ def _supports_current_device() -> bool: def _supports_no_act_and_mul() -> bool: - """Does not support non-gated MoE (i.e. Nanotron-Mini).""" - return False + """Supports non-gated MoE.""" + return True def _supports_quant_scheme( @@ -52,8 +52,7 @@ def _supports_quant_scheme( def _supports_activation(activation: MoEActivation) -> bool: - """Supports silu activation only.""" - return activation == MoEActivation.SILU + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] def _supports_routing_method( @@ -74,6 +73,7 @@ def _supports_routing_method( elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): # NOTE(dbari): as above, potentially allow others here. return routing_method in [ + RoutingMethodType.DeepSeekV3, RoutingMethodType.Llama4, RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, @@ -291,6 +291,7 @@ def fi_trtllm_fp8_per_tensor_moe( local_num_experts: int, use_routing_scales_on_input: bool, routing_method_type: int, + activation_type: int, routed_scaling_factor: float = 1.0, ) -> torch.Tensor: num_expert_group = num_expert_group if num_expert_group is not None else 0 @@ -326,9 +327,9 @@ def fi_trtllm_fp8_per_tensor_moe( routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, - # TODO: Required for flashinfer==0.6.3, remove with update + # TODO: enum type Required for flashinfer==0.6.3, remove with update # https://github.com/flashinfer-ai/flashinfer/pull/2508 - activation_type=ActivationType.Swiglu, + activation_type=ActivationType(activation_type), ) @@ -351,6 +352,7 @@ def fi_trtllm_fp8_per_tensor_moe_fake( local_num_experts: int, use_routing_scales_on_input: bool, routing_method_type: int, + activation_type: int, routed_scaling_factor: float = 1.0, ) -> torch.Tensor: return torch.empty_like(hidden_states) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e0322a46f01a..9af815ee9e9a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -937,10 +937,11 @@ def apply_monolithic( ) # TODO(rob): this validation should happen at kernel selection # time in the oracle rather than here. - assert layer.activation == MoEActivation.SILU, ( - f"Expected 'silu' activation but got {layer.activation}" + SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert layer.activation in SUPPORTED_ACTIVATIONS, ( + f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " + f"TRTLLM FP4 MoE, {layer.activation} found instead." ) - assert not layer.renormalize return apply_fi_trtllm_fp8_per_tensor_moe( layer=layer, hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 9d9fd31ad09d..ea84406ba90f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -15,6 +15,10 @@ FusedMoEParallelConfig, RoutingMethodType, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + activation_to_flashinfer_int, + align_fp4_moe_weights_for_fi, +) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( swizzle_blockscale, ) @@ -50,8 +54,8 @@ def _supports_current_device() -> bool: def _supports_no_act_and_mul() -> bool: - """Does not support non-gated MoE (i.e. Nemotron-Nano).""" - return False + """Supports non-gated MoE.""" + return True def _supports_quant_scheme( @@ -66,8 +70,7 @@ def _supports_quant_scheme( def _supports_activation(activation: MoEActivation) -> bool: - """Supports silu activation only.""" - return activation in [MoEActivation.SILU] + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] def _supports_routing_method( @@ -150,6 +153,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( hidden_size, intermediate_size, num_experts, + is_gated_activation: bool, ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( @@ -160,15 +164,18 @@ def prepare_static_weights_for_trtllm_fp4_moe( _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} """Prepare quantized weights for kernel (done offline with weights).""" epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + gemm1_intermediate_size = ( + 2 * intermediate_size if is_gated_activation else intermediate_size + ) # Convert quantized weights to proper formats gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2 + num_experts, gemm1_intermediate_size, hidden_size // 2 ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, 2 * intermediate_size, hidden_size // 16 + num_experts, gemm1_intermediate_size, hidden_size // 16 ) # fp8 scaling factors gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( @@ -191,6 +198,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( _cache_permute_indices, gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, + is_gated_act_gemm=is_gated_activation, ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] @@ -203,6 +211,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( gemm1_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, + is_gated_act_gemm=is_gated_activation, ) gemm1_scales_fp4_shuffled.append( nvfp4_block_scale_interleave( @@ -246,7 +255,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( gemm1_scales_fp4_shuffled = ( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) - .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + .reshape(num_experts, gemm1_intermediate_size, hidden_size // 16) ) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) @@ -297,10 +306,10 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE - # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404 - assert activation == MoEActivation.SILU, ( - "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. " - f"{activation} found instead." + SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert activation in SUPPORTED_ACTIVATIONS, ( + f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " + f"TRTLLM FP4 MoE, {activation} found instead." ) # Quantize input to FP4 @@ -325,6 +334,9 @@ def flashinfer_trtllm_fp4_moe( else router_logits ) + # Determine activation type + activation_type = activation_to_flashinfer_int(layer.activation) + # Call TRT-LLM FP4 block-scale MoE kernel out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits, @@ -355,6 +367,7 @@ def flashinfer_trtllm_fp4_moe( routed_scaling_factor=None, routing_method_type=routing_method_type, do_finalize=True, + activation_type=activation_type, )[0] return out @@ -479,10 +492,16 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ] # Reorder [w1, w3] to [w3, w1] for FI NVFP4 MoE kernels. - if is_act_and_mul and backend in [ - NvFp4MoeBackend.FLASHINFER_CUTLASS, - NvFp4MoeBackend.FLASHINFER_TRTLLM, - ]: + is_gated = layer.activation.is_gated + if ( + is_gated + and is_act_and_mul + and backend + in [ + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_TRTLLM, + ] + ): w13, w13_scale = reorder_w1w3_to_w3w1(w13, w13_scale) # For some FI kernels, the input scales are shared by all experts. @@ -495,19 +514,32 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( # Shuffle weights and scales for FI TRTLLM NVFP4 MoE kernels. if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + # Align weights for FI NVFP4 MoE kernels. + min_alignment = 16 if is_gated else 128 + w13, w13_scale, w2, w2_scale, padded_intermediate = ( + align_fp4_moe_weights_for_fi( + w13, w13_scale, w2, w2_scale, is_act_and_mul, min_alignment + ) + ) + layer.intermediate_size_per_partition = padded_intermediate + w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( w13, w2, w13_scale, w2_scale, - w2.size(-2), # hidden_size - w13.size(-2) // 2, # intermediate_size - w13.size(0), # num_experts + hidden_size=w2.size(-2), + intermediate_size=w13.size(-2) // 2 if is_gated else w13.size(-2), + num_experts=w13.size(0), + is_gated_activation=is_gated, ) # We do not need to make this a parameter, because # it is not used during the weight (re)-loading process. - layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale + if is_gated: + layer.g1_scale_c = a13_scale * w13_scale_2 / a2_scale + else: + layer.g1_scale_c = torch.ones_like(a13_scale) / a2_scale layer.a1_gscale = 1.0 / a13_scale layer.g1_alphas = a13_scale * w13_scale_2 layer.g2_alphas = a2_scale * w2_scale_2 diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 56c90aa86426..42fae9ee9327 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -6,6 +6,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.platforms import current_platform from vllm.utils.math_utils import round_up @@ -18,6 +19,20 @@ class FlashinferMoeBackend(Enum): CUTEDSL = "CUTEDSL" +def activation_to_flashinfer_int(activation: MoEActivation) -> int: + from flashinfer.fused_moe.core import ActivationType + + # silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively + ACTIVATION_TO_FI_ACTIVATION = { + MoEActivation.SILU_NO_MUL: ActivationType.Silu, + MoEActivation.GELU_NO_MUL: ActivationType.Gelu, + MoEActivation.SILU: ActivationType.Swiglu, + MoEActivation.GELU: ActivationType.Geglu, + MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, + } + return ACTIVATION_TO_FI_ACTIVATION[activation].value + + def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: return ( x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) @@ -25,7 +40,7 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( - gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor + gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor, is_gated_activation: bool ): """Shuffle weights for for FI TRT-LLM Format""" from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a @@ -40,6 +55,8 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( for i in range(num_experts): gemm1_weights_fp8_interleaved.append( reorder_rows_for_gated_act_gemm(gemm1_weights[i]) + if is_gated_activation + else gemm1_weights[i] ) # Stack weights and scales for all experts @@ -86,7 +103,13 @@ def register_scales_for_trtllm_fp8_per_tensor_moe( ) layer.w2_input_scale_inv = 1.0 / w2_input_scale layer.output1_scales_gate_scalar = g1_alphas - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv + + if layer.activation.is_gated: + layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv + else: + layer.output1_scales_scalar = ( + torch.ones_like(g1_alphas) * layer.w2_input_scale_inv + ) layer.output2_scales_scalar = g2_alphas @@ -125,6 +148,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( assert layer.custom_routing_function is None, ( "Custom routing function is only supported for Llama4" ) + activation_type = activation_to_flashinfer_int(layer.activation) return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( routing_logits=router_logits, @@ -145,6 +169,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, routing_method_type=layer.routing_method_type, + activation_type=activation_type, ) @@ -274,8 +299,64 @@ def convert_moe_weights_to_flashinfer_trtllm_block_layout( return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor +def align_fp4_moe_weights_for_fi( + w13: torch.Tensor, + w13_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + is_act_and_mul: bool, + min_alignment: int = 16, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad intermediate size so FlashInfer kernels' alignment constraints hold. + + Some FlashInfer FP4 MoE kernels require the intermediate size + used for GEMM to be divisible by a small alignment value. When this is + not satisfied (e.g. with certain tensor-parallel sizes), we pad the + gate/up and down projection weights along the intermediate dim. + """ + + # Current local intermediate size (per partition) is the K dimension of + # the down projection. + num_experts, hidden_size, intermediate = w2.shape + intermediate *= 2 # because of packed FP4 + + padded_intermediate = round_up(intermediate, min_alignment) + + if padded_intermediate == intermediate: + return w13, w13_scale, w2, w2_scale, intermediate + + logger.info_once( + "Padding intermediate size from %d to %d for up/down projection weights.", + intermediate, + padded_intermediate, + scope="local", + ) + + up_mult = 2 if is_act_and_mul else 1 + padded_gate_up_dim = up_mult * padded_intermediate + + # Pad w13 and w2 along its intermediate dimension. + padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size // 2)) + padded_w13[:, : w13.shape[1], :] = w13 + + padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate // 2)) + padded_w2[:, :, : w2.shape[2]] = w2 + + padded_w13_scale = w13_scale.new_zeros( + (num_experts, padded_gate_up_dim, hidden_size // 16) + ) + padded_w13_scale[:, : w13_scale.shape[1], :] = w13_scale + + padded_w2_scale = w2_scale.new_zeros( + (num_experts, hidden_size, padded_intermediate // 16) + ) + padded_w2_scale[:, :, : w2_scale.shape[2]] = w2_scale + + return padded_w13, padded_w13_scale, padded_w2, padded_w2_scale, padded_intermediate + + def align_fp8_moe_weights_for_fi( - w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool + w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool, min_alignment: int = 16 ) -> tuple[torch.Tensor, torch.Tensor, int]: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. @@ -289,7 +370,6 @@ def align_fp8_moe_weights_for_fi( # the down projection. num_experts, hidden_size, intermediate = w2.shape - min_alignment = 16 padded_intermediate = round_up(intermediate, min_alignment) if padded_intermediate == intermediate: @@ -342,11 +422,14 @@ def prepare_fp8_moe_layer_for_fi( # Some FI MoE kernels require internal alignment of 16 # for the gate-up proj. Pad the weights to respect this. + is_gated = layer.activation.is_gated if not block_quant: + min_alignment = 16 if is_gated else 128 w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( w13, w2, layer.moe_config.is_act_and_mul, + min_alignment, ) layer.intermediate_size_per_partition = new_intermediate @@ -363,7 +446,7 @@ def prepare_fp8_moe_layer_for_fi( assert w13_input_scale is not None assert w2_input_scale is not None - rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2) + rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated) register_scales_for_trtllm_fp8_per_tensor_moe( layer, w13_scale=w13_scale,