diff --git a/tests/kernels/moe/test_flashinfer_b12x_moe.py b/tests/kernels/moe/test_flashinfer_b12x_moe.py index ec0a9594fe12..13ea924ac282 100644 --- a/tests/kernels/moe/test_flashinfer_b12x_moe.py +++ b/tests/kernels/moe/test_flashinfer_b12x_moe.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import pytest import torch @@ -8,8 +10,7 @@ if not current_platform.is_device_capability_family(120): pytest.skip( - reason="FlashInfer CuteDSL SM12x MoE requires SM120 " - "(RTX Pro 6000 / DGX Spark).", + reason="FlashInfer B12x MoE requires SM120 (RTX Pro 6000 / DGX Spark).", allow_module_level=True, ) @@ -18,8 +19,8 @@ if not has_flashinfer_b12x_moe(): pytest.skip( reason=( - "FlashInfer cute_dsl_fused_moe_nvfp4 / convert_sf_to_mma_layout " - "not available in installed FlashInfer (needs PRs #3051 and #3066)." + "FlashInfer B12xMoEWrapper not available in installed " + "FlashInfer (needs PR #3080)." ), allow_module_level=True, ) @@ -40,7 +41,6 @@ from vllm.model_executor.layers.fused_moe.experts.flashinfer_b12x_moe import ( FlashInferB12xExperts, ) -from vllm.utils.flashinfer import flashinfer_convert_sf_to_mma_layout from vllm.utils.torch_utils import set_random_seed # Dimensions chosen to satisfy FP4 alignment requirements (k multiple of 256, @@ -59,7 +59,7 @@ def _reorder_gate_up_to_up_gate( ) -> tuple[torch.Tensor, torch.Tensor]: """Swap gate and up-projection halves along dim=1 to [up, gate] order. - The SM12x kernel expects weights in [up (w3), gate (w1)] order while the + The B12x kernel expects weights in [up (w3), gate (w1)] order while the BF16 reference uses [gate (w1), up (w3)]. This replicates the reordering done at model-load time by ``prepare_nvfp4_moe_layer_for_fi_or_cutlass``. """ @@ -70,6 +70,22 @@ def _reorder_gate_up_to_up_gate( ) +def _process_b12x_weights( + experts: FlashInferB12xExperts, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_scale_2: torch.Tensor, + w2_scale_2: torch.Tensor, +) -> None: + layer = SimpleNamespace( + w13_weight_scale=w1_scale, + w13_weight_scale_2=w1_scale_2, + w2_weight_scale=w2_scale, + w2_weight_scale_2=w2_scale_2, + ) + experts.process_weights_after_loading(layer) + + @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", [8, 16]) @pytest.mark.parametrize("topk", [1, 2, 4]) @@ -174,22 +190,12 @@ def test_flashinfer_b12x_moe( moe_config=moe_config, quant_config=quant_config, ) - # In production, process_weights_after_loading computes these after - # normalizing block scales. In the test the scales are already in final - # form (global_scale=1.0), so we compute the MMA layouts directly. - num_experts_w1, m1, k1_sf = w1_blockscale.shape - experts.w1_sf_mma = flashinfer_convert_sf_to_mma_layout( - w1_blockscale.reshape(num_experts_w1 * m1, k1_sf), - m=m1, - k=k1_sf * 16, - num_groups=num_experts_w1, - ) - num_experts_w2, m2, k2_sf = w2_blockscale.shape - experts.w2_sf_mma = flashinfer_convert_sf_to_mma_layout( - w2_blockscale.reshape(num_experts_w2 * m2, k2_sf), - m=m2, - k=k2_sf * 16, - num_groups=num_experts_w2, + _process_b12x_weights( + experts, + w1_blockscale, + w2_blockscale, + ones_e, + ones_e, ) kernel = mk.FusedMoEKernel( @@ -225,5 +231,135 @@ def test_flashinfer_b12x_moe( torch.testing.assert_close(sm12x_output, torch_output, atol=2e-1, rtol=2e-1) +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [8, 16]) +@pytest.mark.parametrize("topk", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.inference_mode() +def test_flashinfer_b12x_moe_relu2( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + workspace_init, +): + """Test FlashInferB12xExperts with ReLU2 (non-gated) activation. + + ReLU2 is used by Nemotron-H style models. Unlike the gated SiLU + path, w1 has shape [E, N, K] (not [E, 2N, K]) and the activation + is relu(x)^2 without a gate/up split. + """ + set_random_seed(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + + # Non-gated: w1 shape is (e, n, k), not (e, 2n, k). + w1_bf16 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 15 + w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 15 + + gs = torch.ones(1, device="cuda", dtype=torch.float32) + sf_vec_size = 16 + + # W1: no gate/up reordering for non-gated. + w1_flat = w1_bf16.reshape(e * n, k) + w1_q_flat, w1_sf_flat = fp4_quantize( + w1_flat, + global_scale=gs, + sf_vec_size=sf_vec_size, + is_sf_swizzled_layout=True, + ) + w1_q = w1_q_flat.view(e, n, k // 2) + w1_blockscale = w1_sf_flat.view(e, n, w1_sf_flat.shape[1]) + + w2_flat = w2_bf16.reshape(e * k, n) + w2_q_flat, w2_sf_flat = fp4_quantize( + w2_flat, + global_scale=gs, + sf_vec_size=sf_vec_size, + is_sf_swizzled_layout=True, + ) + w2_q = w2_q_flat.view(e, k, n // 2) + w2_blockscale = w2_sf_flat.view(e, k, w2_sf_flat.shape[1]) + + ones_e = torch.ones(e, device="cuda", dtype=torch.float32) + + quant_config = nvfp4_moe_quant_config( + g1_alphas=ones_e, + g2_alphas=ones_e, + a1_gscale=ones_e, + a2_gscale=ones_e, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + + moe_config = make_dummy_moe_config( + num_experts=e, + experts_per_token=topk, + hidden_dim=k, + intermediate_size_per_partition=n, + in_dtype=dtype, + activation=MoEActivation.RELU2_NO_MUL, + is_act_and_mul=False, + ) + + experts = FlashInferB12xExperts( + moe_config=moe_config, + quant_config=quant_config, + ) + _process_b12x_weights( + experts, + w1_blockscale, + w2_blockscale, + ones_e, + ones_e, + ) + + kernel = mk.FusedMoEKernel( + maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + experts, + inplace=False, + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + b12x_output = kernel.apply( + hidden_states=a, + w1=w1_q, + w2=w2_q, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=e, + activation=MoEActivation.RELU2_NO_MUL, + apply_router_weight_on_input=False, + expert_map=None, + ) + + torch_output = torch_moe( + a, + w1_bf16, + w2_bf16, + score, + topk, + activation=MoEActivation.RELU2_NO_MUL, + ) + + torch.testing.assert_close( + b12x_output, + torch_output, + atol=2e-1, + rtol=2e-1, + ) + + if __name__ == "__main__": test_flashinfer_b12x_moe(16, 128, 256, 8, 2, torch.bfloat16) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index acb2c21b3896..9141586c0d08 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -53,6 +53,8 @@ def make_dummy_moe_config( hidden_dim: int = 1, intermediate_size_per_partition: int = 1, in_dtype: torch.dtype = torch.bfloat16, + activation: MoEActivation = MoEActivation.SILU, + is_act_and_mul: bool = True, ) -> FusedMoEConfig: """ This is a dummy config for the mk constructor interface @@ -69,7 +71,8 @@ def make_dummy_moe_config( num_local_experts=num_experts, num_logical_experts=num_experts, moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), - activation=MoEActivation.SILU, + activation=activation, + is_act_and_mul=is_act_and_mul, in_dtype=in_dtype, device="cuda", routing_method=RoutingMethodType.TopK, diff --git a/vllm/envs.py b/vllm/envs.py index 047141a9c0e2..5adf641e7599 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -188,6 +188,7 @@ VLLM_FLASHINFER_AUTOTUNE_CACHE_DIR: str | None = None VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 + VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1499,6 +1500,13 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024)) ), + # When >0 and num_tokens >= threshold, the B12x SM12x MoE wrapper routes + # to cutlass_fused_moe (prefill) instead of the b12x kernels (decode). + # 0 (default) keeps pure b12x dispatch. Requires a FlashInfer build that + # exposes the `cutlass_prefill_threshold` kwarg on B12xMoEWrapper. + "VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD": lambda: int( + os.getenv("VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD", "0") + ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py index 6481434f2e78..31d8d06ef1de 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect + import torch +import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -20,7 +23,6 @@ ) from vllm.platforms import current_platform from vllm.utils.flashinfer import ( - flashinfer_b12x_fused_moe, flashinfer_convert_sf_to_mma_layout, has_flashinfer_b12x_moe, ) @@ -42,6 +44,11 @@ class FlashInferB12xExperts(mk.FusedMoEExpertsModular): Only NVFP4 (kNvfp4Static/kNvfp4Dynamic) quantization is supported. """ + _ACTIVATION_MAP: dict[MoEActivation, str] = { + MoEActivation.SILU: "silu", + MoEActivation.RELU2_NO_MUL: "relu2", + } + def __init__( self, moe_config: FusedMoEConfig, @@ -55,7 +62,117 @@ def __init__( self.num_local_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank + # Shape params for B12xMoEWrapper construction. + self.global_num_experts = moe_config.num_experts + self.topk = moe_config.experts_per_token + self.hidden_dim = moe_config.hidden_dim + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.max_num_tokens = moe_config.max_num_tokens + self.local_expert_offset = self.ep_rank * self.num_local_experts + + activation = moe_config.activation + if activation not in self._ACTIVATION_MAP: + raise ValueError( + f"FlashInferB12xExperts does not support " + f"activation {activation!r}. " + f"Supported: {list(self._ACTIVATION_MAP.keys())}" + ) + self._activation_str = self._ACTIVATION_MAP[activation] + + # Hybrid CUTLASS-prefill / B12x-decode dispatch. When > 0, the + # wrapper routes batches with num_tokens >= threshold through + # cutlass_fused_moe; see register_cutlass_prefill_weights() in + # _ensure_wrapper. Requires a FlashInfer build exposing the + # `cutlass_prefill_threshold` kwarg on B12xMoEWrapper. + self.cutlass_prefill_threshold = ( + envs.VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD + ) + + # Lazily created on first apply() call. + self._wrapper: Any | None = None + self._cutlass_registered: bool = False + self.w1_sf_mma: torch.Tensor | None = None + self.w2_sf_mma: torch.Tensor | None = None + + # CUTLASS-format scales saved before the in-place B12x rewrite in + # process_weights_after_loading. Only populated when + # cutlass_prefill_threshold > 0. + self._cutlass_w13_scale: torch.Tensor | None = None + self._cutlass_w2_scale: torch.Tensor | None = None + self._cutlass_a1_gscale: torch.Tensor | None = None + self._cutlass_a2_gscale: torch.Tensor | None = None + self._cutlass_g1_alphas: torch.Tensor | None = None + self._cutlass_g2_alphas: torch.Tensor | None = None + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # When hybrid CUTLASS prefill is enabled, save copies of the + # CUTLASS-format scales BEFORE the in-place B12x rewrite below + # destroys them. The FP4 weight bytes themselves are reusable — + # prepare_nvfp4_moe_layer_for_fi_or_cutlass produces the same + # [w3, w1] reorder + swizzled SF for both FLASHINFER_CUTLASS and + # FLASHINFER_B12X — so we only need to clone the scales. + # + # g_alphas: B12x leaves g1_alphas = 1/w_gs (does NOT fold + # a_input_scale). CUTLASS wants 1/(a_gs * w_gs) = (1/w_gs) / a_gs, + # hence the division. + # + # The clones are registered as nn.Parameter on the layer so + # FusedMoE.get_expert_weights picks them up and EPLB rearranges + # them in lockstep with the live b12x scales. + if self.cutlass_prefill_threshold > 0: + assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, ( + "Expected swizzled FP8 SF before B12x rewrite, got " + f"{layer.w13_weight_scale.dtype}" + ) + cutlass_w13_scale = layer.w13_weight_scale.clone() + cutlass_w2_scale = layer.w2_weight_scale.clone() + cutlass_a1_gscale = self.a1_gscale.clone() + cutlass_a2_gscale = self.a2_gscale.clone() + cutlass_g1_alphas = ( + self.g1_alphas.float() / self.a1_gscale + ).contiguous() + cutlass_g2_alphas = ( + self.g2_alphas.float() / self.a2_gscale + ).contiguous() + + layer.register_parameter( + "w13_cutlass_weight_scale", + torch.nn.Parameter(cutlass_w13_scale, requires_grad=False), + ) + layer.register_parameter( + "w2_cutlass_weight_scale", + torch.nn.Parameter(cutlass_w2_scale, requires_grad=False), + ) + layer.register_parameter( + "w13_cutlass_a_gscale", + torch.nn.Parameter(cutlass_a1_gscale, requires_grad=False), + ) + layer.register_parameter( + "w2_cutlass_a_gscale", + torch.nn.Parameter(cutlass_a2_gscale, requires_grad=False), + ) + layer.register_parameter( + "w13_cutlass_g_alphas", + torch.nn.Parameter(cutlass_g1_alphas, requires_grad=False), + ) + layer.register_parameter( + "w2_cutlass_g_alphas", + torch.nn.Parameter(cutlass_g2_alphas, requires_grad=False), + ) + + # Hold references on the experts class so _ensure_wrapper can + # build the quant_scales list without re-fetching from layer. + # These alias the registered Parameters' storage, so EPLB + # rearrangement of the parameters is observed here too. + self._cutlass_w13_scale = layer.w13_cutlass_weight_scale.data + self._cutlass_w2_scale = layer.w2_cutlass_weight_scale.data + self._cutlass_a1_gscale = layer.w13_cutlass_a_gscale.data + self._cutlass_a2_gscale = layer.w2_cutlass_a_gscale.data + self._cutlass_g1_alphas = layer.w13_cutlass_g_alphas.data + self._cutlass_g2_alphas = layer.w2_cutlass_g_alphas.data + # Normalise block scales to absorb the per-expert weight global scale # (w_gs). vLLM's NVFP4 convention stores: # block_scale = max_abs * w_gs / fp4_max, g1_alphas = 1/w_gs @@ -124,7 +241,7 @@ def _supports_current_device() -> bool: @staticmethod def _supports_no_act_and_mul() -> bool: - return False + return True @staticmethod def _supports_quant_scheme( @@ -135,7 +252,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: MoEActivation) -> bool: - return activation == MoEActivation.SILU + return activation in (MoEActivation.SILU, MoEActivation.RELU2_NO_MUL) @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -167,13 +284,73 @@ def workspace_shapes( @property def expects_unquantized_inputs(self) -> bool: - # b12x_fused_moe expects BF16 hidden states and performs its own FP4 + # B12xMoEWrapper expects BF16 hidden states and performs its own FP4 # quantization internally. Returning True prevents the modular kernel - # from pre-quantizing activations, which would produce an FP4-packed - # tensor with size(-1)=k//2 and break the scale-factor conversion that - # expects size(-1)=k. + # from pre-quantizing activations. return True + def _ensure_wrapper(self, w1: torch.Tensor, w2: torch.Tensor) -> None: + """Lazily create B12xMoEWrapper on first use. + + Also registers CUTLASS-format prefill weights when hybrid dispatch + is enabled; the FP4 byte tensors are shared with the b12x decode + path (only scales differ — saved in process_weights_after_loading). + """ + if self._wrapper is None: + from flashinfer.fused_moe import B12xMoEWrapper + + b12x_kwargs = dict( + num_experts=self.global_num_experts, + top_k=self.topk, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_size_per_partition, + use_cuda_graph=True, + max_num_tokens=self.max_num_tokens, + num_local_experts=self.num_local_experts, + activation=self._activation_str, + ) + # cutlass_prefill_threshold is gated on a FlashInfer build that + # exposes the kwarg. Skip silently if absent and threshold is 0; + # error cleanly if the user is asking for the hybrid path. + if "cutlass_prefill_threshold" in inspect.signature( + B12xMoEWrapper.__init__ + ).parameters: + b12x_kwargs["cutlass_prefill_threshold"] = ( + self.cutlass_prefill_threshold + ) + elif self.cutlass_prefill_threshold > 0: + raise RuntimeError( + "VLLM_FLASHINFER_B12X_CUTLASS_PREFILL_THRESHOLD > 0 " + "requires a FlashInfer build that exposes the " + "`cutlass_prefill_threshold` kwarg on B12xMoEWrapper; " + "current FlashInfer does not." + ) + self._wrapper = B12xMoEWrapper(**b12x_kwargs) + + if self.cutlass_prefill_threshold > 0 and not self._cutlass_registered: + assert self._cutlass_w13_scale is not None, ( + "cutlass_prefill_threshold > 0 but CUTLASS scales were " + "not saved in process_weights_after_loading" + ) + # quant_scales order matches FlashInferExperts (NVFP4 mode): + # [a1_gs, w1_blockscale_int32, 1/(a1_gs*w1_gs), + # a2_gs, w2_blockscale_int32, 1/(a2_gs*w2_gs)]. + # register_cutlass_prefill_weights does .contiguous().view(long) + # on w*_q internally — pass uint8 directly. + self._wrapper.register_cutlass_prefill_weights( + w1_q=w1, + w2_q=w2, + quant_scales=[ + self._cutlass_a1_gscale, + self._cutlass_w13_scale.view(torch.int32), + self._cutlass_g1_alphas, + self._cutlass_a2_gscale, + self._cutlass_w2_scale.view(torch.int32), + self._cutlass_g2_alphas, + ], + ) + self._cutlass_registered = True + def apply( self, output: torch.Tensor, @@ -201,13 +378,14 @@ def apply( assert self.a2_gscale is not None, ( "a2_gscale must not be None for FlashInferB12xExperts" ) + assert self.w1_sf_mma is not None and self.w2_sf_mma is not None, ( + "process_weights_after_loading must run before FlashInferB12xExperts.apply" + ) - top_k = topk_ids.shape[1] + self._ensure_wrapper(w1, w2) - flashinfer_b12x_fused_moe( + result = self._wrapper.run( x=hidden_states, - token_selected_experts=topk_ids.to(torch.int32), - token_final_scales=topk_weights, w1_weight=w1, w1_weight_sf=self.w1_sf_mma, w1_alpha=self.g1_alphas, @@ -215,9 +393,7 @@ def apply( w2_weight=w2, w2_weight_sf=self.w2_sf_mma, w2_alpha=self.g2_alphas, - num_experts=global_num_experts, - top_k=top_k, - num_local_experts=self.num_local_experts, - output_dtype=self.out_dtype, - output=output, + token_selected_experts=topk_ids.to(torch.int32), + token_final_scales=topk_weights, ) + output.copy_(result)