-
-
Notifications
You must be signed in to change notification settings - Fork 15.2k
fix(mxfp4): Disable monolithic path for TRITON backend with EP #34270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 5 commits into
vllm-project:main
from
elizabetht:fix/mxfp4-triton-ep-crash
Feb 25, 2026
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b8d4938
fix(mxfp4): Apply expert_map remapping in triton_kernel_moe_forward f…
elizabetht 68fc33e
Fix EP expert_map indexing dtype
elizabetht 64b7795
Fix mxfp4 triton EP test to run on CUDA devices
elizabetht d5f5aea
Add CPU fallback for device selection in EP test
elizabetht cacd228
Merge branch 'main' into fix/mxfp4-triton-ep-crash
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| Tests that triton_kernel_moe_forward correctly applies expert_map | ||
| remapping when expert parallelism (EP) is enabled. | ||
|
|
||
| Previously, legacy_routing was always used and it produced routing data | ||
| with global expert IDs that didn't correspond to local weight indices, | ||
| causing illegal memory access with EP. The fix splits routing: when | ||
| expert_map is provided, topk selection is performed first, expert_map is | ||
| applied to remap global→local IDs, and make_routing_data builds routing | ||
| structures from the local IDs. | ||
| """ | ||
|
|
||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.quantization.mxfp4 import ( | ||
| Mxfp4Backend, | ||
| Mxfp4MoEMethod, | ||
| ) | ||
|
|
||
|
|
||
| def _make_mock_moe_config(ep_size: int = 1) -> MagicMock: | ||
| """Create a mock FusedMoEConfig with the given EP size.""" | ||
| parallel_config = MagicMock() | ||
| parallel_config.ep_size = ep_size | ||
|
|
||
| moe_config = MagicMock() | ||
| moe_config.ep_size = ep_size | ||
| moe_config.is_lora_enabled = False | ||
| moe_config.moe_parallel_config = parallel_config | ||
| return moe_config | ||
|
|
||
|
|
||
| class TestMxfp4TritonIsMonolithic: | ||
| """Verify that is_monolithic is always True for the TRITON backend, | ||
| regardless of EP size, since triton_kernel_moe_forward now handles | ||
| expert_map remapping internally.""" | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "backend,ep_size,expected_monolithic", | ||
| [ | ||
| # TRITON is always monolithic (handles EP via expert_map remapping) | ||
| (Mxfp4Backend.TRITON, 1, True), | ||
| (Mxfp4Backend.TRITON, 2, True), | ||
| (Mxfp4Backend.TRITON, 4, True), | ||
| # SM100 backends are always monolithic | ||
| (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True), | ||
| (Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True), | ||
| (Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True), | ||
| (Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True), | ||
| # MARLIN is never monolithic | ||
| (Mxfp4Backend.MARLIN, 1, False), | ||
| (Mxfp4Backend.MARLIN, 2, False), | ||
| ], | ||
| ids=[ | ||
| "triton-no-ep", | ||
| "triton-ep2", | ||
| "triton-ep4", | ||
| "sm100-trtllm-no-ep", | ||
| "sm100-trtllm-ep2", | ||
| "sm100-bf16-no-ep", | ||
| "sm100-bf16-ep2", | ||
| "marlin-no-ep", | ||
| "marlin-ep2", | ||
| ], | ||
| ) | ||
| @patch( | ||
| "vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend", | ||
| ) | ||
| @patch( | ||
| "vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config", | ||
| ) | ||
| def test_is_monolithic( | ||
| self, | ||
| mock_get_config, | ||
| mock_get_backend, | ||
| backend, | ||
| ep_size, | ||
| expected_monolithic, | ||
| ): | ||
| """is_monolithic should be True for TRITON regardless of EP size.""" | ||
| mock_get_backend.return_value = backend | ||
|
|
||
| mock_compilation_config = MagicMock() | ||
| mock_compilation_config.max_cudagraph_capture_size = 1024 | ||
| mock_vllm_config = MagicMock() | ||
| mock_vllm_config.compilation_config = mock_compilation_config | ||
| mock_get_config.return_value = mock_vllm_config | ||
|
|
||
| moe_config = _make_mock_moe_config(ep_size=ep_size) | ||
| method = Mxfp4MoEMethod(moe_config) | ||
|
|
||
| assert method.is_monolithic == expected_monolithic, ( | ||
| f"Expected is_monolithic={expected_monolithic} for " | ||
| f"backend={backend.name}, ep_size={ep_size}, " | ||
| f"but got {method.is_monolithic}." | ||
| ) | ||
|
|
||
|
|
||
| class TestTritonMoeForwardExpertMap: | ||
| """Test that triton_kernel_moe_forward applies expert_map remapping | ||
| when expert_map is provided (EP active).""" | ||
|
|
||
| @pytest.mark.parametrize("expert_map_present", [False, True]) | ||
| def test_routing_path_selection(self, expert_map_present): | ||
| """Verify that the EP-aware routing path is taken when expert_map | ||
| is present, and the legacy_routing path is taken otherwise.""" | ||
|
|
||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| # This is a structural test: we mock the routing functions to | ||
| # verify the correct path is exercised. | ||
| mock_expert_map = ( | ||
| torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None | ||
| ) | ||
|
|
||
| with ( | ||
| patch( | ||
| "vllm.model_executor.layers.fused_moe." | ||
| "gpt_oss_triton_kernels_moe.legacy_routing" | ||
| ) as mock_legacy, | ||
| patch("triton_kernels.topk.topk") as mock_topk, | ||
| patch( | ||
| "vllm.model_executor.layers.fused_moe." | ||
| "gpt_oss_triton_kernels_moe.make_routing_data" | ||
| ) as mock_make_routing, | ||
| patch( | ||
| "vllm.model_executor.layers.fused_moe." | ||
| "gpt_oss_triton_kernels_moe.triton_kernel_fused_experts" | ||
| ) as mock_fused_experts, | ||
| ): | ||
| from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 | ||
| triton_kernel_moe_forward, | ||
| ) | ||
|
|
||
| # Set up return values | ||
| mock_routing_data = MagicMock() | ||
| mock_gather = MagicMock() | ||
| mock_scatter = MagicMock() | ||
|
|
||
| if expert_map_present: | ||
| sparse_result = MagicMock() | ||
| sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32) | ||
| sparse_result.vals = torch.tensor([[0.6, 0.4]]) | ||
| mock_topk.return_value = sparse_result | ||
| mock_make_routing.return_value = ( | ||
| mock_routing_data, | ||
| mock_gather, | ||
| mock_scatter, | ||
| ) | ||
| else: | ||
| mock_legacy.return_value = ( | ||
| mock_routing_data, | ||
| mock_gather, | ||
| mock_scatter, | ||
| ) | ||
|
|
||
| mock_fused_experts.return_value = torch.zeros((1, 8), device=device) | ||
|
|
||
| hidden = torch.randn((1, 8), device=device) | ||
| w1 = torch.randn((2, 8, 16), device=device) | ||
| w2 = torch.randn((2, 8, 8), device=device) | ||
| logits = torch.randn((1, 4), device=device) | ||
|
|
||
| triton_kernel_moe_forward( | ||
| hidden_states=hidden, | ||
| w1=w1, | ||
| w2=w2, | ||
| gating_output=logits, | ||
| topk=2, | ||
| renormalize=True, | ||
| expert_map=mock_expert_map, | ||
| ) | ||
|
|
||
| if expert_map_present: | ||
| # EP path: should use topk + make_routing_data, NOT | ||
| # legacy_routing | ||
| mock_topk.assert_called_once() | ||
| mock_make_routing.assert_called_once() | ||
| mock_legacy.assert_not_called() | ||
| # expert_map should be None in the fused_experts call | ||
| # (already applied) | ||
| call_kwargs = mock_fused_experts.call_args | ||
| assert call_kwargs[1].get("expert_map") is None or ( | ||
| len(call_kwargs[0]) > 0 | ||
| ) | ||
| else: | ||
| # Non-EP path: should use legacy_routing | ||
| mock_legacy.assert_called_once() | ||
| mock_topk.assert_not_called() | ||
| mock_make_routing.assert_not_called() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@elizabetht This test doesn't pass for me on H100. Can you fix it by
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bump
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Working on it now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath Could you check now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a minor tweak to conditionally set device as "cuda" if torch.cuda is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly I don't know how useful all this mocked logic is, especially since we are actively refactoring the mxfp4 moe. It would be better to just have an e2e gsm8k eval in CI with EPDP used