diff --git a/tests/kernels/quantization/test_mxfp4_triton_ep.py b/tests/kernels/quantization/test_mxfp4_triton_ep.py new file mode 100644 index 000000000000..d4eb91058906 --- /dev/null +++ b/tests/kernels/quantization/test_mxfp4_triton_ep.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests that triton_kernel_moe_forward correctly applies expert_map +remapping when expert parallelism (EP) is enabled. + +Previously, legacy_routing was always used and it produced routing data +with global expert IDs that didn't correspond to local weight indices, +causing illegal memory access with EP. The fix splits routing: when +expert_map is provided, topk selection is performed first, expert_map is +applied to remap global→local IDs, and make_routing_data builds routing +structures from the local IDs. +""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, + Mxfp4MoEMethod, +) + + +def _make_mock_moe_config(ep_size: int = 1) -> MagicMock: + """Create a mock FusedMoEConfig with the given EP size.""" + parallel_config = MagicMock() + parallel_config.ep_size = ep_size + + moe_config = MagicMock() + moe_config.ep_size = ep_size + moe_config.is_lora_enabled = False + moe_config.moe_parallel_config = parallel_config + return moe_config + + +class TestMxfp4TritonIsMonolithic: + """Verify that is_monolithic is always True for the TRITON backend, + regardless of EP size, since triton_kernel_moe_forward now handles + expert_map remapping internally.""" + + @pytest.mark.parametrize( + "backend,ep_size,expected_monolithic", + [ + # TRITON is always monolithic (handles EP via expert_map remapping) + (Mxfp4Backend.TRITON, 1, True), + (Mxfp4Backend.TRITON, 2, True), + (Mxfp4Backend.TRITON, 4, True), + # SM100 backends are always monolithic + (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True), + (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True), + (Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True), + (Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True), + # MARLIN is never monolithic + (Mxfp4Backend.MARLIN, 1, False), + (Mxfp4Backend.MARLIN, 2, False), + ], + ids=[ + "triton-no-ep", + "triton-ep2", + "triton-ep4", + "sm100-trtllm-no-ep", + "sm100-trtllm-ep2", + "sm100-bf16-no-ep", + "sm100-bf16-ep2", + "marlin-no-ep", + "marlin-ep2", + ], + ) + @patch( + "vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend", + ) + @patch( + "vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config", + ) + def test_is_monolithic( + self, + mock_get_config, + mock_get_backend, + backend, + ep_size, + expected_monolithic, + ): + """is_monolithic should be True for TRITON regardless of EP size.""" + mock_get_backend.return_value = backend + + mock_compilation_config = MagicMock() + mock_compilation_config.max_cudagraph_capture_size = 1024 + mock_vllm_config = MagicMock() + mock_vllm_config.compilation_config = mock_compilation_config + mock_get_config.return_value = mock_vllm_config + + moe_config = _make_mock_moe_config(ep_size=ep_size) + method = Mxfp4MoEMethod(moe_config) + + assert method.is_monolithic == expected_monolithic, ( + f"Expected is_monolithic={expected_monolithic} for " + f"backend={backend.name}, ep_size={ep_size}, " + f"but got {method.is_monolithic}." + ) + + +class TestTritonMoeForwardExpertMap: + """Test that triton_kernel_moe_forward applies expert_map remapping + when expert_map is provided (EP active).""" + + @pytest.mark.parametrize("expert_map_present", [False, True]) + def test_routing_path_selection(self, expert_map_present): + """Verify that the EP-aware routing path is taken when expert_map + is present, and the legacy_routing path is taken otherwise.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + # This is a structural test: we mock the routing functions to + # verify the correct path is exercised. + mock_expert_map = ( + torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None + ) + + with ( + patch( + "vllm.model_executor.layers.fused_moe." + "gpt_oss_triton_kernels_moe.legacy_routing" + ) as mock_legacy, + patch("triton_kernels.topk.topk") as mock_topk, + patch( + "vllm.model_executor.layers.fused_moe." + "gpt_oss_triton_kernels_moe.make_routing_data" + ) as mock_make_routing, + patch( + "vllm.model_executor.layers.fused_moe." + "gpt_oss_triton_kernels_moe.triton_kernel_fused_experts" + ) as mock_fused_experts, + ): + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward, + ) + + # Set up return values + mock_routing_data = MagicMock() + mock_gather = MagicMock() + mock_scatter = MagicMock() + + if expert_map_present: + sparse_result = MagicMock() + sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32) + sparse_result.vals = torch.tensor([[0.6, 0.4]]) + mock_topk.return_value = sparse_result + mock_make_routing.return_value = ( + mock_routing_data, + mock_gather, + mock_scatter, + ) + else: + mock_legacy.return_value = ( + mock_routing_data, + mock_gather, + mock_scatter, + ) + + mock_fused_experts.return_value = torch.zeros((1, 8), device=device) + + hidden = torch.randn((1, 8), device=device) + w1 = torch.randn((2, 8, 16), device=device) + w2 = torch.randn((2, 8, 8), device=device) + logits = torch.randn((1, 4), device=device) + + triton_kernel_moe_forward( + hidden_states=hidden, + w1=w1, + w2=w2, + gating_output=logits, + topk=2, + renormalize=True, + expert_map=mock_expert_map, + ) + + if expert_map_present: + # EP path: should use topk + make_routing_data, NOT + # legacy_routing + mock_topk.assert_called_once() + mock_make_routing.assert_called_once() + mock_legacy.assert_not_called() + # expert_map should be None in the fused_experts call + # (already applied) + call_kwargs = mock_fused_experts.call_args + assert call_kwargs[1].get("expert_map") is None or ( + len(call_kwargs[0]) > 0 + ) + else: + # Non-EP path: should use legacy_routing + mock_legacy.assert_called_once() + mock_topk.assert_not_called() + mock_make_routing.assert_not_called() diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 70d11f44f43b..5617156bf2fc 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -179,9 +179,35 @@ def triton_kernel_moe_forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, ) -> torch.Tensor: - routing_data, gather_idx, scatter_idx = legacy_routing( - gating_output, topk, sm_first=not renormalize - ) + if expert_map is not None: + # With expert parallelism, legacy_routing produces routing data + # using global expert IDs which don't correspond to local weight + # indices. Split the routing into topk selection + expert_map + # remapping + local routing data construction (matching the + # approach used by OAITritonExperts.apply). + from triton_kernels.topk import topk as topk_fn + + sm_first = not renormalize + logits = gating_output + if sm_first: + logits = torch.softmax(logits, dim=-1) + sparse_logits = topk_fn(logits, topk, apply_softmax=not sm_first) + # sparse_logits.indx contains global expert IDs – remap to local. + topk_ids = expert_map[sparse_logits.indx.to(torch.long)] + topk_weights = sparse_logits.vals + local_num_experts = w1.size(0) + routing_data, gather_idx, scatter_idx = make_routing_data( + topk_ids, topk_weights, local_num_experts + ) + # expert_map already applied; pass None downstream. + effective_expert_map = None + effective_global_num_experts = local_num_experts + else: + routing_data, gather_idx, scatter_idx = legacy_routing( + gating_output, topk, sm_first=not renormalize + ) + effective_expert_map = expert_map + effective_global_num_experts = global_num_experts output = torch.empty_like(hidden_states) @@ -197,8 +223,8 @@ def triton_kernel_moe_forward( activation=activation, quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + global_num_experts=effective_global_num_experts, + expert_map=effective_expert_map, )