diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 537dcae4e74b..ac6d54b7199b 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -21,12 +21,12 @@ get_tensor_model_parallel_world_size, ) from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from .mk_objects import ( diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index c9d425b5b990..d78e1947fac0 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -15,10 +15,10 @@ from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel, ) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform from vllm.triton_utils import tl from vllm.utils.torch_utils import set_random_seed diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 53a03f48e24e..63fbbfeec6ca 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -11,13 +11,15 @@ ) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe import ( + fused_experts, + fused_topk, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, deep_gemm_moe_fp8, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe, ) from vllm.platforms import current_platform diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index bd10c3793e34..d5b1c2cf006a 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -10,6 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -19,7 +20,6 @@ CutlassExpertsFp8, run_cutlass_moe_fp8, ) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1262eea70bab..ee23e7e6529d 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -12,6 +12,7 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe, @@ -19,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( create_flashinfer_prepare_finalize, ) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index f676cc4fee1b..2a974206d1d0 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -14,7 +14,7 @@ get_cached_compilation_config, set_current_vllm_config, ) -from vllm.model_executor.layers.fused_moe.fused_moe import ( +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import ( GroupedTopk, fused_grouped_topk, ) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 07ced9769b00..551233169f59 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -24,6 +24,9 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe import ( + fused_topk, +) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config, @@ -34,7 +37,6 @@ fused_marlin_moe, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe, ) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 45127ce0ac63..0b3c435aa067 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -9,7 +9,7 @@ import pytest import torch -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 873d72117de7..4dd4223db373 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -13,11 +13,11 @@ from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, ) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 3a5801ae4996..a2ab94e370f5 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -8,9 +8,9 @@ from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8 -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv diff --git a/tests/kernels/moe/test_routing.py b/tests/kernels/moe/test_routing.py new file mode 100644 index 000000000000..93aa6aa5c5ca --- /dev/null +++ b/tests/kernels/moe/test_routing.py @@ -0,0 +1,499 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import pytest +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.router_factory import ( + create_fused_moe_router, +) +from vllm.model_executor.models.llama4 import Llama4MoE + +# Test parameters +MK_S = [(32, 256), (64, 512)] +TOP_KS = [2, 4, 6] +NUM_EXPERTS = [8, 16, 64] + + +def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState: + if not enable_eplb: + return EplbLayerState() + + # Initialize EPLB state with proper tensors for testing + # For testing purposes, we use a simple 1:1 mapping (no redundant experts) + # expert_load_view: tracks load on each expert (shape: num_experts) + expert_load_view = torch.zeros(global_num_experts, dtype=torch.int32, device="cuda") + + # logical_to_physical_map: maps logical experts to physical experts + # Shape: (num_logical_experts, max_slots) + # For testing, use simple 1:1 mapping with single slot per expert + logical_to_physical_map = torch.arange( + global_num_experts, dtype=torch.int64, device="cuda" + ).unsqueeze(-1) + + # logical_replica_count: number of replicas per logical expert + # Shape: (num_logical_experts,) + # For testing, each logical expert has exactly 1 replica + logical_replica_count = torch.ones( + global_num_experts, dtype=torch.int64, device="cuda" + ) + + return EplbLayerState( + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + +def make_test_data( + m: int, k: int, num_experts: int +) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = torch.randn((m, k), device="cuda") / 10 + logits = torch.randn((m, num_experts), device="cuda") + return hidden_states, logits + + +def make_e_score_correction_bias( + e_score_correction_bias_val: float, + num_experts: int, +) -> torch.Tensor: + # return torch.randn(num_experts, device="cuda") * e_score_correction_bias_val + return torch.full( + (num_experts,), e_score_correction_bias_val, device="cuda", dtype=torch.float32 + ) + + +def assert_routing_results_close( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + baseline_weights: torch.Tensor, + baseline_ids: torch.Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, +): + """ + Compare routing results, sorting by expert ID first to handle non-deterministic + ordering from sorted=False in topk. + """ + # Sort both results by expert IDs for consistent comparison + sorted_indices_actual = torch.argsort(topk_ids, dim=-1) + sorted_indices_baseline = torch.argsort(baseline_ids.to(topk_ids.dtype), dim=-1) + + # Gather the sorted values + topk_ids_sorted = torch.gather(topk_ids, 1, sorted_indices_actual) + topk_weights_sorted = torch.gather(topk_weights, 1, sorted_indices_actual) + baseline_ids_sorted = torch.gather( + baseline_ids.to(topk_ids.dtype), 1, sorted_indices_baseline + ) + baseline_weights_sorted = torch.gather(baseline_weights, 1, sorted_indices_baseline) + + # Compare + torch.testing.assert_close(topk_ids_sorted, baseline_ids_sorted) + torch.testing.assert_close( + topk_weights_sorted, baseline_weights_sorted, rtol=rtol, atol=atol + ) + + +def baseline_fused_topk( + router_logits: torch.Tensor, top_k: int, renormalize: bool +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Baseline for standard fused top-k routing. + + Algorithm: + 1. Apply softmax to router logits + 2. Select top-k experts + 3. Optionally renormalize the weights + """ + scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + # Use sorted=False to match vllm implementation (vllm_is_batch_invariant + # defaults to False) + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def baseline_fused_topk_bias( + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + routed_scaling_factor: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Baseline for fused top-k with bias correction. + + Algorithm: + 1. Apply softmax to router logits + 2. Add bias to scores for expert selection + 3. Select top-k experts using biased scores + 4. Get weights from original (unbiased) scores + 5. Apply routed scaling factor + 6. Optionally renormalize the weights + """ + # Apply softmax to get scores + scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + + # Add bias for expert selection + scores_for_choice = scores + e_score_correction_bias.unsqueeze(0) + + # Select top-k using biased scores (sorted=False to match implementation) + topk_ids = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1] + + # Get weights from original scores (not biased) + topk_weights = scores.gather(1, topk_ids) + + # Renormalize if needed (BEFORE applying scaling factor) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + # Apply scaling factor (AFTER renormalization, if applicable) + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def baseline_grouped_topk( + router_logits: torch.Tensor, + top_k: int, + num_expert_group: int, + topk_group: int, + scoring_func: str, + renormalize: bool, + e_score_correction_bias: torch.Tensor | None, + routed_scaling_factor: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Baseline for grouped top-k routing (e.g., DeepSeek). + + Algorithm: + 1. Apply scoring function (softmax or sigmoid) + 2. Optionally add bias + 3. Select top-k groups based on max scores within each group + 4. Mask scores to only include selected groups + 5. Select top-k experts from masked scores + 6. Apply scaling factor + 7. Optionally renormalize + """ + num_token = router_logits.shape[0] + + # Apply scoring function + if scoring_func == "softmax": + scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + elif scoring_func == "sigmoid": + scores = torch.sigmoid(router_logits.float()) + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + # Handle bias correction + if e_score_correction_bias is not None: + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + # For bias case, use sum of top-2 scores in each group + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + # Use max score in each group + group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values + + # Select top-k groups + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] + + # Create mask for selected groups + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + # Expand mask to all experts + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) + + # Mask scores (set non-selected to -inf) + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) + + # Select top-k experts + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False)[1] + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=top_k, dim=-1, sorted=False) + + # Renormalize if needed + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + # Apply scaling factor + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def baseline_custom_llama4( + router_logits: torch.Tensor, top_k: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Baseline for Llama4 custom routing. + + Algorithm: + 1. Select top-k expert indices (without softmax) + 2. Apply sigmoid to the selected scores + """ + router_scores, router_indices = torch.topk(router_logits, top_k, dim=-1) + router_scores = torch.sigmoid(router_scores.float()) + return router_scores.to(torch.float32), router_indices.to(torch.int32) + + +@pytest.mark.parametrize("m,k", MK_S) +@pytest.mark.parametrize("top_k", TOP_KS) +@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("renormalize", [False, True]) +@pytest.mark.parametrize("enable_eplb", [False, True]) +def test_fused_topk( + m: int, + k: int, + top_k: int, + global_num_experts: int, + renormalize: bool, + enable_eplb: bool, +): + if top_k > global_num_experts: + pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})") + + eplb_state = setup_eplb_state(enable_eplb, global_num_experts) + router = create_fused_moe_router( + top_k=top_k, + global_num_experts=global_num_experts, + renormalize=renormalize, + enable_eplb=enable_eplb, + eplb_state=eplb_state, + ) + + hidden_states, router_logits = make_test_data(m, k, global_num_experts) + + # Get router output + topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) + + # Compute baseline + baseline_weights, baseline_ids = baseline_fused_topk( + router_logits, top_k, renormalize + ) + + # Compare results + assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids) + + +@pytest.mark.parametrize("m,k", MK_S) +@pytest.mark.parametrize("top_k", TOP_KS) +@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("renormalize", [False, True]) +@pytest.mark.parametrize("enable_eplb", [False, True]) +@pytest.mark.parametrize("e_score_correction_bias_val", [0.9]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1]) +def test_fused_topk_bias( + m: int, + k: int, + top_k: int, + global_num_experts: int, + renormalize: bool, + enable_eplb: bool, + e_score_correction_bias_val: float, + routed_scaling_factor: float, +): + if top_k > global_num_experts: + pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})") + + eplb_state = setup_eplb_state(enable_eplb, global_num_experts) + + e_score_correction_bias = make_e_score_correction_bias( + e_score_correction_bias_val, + global_num_experts, + ) + + router = create_fused_moe_router( + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + top_k=top_k, + global_num_experts=global_num_experts, + renormalize=renormalize, + enable_eplb=enable_eplb, + eplb_state=eplb_state, + ) + + hidden_states, router_logits = make_test_data(m, k, global_num_experts) + + # Get router output + topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) + + # Compute baseline + baseline_weights, baseline_ids = baseline_fused_topk_bias( + router_logits, + top_k, + renormalize, + e_score_correction_bias, + routed_scaling_factor, + ) + + # Compare results + assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids) + + +@pytest.mark.parametrize("m,k", MK_S) +@pytest.mark.parametrize("top_k", TOP_KS) +@pytest.mark.parametrize( + "global_num_experts,num_expert_group,topk_group", + [ + (64, 8, 4), # 8 groups of 8 experts, select 4 groups + (32, 4, 2), # 4 groups of 8 experts, select 2 groups + ], +) +@pytest.mark.parametrize("renormalize", [False, True]) +@pytest.mark.parametrize("enable_eplb", [False, True]) +@pytest.mark.parametrize("e_score_correction_bias_val", [0.9]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.1]) +@pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"]) +def test_grouped_topk( + m: int, + k: int, + top_k: int, + global_num_experts: int, + renormalize: bool, + enable_eplb: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str, + e_score_correction_bias_val: float, + routed_scaling_factor: float, +): + if top_k > global_num_experts: + pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})") + + eplb_state = setup_eplb_state(enable_eplb, global_num_experts) + + e_score_correction_bias = make_e_score_correction_bias( + e_score_correction_bias_val, + global_num_experts, + ) + + routing_method_type = None + if scoring_func == "llama4": + routing_method_type = RoutingMethodType.Llama4 + scoring_func = "sigmoid" + + router = create_fused_moe_router( + use_grouped_topk=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routing_method_type=routing_method_type, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + top_k=top_k, + global_num_experts=global_num_experts, + renormalize=renormalize, + enable_eplb=enable_eplb, + eplb_state=eplb_state, + ) + + hidden_states, router_logits = make_test_data(m, k, global_num_experts) + + # Get router output + topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) + + # Compute baseline + baseline_weights, baseline_ids = baseline_grouped_topk( + router_logits, + top_k, + num_expert_group, + topk_group, + scoring_func, + renormalize, + e_score_correction_bias, + routed_scaling_factor, + ) + + # Compare results + assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids) + + +@pytest.mark.parametrize("m,k", MK_S) +@pytest.mark.parametrize("top_k", TOP_KS) +@pytest.mark.parametrize("global_num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("renormalize", [False, True]) +@pytest.mark.parametrize("enable_eplb", [False, True]) +@pytest.mark.parametrize("custom_routing_function", [Llama4MoE.custom_routing_function]) +def test_custom( + m: int, + k: int, + top_k: int, + global_num_experts: int, + renormalize: bool, + enable_eplb: bool, + custom_routing_function: Callable, +): + if top_k > global_num_experts: + pytest.skip(f"top_k ({top_k}) > global_num_experts ({global_num_experts})") + + eplb_state = setup_eplb_state(enable_eplb, global_num_experts) + + router = create_fused_moe_router( + top_k=top_k, + global_num_experts=global_num_experts, + custom_routing_function=custom_routing_function, + renormalize=renormalize, + enable_eplb=enable_eplb, + eplb_state=eplb_state, + ) + + hidden_states, router_logits = make_test_data(m, k, global_num_experts) + + # Get router output + topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) + + # Compute baseline (Llama4 uses sigmoid) + baseline_weights, baseline_ids = baseline_custom_llama4(router_logits, top_k) + + # Compare results + assert_routing_results_close(topk_weights, topk_ids, baseline_weights, baseline_ids) + + +# TODO: is other test sufficient? +# # See tests/test_routing_simulatator.py +# @pytest.mark.parametrize("m,k", MK_S) +# @pytest.mark.parametrize("top_k", TOP_KS) +# @pytest.mark.parametrize("global_num_experts", NUM_EXPERTS) +# @pytest.mark.parametrize("renormalize", [False, True]) +# @pytest.mark.parametrize("enable_eplb", [False, True]) +# @pytest.mark.parameterize("strategy", ["uniform_random", "normal_routing"]) +# def test_simulated( +# m: int, +# k: int, +# top_k: int, +# global_num_experts: int, +# renormalize: bool, +# enable_eplb: bool, +# strategy: str, +# monkeypatch, +# ): +# eplb_state = setup_eplb_state(enable_eplb) + +# monkeypatch.setenv("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", strategy) +# router = create_fused_moe_router( +# top_k=top_k, +# global_num_experts=global_num_experts, +# enable_eplb=enable_eplb, +# eplb_state=eplb_state, +# ) + +# hidden_states, router_logits = make_test_data(m, k, global_num_experts) +# topk_weights, topk_ids = router.select_experts(hidden_states, router_logits) diff --git a/tests/test_routing_simulator.py b/tests/kernels/moe/test_routing_simulator.py similarity index 77% rename from tests/test_routing_simulator.py rename to tests/kernels/moe/test_routing_simulator.py index e37f30755663..c0c3a1e1da9e 100644 --- a/tests/test_routing_simulator.py +++ b/tests/kernels/moe/test_routing_simulator.py @@ -19,7 +19,7 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.layers.fused_moe.routing_simulator import ( +from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import ( DistributionBasedRouting, RoutingSimulator, ) @@ -109,40 +109,44 @@ def test_routing_strategy_integration(monkeypatch, device): tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ) - fused_moe = FusedMoE( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=0, - use_grouped_topk=False, - renormalize=True, - ) - - for strategy in strategies: - # Set environment variable - env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" - monkeypatch.setenv(env_name, strategy) - # Force reload of environment variable - envs.environment_variables[env_name] = lambda s=strategy: s - - # Test the select_experts method - topk_weights, topk_ids = fused_moe.router.select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) - - # Verify output shapes - assert topk_weights.shape == (num_tokens, top_k), ( - f"Wrong weights shape for {strategy}" - ) - assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}" - - # Verify expert IDs are valid - assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" - assert topk_ids.max() < num_experts, ( - f"Invalid expert ID (too large) for {strategy}" - ) + for strategy in strategies: + fused_moe = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=0, + use_grouped_topk=False, + renormalize=True, + prefix=strategy, + ) + + # Set environment variable + env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" + monkeypatch.setenv(env_name, strategy) + + # Force reload of environment variable + envs.environment_variables[env_name] = lambda s=strategy: s + + # Test the select_experts method + topk_weights, topk_ids = fused_moe.router.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + ) + + # Verify output shapes + assert topk_weights.shape == (num_tokens, top_k), ( + f"Wrong weights shape for {strategy}" + ) + assert topk_ids.shape == (num_tokens, top_k), ( + f"Wrong ids shape for {strategy}" + ) + + # Verify expert IDs are valid + assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max() < num_experts, ( + f"Invalid expert ID (too large) for {strategy}" + ) def test_distribution_based_routing_with_custom_strategy(): diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 7d95dcddca71..8ee1b1a37ca6 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -17,7 +17,7 @@ ReLUSquaredActivation, SiluAndMul, ) -from vllm.model_executor.layers.fused_moe.fused_moe import ( +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( dispatch_topk_func, vllm_topk_softmax, ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 424c2235cf8f..cc4b333f42cb 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1158,6 +1158,15 @@ def _sync_load_pass(self) -> list[torch.Tensor]: return self._allreduce_list(load_pass_list) +@dataclass +class EplbLayerState: + """Runtime EPLB data stored in the MoE layer.""" + + expert_load_view: torch.Tensor | None = None + logical_to_physical_map: torch.Tensor | None = None + logical_replica_count: torch.Tensor | None = None + + def _node_count_with_rank_mapping( pg: ProcessGroup | StatelessProcessGroup, rank_mapping: dict[int, int], diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 5ba9e80fc8b8..3bbffa2a6cff 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -11,9 +11,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import ( - FusedMoERouter, -) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoeWeightScaleSupported, @@ -23,6 +20,9 @@ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, @@ -83,13 +83,17 @@ def get_config() -> dict[str, Any] | None: BatchedTritonExperts, ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - GroupedTopk, TritonExperts, TritonWNA16Experts, fused_experts, - fused_topk, get_config_file_name, ) + from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( + fused_topk, + ) + from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import ( + GroupedTopk, + ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 23b86fdca898..c8baefbd55fe 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -117,8 +117,12 @@ class RoutingMethodType(IntEnum): RenormalizeNaive = (4,) # TopK: TopK (no softmax) TopK = (5,) + # Custom + Custom = (6,) + # Simulated + Simulated = (7,) # Unspecified - Unspecified = 6.0 + Unspecified = 8.0 @dataclass diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f2d9463e9183..d069d81f5b99 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,9 +13,7 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) @@ -34,9 +32,6 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -49,7 +44,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme -from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -1318,375 +1312,6 @@ def try_get_optimal_moe_config( return config -def vllm_topk_softmax( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, -) -> tuple[torch.Tensor, ...]: - ops.topk_softmax( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - ) - - return topk_weights, topk_indices - - -def dispatch_topk_func( - use_rocm_aiter: bool = False, -) -> Callable[..., tuple[torch.Tensor, ...]]: - if use_rocm_aiter: - return rocm_aiter_ops.topk_softmax - return vllm_topk_softmax - - -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - indices_type: torch.dtype | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - - M, _ = hidden_states.size() - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty( - M, - topk, - dtype=torch.int32 if indices_type is None else indices_type, - device=hidden_states.device, - ) - token_expert_indices = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - - topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) - topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output, renormalize - ) - - return topk_weights, topk_ids, token_expert_indices - - -def fused_topk_bias( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - e_score_correction_bias: torch.Tensor, - topk: int, - renormalize: bool, -): - n_routed_experts = gating_output.shape[-1] - scores = gating_output.softmax(dim=-1) - scores_for_choice = scores.view( - -1, n_routed_experts - ) + e_score_correction_bias.unsqueeze(0) - - # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() - topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] - topk_weights = scores.gather(1, topk_indices) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights.to(torch.float32), topk_indices.to(torch.int32) - - -# This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile( - dynamic=True, - backend=current_platform.simple_compile_backend, - options=maybe_disable_graph_partition(current_platform.simple_compile_backend), -) -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - if ( - envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK - and current_platform.is_cuda() - and num_expert_group <= 32 - and topk <= 32 - and e_score_correction_bias is not None - ): - return fused_grouped_topk( - hidden_states=hidden_states, - gating_output=gating_output, - topk=topk, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - ) - - assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.size(0) - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) - else: - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - - # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk( - tmp_scores, k=topk, dim=-1, sorted=use_sorted - ) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - if routed_scaling_factor != 1.0: - topk_weights = topk_weights * routed_scaling_factor - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) - - -# --8<-- [start:grouped_topk] -@CustomOp.register("grouped_topk") -class GroupedTopk(CustomOp): - """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" - - # --8<-- [end:grouped_topk] - - def __init__( - self, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - num_fused_shared_experts: int = 0, - ) -> None: - super().__init__() - self.native_impl = grouped_topk - self.topk = topk - self.renormalize = renormalize - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.scoring_func = scoring_func - self.routed_scaling_factor = routed_scaling_factor - self.num_fused_shared_experts = num_fused_shared_experts - - def forward_native( - self, - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - e_score_correction_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.native_impl( - hidden_states, - gating_output, - self.topk, - self.renormalize, - self.num_expert_group, - self.topk_group, - self.scoring_func, - self.routed_scaling_factor, - e_score_correction_bias, - ) - - def forward_cuda( - self, - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - e_score_correction_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.forward_native( - hidden_states, gating_output, e_score_correction_bias - ) - - def forward_hip( - self, - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - e_score_correction_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if rocm_aiter_ops.is_fused_moe_enabled(): - if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): - assert self.num_fused_shared_experts == 0 - return rocm_aiter_grouped_topk( - hidden_states, - gating_output, - self.topk, - self.renormalize, - self.num_expert_group, - self.topk_group, - self.scoring_func, - self.routed_scaling_factor, - e_score_correction_bias, - self.num_fused_shared_experts, - ) - else: - return self.forward_native( - hidden_states, gating_output, e_score_correction_bias - ) - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, -) -> torch.Tensor: - """ - Map the logical expert ids to physical expert ids - and record the expert load metrics. - - This will select a pseudo-random replica for each logical expert. - Only used for EPLB. - - Args: - topk_ids: The logical expert ids. - expert_load_view: The expert load view. - logical_to_physical_map: The logical to physical map. - logical_replica_count: The logical replica count. - - Returns: - The physical expert ids. - """ - - # 1. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - # Use (token position) modulo (replica count) - # to deterministically choose a replica - replica_count = logical_replica_count[topk_ids_long] - # Flatten-position based index, reshaped back to `topk_ids` shape - pos_indices = torch.arange( - topk_ids.numel(), device=topk_ids.device, dtype=torch.long - ).reshape_as(topk_ids) - # Compute pseudo-random indices by modulo - replica_indices = (pos_indices % replica_count).unsqueeze(-1) - physical_ids = ( - logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1) - ) - - topk_ids = physical_ids - - # 2. Record expert load metrics. - - # TODO(bowen): When using `FusedMoEModularKernel`, this - # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert - # token count, in some cases directly from the kernel. - # However, now there are many code paths not using - # the modular kernel, e.g. calling `fused_experts`, - # so we decide to keep the logic here. - # - # If later refactor moved all the MoE kernel calls - # to the modular kernel, we can move this logic there - # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - # `torch.bincount` is not compilable, so use `scatter_add_` instead. - topk_ids_flatten = topk_ids.flatten() - expert_load_view.scatter_add_( - dim=0, - index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten).to(expert_load_view), - ) - return topk_ids - - -def fused_grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - e_score_correction_bias: torch.Tensor, - num_expert_group: int = 0, - topk_group: int = 0, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - - if scoring_func == "sigmoid": - # Fully fused kernel path for sigmoid - topk_values, topk_indices = ops.grouped_topk( - gating_output, # raw logits - num_expert_group, - topk_group, - topk, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - 1, # scoring_func=1 for sigmoid - ) - elif scoring_func == "softmax": - # Apply softmax in Python, then use fused kernel - # TODO: Add support for softmax in kernel - scores = torch.softmax(gating_output, dim=-1) - topk_values, topk_indices = ops.grouped_topk( - scores, # pre-computed scores - num_expert_group, - topk_group, - topk, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - 0, # scoring_func=0 (no activation, scores already computed) - ) - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - # Fused kernel outputs float32 values and int32 indices directly - return topk_values, topk_indices - - def inplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 389ccf358c56..9d1ec7fd8027 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -10,13 +10,13 @@ FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import ( - FusedMoERouter, -) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 2d98433e4db0..64de30ce6085 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -12,11 +12,13 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel, FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 702052c9612e..e24d60150d60 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -21,7 +21,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -31,14 +31,24 @@ FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( + FusedMoEModularMethod, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( RoutedExpertsCapturer, ) -from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator +from vllm.model_executor.layers.fused_moe.router.router_factory import ( + create_fused_moe_router, +) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) @@ -52,31 +62,6 @@ ) from vllm.v1.worker.ubatching import dbo_current_ubatch_id -if current_platform.is_cuda_alike(): - from .fused_moe import eplb_map_to_physical_and_record -else: - - def _eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> torch.Tensor: - # CPU fallback: no EPLB so just return as is - return topk_ids - - eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record -from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk -from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( - FusedMoEMethodBase, -) -from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( - FusedMoEModularMethod, -) -from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( - UnquantizedFusedMoEMethod, -) - logger = init_logger(__name__) @@ -288,23 +273,6 @@ def maybe_roundup_hidden_size( return hidden_size -class FusedMoERouterImpl(FusedMoERouter): - def __init__(self, layer: "FusedMoE"): - super().__init__() - self.layer = layer - - @property - def routing_method_type(self) -> RoutingMethodType: - return self.layer.routing_method_type - - def select_experts( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.layer._select_experts(hidden_states, router_logits) - - # --8<-- [start:fused_moe] @CustomOp.register("fused_moe") class FusedMoE(CustomOp): @@ -440,9 +408,7 @@ def __init__( self.layer_name = prefix self.enable_eplb = enable_eplb - self.expert_load_view: torch.Tensor | None = None - self.logical_to_physical_map: torch.Tensor | None = None - self.logical_replica_count: torch.Tensor | None = None + self.eplb_state = EplbLayerState() self.expert_placement_strategy: ExpertPlacementStrategy = ( vllm_config.parallel_config.expert_placement_strategy ) @@ -538,6 +504,8 @@ def __init__( self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize + + # TODO(bnell): these attributes are only used by cpu/xpu/mxfp4 self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None @@ -547,46 +515,11 @@ def __init__( self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias + # TODO(bnell): end attributes + self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - self._grouped_topk_impl: GroupedTopk | None = None - if self.use_grouped_topk: - assert self.num_expert_group is not None - assert self.topk_group is not None - self._grouped_topk_impl = GroupedTopk( - topk=self.top_k, - renormalize=self.renormalize, - num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - num_fused_shared_experts=self.num_fused_shared_experts, - ) - - if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError( - "Only softmax scoring function is supported for non-grouped topk." - ) - - # ToDo: Better logic to determine the routing method type - if routing_method_type is not None: - self.routing_method_type: RoutingMethodType = routing_method_type - else: - if scoring_func == "sigmoid": - if self.use_grouped_topk: - self.routing_method_type = RoutingMethodType.DeepSeekV3 - elif self.top_k == 1: - self.routing_method_type = RoutingMethodType.Llama4 - elif self.scoring_func == "softmax": - self.routing_method_type = ( - RoutingMethodType.Renormalize - if not self.renormalize - else RoutingMethodType.RenormalizeNaive - ) - else: - self.routing_method_type = RoutingMethodType.TopK - self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, @@ -637,8 +570,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # If you plan to add support for more quantization methods, # please refer to the implementation in `Fp8MoEMethod`. raise NotImplementedError( - f"EPLB is not supported {self.quant_method.__class__.__name__}. " - "EPLB is only supported for FP8 quantization for now." + f"EPLB is not supported {self.quant_method.__class__.__name__}." ) moe_quant_params = { @@ -663,7 +595,38 @@ def _get_quant_method() -> FusedMoEMethodBase: self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None - self.router = FusedMoERouterImpl(self) + # TODO(bnell): in next PR move capture back to layer + capture: Callable[[torch.Tensor], None] | None = None + if ( + self.vllm_config.model_config is not None + and self.vllm_config.model_config.enable_return_routed_experts + ): + # In dummy runs, the capturer is not initialized. + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capture = lambda topk_ids: capturer.capture(self.layer_id, topk_ids) + + self.router = create_fused_moe_router( + top_k=top_k, + global_num_experts=self.global_num_experts, + eplb_state=self.eplb_state, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=self.num_fused_shared_experts, + enable_eplb=enable_eplb, + # TODO(bnell): once we can construct the MK at init time, we + # can make this a value. + indices_type_getter=lambda: self.quant_method.topk_indices_dtype, + routing_method_type=routing_method_type, + capture=capture, + ) + self.routing_method_type: RoutingMethodType = self.router.routing_method_type # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. @@ -1492,9 +1455,9 @@ def set_eplb_state( This is used later in forward pass, where we get the expert mapping and record the load metrics in `expert_load_view`. """ - self.expert_load_view = expert_load_view[moe_layer_idx] - self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] - self.logical_replica_count = logical_replica_count[moe_layer_idx] + self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx] + self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx] def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: @@ -1535,130 +1498,6 @@ def ensure_dp_chunking_init(self): device=torch.cuda.current_device(), ) - def _select_experts( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Route the input hidden states to the top-k experts based on the - router logits. - - Returns: - (topk_weights, topk_ids) - (tuple[torch.Tensor, torch.Tensor]): - The weights and expert ids. - - **Compatibility**: When EPLB is not enabled, the returned ids are - equivalent to global logical ids, so should be compatible with - plain MoE implementations without redundant experts. - """ - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, - fused_topk_bias, - ) - - if self.enable_eplb: - if self.quant_method.supports_eplb: - if self.expert_load_view is None: - raise ValueError( - "enable_eplb=True requiere expert_load_view != None" - ) - if self.logical_to_physical_map is None: - raise ValueError( - "enable_eplb=True requiere logical_to_physical_map != None" - ) - if self.logical_replica_count is None: - raise ValueError( - "enable_eplb=True requiere logical_replica_count != None" - ) - else: - raise NotImplementedError( - f"EPLB is not supported for {self.quant_method.method_name}." - ) - - def valid_grouping() -> bool: - # Check if num_experts is greater than num_expert_group - # and is divisible by num_expert_group - num_experts = router_logits.shape[-1] - if num_experts <= self.num_expert_group: - return False - return num_experts % self.num_expert_group == 0 - - indices_type = self.quant_method.topk_indices_dtype - - # Check if we should use a routing simulation strategy - routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY - if routing_strategy != "": - topk_weights, topk_ids = RoutingSimulator.simulate_routing( - hidden_states=hidden_states, - router_logits=router_logits, - strategy_name=routing_strategy, - top_k=self.top_k, - indices_type=indices_type, - ) - - # DeepSeekv2 uses grouped_top_k - elif self.use_grouped_topk and valid_grouping(): - assert self._grouped_topk_impl is not None - topk_weights, topk_ids = self._grouped_topk_impl( - hidden_states=hidden_states, - gating_output=router_logits, - e_score_correction_bias=self.e_score_correction_bias, - ) - elif self.e_score_correction_bias is not None: - topk_weights, topk_ids = fused_topk_bias( - hidden_states=hidden_states, - gating_output=router_logits, - e_score_correction_bias=self.e_score_correction_bias.data, - topk=self.top_k, - renormalize=self.renormalize, - ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor - elif self.custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - indices_type=indices_type, - ) - else: - topk_weights, topk_ids = self.custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - ) - - if self.enable_eplb: - topk_ids = eplb_map_to_physical_and_record( - topk_ids=topk_ids, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, - ) - - if (indices_type is not None) and topk_ids.dtype != indices_type: - topk_ids = topk_ids.to(dtype=indices_type) - - assert topk_ids.dtype == indices_type or indices_type is None - - if ( - self.vllm_config.model_config is not None - and self.vllm_config.model_config.enable_return_routed_experts - ): - # In dummy runs, the capturer is not initialized. - capturer = RoutedExpertsCapturer.get_instance() - if capturer is not None: # in dummmy_run may be None - capturer.capture( # noqa - layer_id=self.layer_id, - topk_ids=topk_ids, - ) - - return topk_weights, topk_ids - def must_reduce_shared_expert_outputs(self) -> bool: """ The shared_experts are typically computed using the RowParallelLinear @@ -1761,8 +1600,12 @@ def forward_impl_chunked( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None - assert self.batched_hidden_states.dtype == full_hidden_states.dtype - assert self.batched_router_logits.dtype == full_router_logits.dtype + assert self.batched_hidden_states.dtype == full_hidden_states.dtype, ( + f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}" + ) + assert self.batched_router_logits.dtype == full_router_logits.dtype, ( + f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}" + ) # Check size compatibility. assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) @@ -2080,15 +1923,8 @@ def extra_repr(self) -> str: f"tp_size={self.tp_size},\n" f"ep_size={self.ep_size}, " f"reduce_results={self.reduce_results}, " - f"renormalize={self.renormalize}, " - f"use_grouped_topk={self.use_grouped_topk}" ) - if self.use_grouped_topk: - s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 - - s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 - return s diff --git a/vllm/model_executor/layers/fused_moe/router/__init__.py b/vllm/model_executor/layers/fused_moe/router/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py new file mode 100644 index 000000000000..683f7188c165 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import abstractmethod +from collections.abc import Callable + +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) +from vllm.platforms import current_platform + +if current_platform.is_cuda_alike(): + + @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> torch.Tensor: + """ + Map the logical expert ids to physical expert ids + and record the expert load metrics. + + This will select a pseudo-random replica for each logical expert. + Only used for EPLB. + + Args: + topk_ids: The logical expert ids. + expert_load_view: The expert load view. + logical_to_physical_map: The logical to physical map. + logical_replica_count: The logical replica count. + + Returns: + The physical expert ids. + """ + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + # Use (token position) modulo (replica count) + # to deterministically choose a replica + replica_count = logical_replica_count[topk_ids_long] + # Flatten-position based index, reshaped back to `topk_ids` shape + pos_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.long + ).reshape_as(topk_ids) + # Compute pseudo-random indices by modulo + replica_indices = (pos_indices % replica_count).unsqueeze(-1) + physical_ids = ( + logical_to_physical_map[topk_ids_long] + .gather(-1, replica_indices) + .squeeze(-1) + ) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + # `torch.bincount` is not compilable, so use `scatter_add_` instead. + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + ) + return topk_ids +else: + + def eplb_map_to_physical_and_record( + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> torch.Tensor: + # CPU fallback: no EPLB so just return as is + return topk_ids + + +class BaseRouter(FusedMoERouter): + """ + Base router class that provides common functionality for all router implementations. + + This class implements the template method pattern where select_experts() handles + common pre-processing and post-processing, delegating the actual routing logic + to the abstract _compute_routing() method. + """ + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + enable_eplb: bool = False, + # TODO(bnell): Once the MK is constructed at layer init time, we + # can make this a plain value instead of a callback. + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + """ + Note: the indices dtype might not be available at router construction + time, so we need to supply a callback to get it at runtime. This is + because the indices type is supplied by modular kernels which are + created after MoE layer/router construction. + """ + super().__init__() + self.top_k = top_k + self.global_num_experts = global_num_experts + self.eplb_state = eplb_state + self.enable_eplb = enable_eplb + self.indices_type_getter = indices_type_getter + self.capture: Callable[[torch.tensor], None] | None = None + + def _validate_eplb_state(self) -> None: + """Validate that EPLB state is properly initialized if EPLB is enabled.""" + if self.enable_eplb: + if self.eplb_state.expert_load_view is None: + raise ValueError("enable_eplb=True requires expert_load_view != None") + if self.eplb_state.logical_to_physical_map is None: + raise ValueError( + "enable_eplb=True requires logical_to_physical_map != None" + ) + if self.eplb_state.logical_replica_count is None: + raise ValueError( + "enable_eplb=True requires logical_replica_count != None" + ) + + def _get_indices_type(self) -> torch.dtype | None: + """Get the desired indices dtype from the getter function.""" + return ( + self.indices_type_getter() if self.indices_type_getter is not None else None + ) + + def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor: + """Apply EPLB mapping to convert logical expert IDs to physical expert IDs.""" + if self.enable_eplb: + assert self.eplb_state.expert_load_view is not None + assert self.eplb_state.logical_to_physical_map is not None + assert self.eplb_state.logical_replica_count is not None + return eplb_map_to_physical_and_record( + topk_ids=topk_ids, + expert_load_view=self.eplb_state.expert_load_view, + logical_to_physical_map=self.eplb_state.logical_to_physical_map, + logical_replica_count=self.eplb_state.logical_replica_count, + ) + return topk_ids + + def _convert_indices_dtype( + self, topk_ids: torch.Tensor, indices_type: torch.dtype | None + ) -> torch.Tensor: + """Convert topk_ids to the desired dtype if needed.""" + if (indices_type is not None) and topk_ids.dtype != indices_type: + topk_ids = topk_ids.to(dtype=indices_type) + + assert topk_ids.dtype == indices_type or indices_type is None + return topk_ids + + @abstractmethod + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the actual routing logic. + + This method must be implemented by subclasses to provide the specific + routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.). + + Args: + hidden_states: Input hidden states + router_logits: Router logits for expert selection + indices_type: Desired dtype for expert indices (may be None) + + Returns: + tuple of (topk_weights, topk_ids) + """ + raise NotImplementedError + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + This method implements the template method pattern: + 1. Validates EPLB state + 2. Gets indices type + 3. Calls _compute_routing() to get topk_weights and topk_ids + 4. Applies EPLB mapping if enabled + 5. Converts indices dtype if needed + + Returns: + (topk_weights, topk_ids) + (tuple[torch.Tensor, torch.Tensor]): + The weights and expert ids computation result. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + # Step 1: Validate EPLB state + self._validate_eplb_state() + + # Step 2: Get indices type. + indices_type = self._get_indices_type() + + # Step 3: Compute routing (delegated to subclass) + topk_weights, topk_ids = self._compute_routing( + hidden_states, router_logits, indices_type + ) + + # Step 4: Apply EPLB mapping + topk_ids = self._apply_eplb_mapping(topk_ids) + + # Step 5: Convert indices dtype + topk_ids = self._convert_indices_dtype(topk_ids, indices_type) + + # TODO(bnell): temporary hack until select_experts is moved into FusedMoE + if self.capture is not None: + self.capture(topk_ids) + + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py new file mode 100644 index 000000000000..a19dfb62b53e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter + + +class CustomRoutingRouter(BaseRouter): + """Router using a custom user-provided routing function.""" + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + custom_routing_function: Callable, + renormalize: bool = True, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + self.custom_routing_function = custom_routing_function + self.renormalize = renormalize + + @property + def routing_method_type(self) -> RoutingMethodType: + return RoutingMethodType.Custom + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute routing using the custom routing function.""" + topk_weights, topk_ids = self.custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + ) + + return topk_weights.to(torch.float32), topk_ids diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py similarity index 100% rename from vllm/model_executor/layers/fused_moe/fused_moe_router.py rename to vllm/model_executor/layers/fused_moe/router/fused_moe_router.py diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py new file mode 100644 index 000000000000..460385ace46b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter + + +def fused_topk_bias( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor, + topk: int, + renormalize: bool, +): + n_routed_experts = gating_output.shape[-1] + scores = gating_output.softmax(dim=-1) + scores_for_choice = scores.view( + -1, n_routed_experts + ) + e_score_correction_bias.unsqueeze(0) + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] + topk_weights = scores.gather(1, topk_indices) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights.to(torch.float32), topk_indices.to(torch.int32) + + +class FusedTopKBiasRouter(BaseRouter): + """Router using fused top-k with e_score_correction_bias.""" + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + e_score_correction_bias: torch.Tensor, + renormalize: bool = True, + routed_scaling_factor: float = 1.0, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + self.e_score_correction_bias = e_score_correction_bias + self.renormalize = renormalize + self.routed_scaling_factor = routed_scaling_factor + + @property + def routing_method_type(self) -> RoutingMethodType: + return ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute routing using fused top-k with bias.""" + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, + ) + + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py new file mode 100644 index 000000000000..25b360c528dc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +import vllm._custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter + + +def vllm_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + ) + + return topk_weights, topk_indices + + +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax + return vllm_topk_softmax + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + indices_type: torch.dtype | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + + M, _ = hidden_states.size() + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device, + ) + token_expert_indices = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize + ) + + return topk_weights, topk_ids, token_expert_indices + + +class FusedTopKRouter(BaseRouter): + """Default router using standard fused top-k routing.""" + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + scoring_func: str = "softmax", + renormalize: bool = True, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + assert scoring_func == "softmax", "FusedTopKRouter only supports softmax." + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + self.renormalize = renormalize + + @property + def routing_method_type(self) -> RoutingMethodType: + return ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute routing using standard fused top-k.""" + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + indices_type=indices_type, + ) + + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py new file mode 100644 index 000000000000..e5b6de02f4ca --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from functools import partial + +import torch + +from vllm import _custom_ops as ops +from vllm import envs as envs +from vllm._aiter_ops import rocm_aiter_ops +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_grouped_topk, +) +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, +) +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk +from vllm.model_executor.utils import maybe_disable_graph_partition +from vllm.platforms import current_platform + + +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + + if scoring_func == "sigmoid": + # Fully fused kernel path for sigmoid + topk_values, topk_indices = ops.grouped_topk( + gating_output, # raw logits + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + 1, # scoring_func=1 for sigmoid + ) + elif scoring_func == "softmax": + # Apply softmax in Python, then use fused kernel + # TODO: Add support for softmax in kernel + scores = torch.softmax(gating_output, dim=-1) + topk_values, topk_indices = ops.grouped_topk( + scores, # pre-computed scores + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + 0, # scoring_func=0 (no activation, scores already computed) + ) + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + # Fused kernel outputs float32 values and int32 indices directly + return topk_values, topk_indices + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if ( + envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + and current_platform.is_cuda() + and num_expert_group <= 32 + and topk <= 32 + and e_score_correction_bias is not None + ): + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + ) + + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) + else: + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +# --8<-- [start:grouped_topk] +@CustomOp.register("grouped_topk") +class GroupedTopk(CustomOp): + """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" + + # --8<-- [end:grouped_topk] + + def __init__( + self, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + num_fused_shared_experts: int = 0, + ) -> None: + super().__init__() + self.native_impl = grouped_topk + self.topk = topk + self.renormalize = renormalize + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor + self.num_fused_shared_experts = num_fused_shared_experts + + def forward_native( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.native_impl( + hidden_states, + gating_output, + self.topk, + self.renormalize, + self.num_expert_group, + self.topk_group, + self.scoring_func, + self.routed_scaling_factor, + e_score_correction_bias, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native( + hidden_states, gating_output, e_score_correction_bias + ) + + def forward_hip( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + assert self.num_fused_shared_experts == 0 + return rocm_aiter_grouped_topk( + hidden_states, + gating_output, + self.topk, + self.renormalize, + self.num_expert_group, + self.topk_group, + self.scoring_func, + self.routed_scaling_factor, + e_score_correction_bias, + self.num_fused_shared_experts, + ) + else: + return self.forward_native( + hidden_states, gating_output, e_score_correction_bias + ) + + +class GroupedTopKRouter(BaseRouter): + """Router using grouped top-k routing (e.g., DeepSeekV2/V3).""" + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + num_expert_group: int, + topk_group: int, + renormalize: bool = True, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + routing_method_type: RoutingMethodType | None = None, + ): + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.renormalize = renormalize + self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor + self.e_score_correction_bias = e_score_correction_bias + self.num_fused_shared_experts = num_fused_shared_experts + + # Determine routing method type + if routing_method_type is not None: + self._routing_method_type = routing_method_type + elif scoring_func == "sigmoid": + self._routing_method_type = RoutingMethodType.DeepSeekV3 + else: + self._routing_method_type = RoutingMethodType.TopK + + @property + def routing_method_type(self) -> RoutingMethodType: + return self._routing_method_type + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute routing using grouped top-k.""" + + def valid_grouping() -> bool: + # Check if num_experts is greater than num_expert_group + # and is divisible by num_expert_group + num_experts = router_logits.shape[-1] + if num_experts <= self.num_expert_group: + return False + return num_experts % self.num_expert_group == 0 + + if not valid_grouping(): + if self.e_score_correction_bias is not None: + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + e_score_correction_bias=self.e_score_correction_bias.data, + topk=self.top_k, + renormalize=self.renormalize, + ) + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor + else: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + indices_type=indices_type, + ) + return topk_weights, topk_ids + + # Select grouped_topk implementation + if rocm_aiter_ops.is_fused_moe_enabled(): + if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled(): + assert self.num_fused_shared_experts == 0 + grouped_topk_impl = partial( + rocm_aiter_grouped_topk, + num_fused_shared_experts=self.num_fused_shared_experts, + ) + else: + grouped_topk_impl = grouped_topk + + topk_weights, topk_ids = grouped_topk_impl( + hidden_states=hidden_states, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + ) + + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py new file mode 100644 index 000000000000..8818373d8b0d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +import vllm.envs as envs +from vllm.distributed.eplb.eplb_state import EplbLayerState +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( + CustomRoutingRouter, +) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + FusedTopKBiasRouter, +) +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( + FusedTopKRouter, +) +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import ( + GroupedTopKRouter, +) +from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import ( + RoutingSimulatorRouter, +) + +EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState() + + +def create_fused_moe_router( + # common parameters + top_k: int, + global_num_experts: int, + renormalize: bool = True, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + routing_method_type: RoutingMethodType | None = None, + # grouped topk parameters + use_grouped_topk: bool = False, + num_expert_group: int | None = None, + topk_group: int | None = None, + scoring_func: str = "softmax", + num_fused_shared_experts: int = 0, + # grouped topk + fused topk bias parameters + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + # custom routing paramaters + custom_routing_function: Callable | None = None, + # eplb parameters + enable_eplb: bool = False, + eplb_state: EplbLayerState = EMPTY_EPLB_STATE, + capture: Callable[[torch.tensor], None] | None = None, +) -> FusedMoERouter: + """ + Factory function to create the appropriate FusedMoERouter subclass based on + the provided parameters. + + The selection logic follows this priority order: + 1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set + 2. GroupedTopKRouter - if use_grouped_topk is True + 3. CustomRoutingRouter - if custom_routing_function is not None + 4. FusedTopKBiasRouter - if e_score_correction_bias is not None + 5. FusedTopKRouter - default fallback + + Common arguments: + top_k: Number of experts to select per token + global_num_experts: Total number of experts in the model + renormalize: Whether to renormalize the routing weights + indices_type_getter: Function to get the desired indices dtype + routing_method_type: Optional explicit routing method type + + Grouped topk arguments: + use_grouped_topk: Whether to use grouped top-k routing + num_expert_group: Number of expert groups (for grouped routing) + topk_group: Top-k within each group (for grouped routing) + scoring_func: Scoring function to use ("softmax" or "sigmoid") + num_fused_shared_experts: Number of fused shared experts (for ROCm AITER) + + Grouped topk and fused topk bias arguments: + routed_scaling_factor: Scaling factor for routed weights + e_score_correction_bias: Optional bias correction for expert scores + + Custom routing arguments: + custom_routing_function: Optional custom routing function + + EPLB arguments: + enable_eplb: Whether EPLB is enabled + eplb_state: EPLB (Expert Parallelism Load Balancing) state + + Returns: + An instance of the appropriate FusedMoERouter subclass + """ + router: BaseRouter + + routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + if routing_strategy != "": + router = RoutingSimulatorRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + # TODO(bnell): this is temporary until select_experts is + # separated from apply. + router.capture = capture + return router + + if use_grouped_topk: + assert custom_routing_function is None + if num_expert_group is None or topk_group is None: + raise ValueError( + "num_expert_group and topk_group must be provided when " + "use_grouped_topk is True" + ) + router = GroupedTopKRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + num_expert_group=num_expert_group, + topk_group=topk_group, + renormalize=renormalize, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + num_fused_shared_experts=num_fused_shared_experts, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + routing_method_type=routing_method_type, + ) + router.capture = capture + return router + + if custom_routing_function is not None: + router = CustomRoutingRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + custom_routing_function=custom_routing_function, + renormalize=renormalize, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + router.capture = capture + return router + + if scoring_func != "softmax": + raise ValueError( + "Only softmax scoring function is supported for non-grouped topk." + ) + + if e_score_correction_bias is not None: + router = FusedTopKBiasRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + e_score_correction_bias=e_score_correction_bias, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + router.capture = capture + return router + + router = FusedTopKRouter( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + renormalize=renormalize, + scoring_func=scoring_func, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + router.capture = capture + return router diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py similarity index 86% rename from vllm/model_executor/layers/fused_moe/routing_simulator.py rename to vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index a01cdc4908b9..f8e46371841a 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -1,20 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Token-to-Expert Routing Simulator - -This module provides a framework for simulating and testing different -token-to-expert routing strategies for Mixture of Experts (MoE) models. -It supports routing logic customization and includes example implementations -like uniform random routing. -""" - from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any import torch +import vllm.envs as envs +from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType +from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter logger = init_logger(__name__) @@ -308,3 +304,44 @@ def simulate_routing( top_k=top_k, indices_type=indices_type, ) + + +class RoutingSimulatorRouter(BaseRouter): + """Router that uses routing simulation strategies for testing/debugging.""" + + def __init__( + self, + top_k: int, + global_num_experts: int, + eplb_state: EplbLayerState, + enable_eplb: bool = False, + indices_type_getter: Callable[[], torch.dtype | None] | None = None, + ): + super().__init__( + top_k=top_k, + global_num_experts=global_num_experts, + eplb_state=eplb_state, + enable_eplb=enable_eplb, + indices_type_getter=indices_type_getter, + ) + + @property + def routing_method_type(self) -> RoutingMethodType: + return RoutingMethodType.Simulated + + def _compute_routing( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + indices_type: torch.dtype | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Use routing simulator to compute routing.""" + routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=routing_strategy, + top_k=self.top_k, + indices_type=indices_type, + ) + return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 351d631442c1..f8489ab06db1 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, @@ -32,6 +31,9 @@ make_unquantized_moe_kernel, select_unquantized_moe_backend, ) +from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( + FusedMoERouter, +) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -312,9 +314,9 @@ def forward_cpu( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( layer.enable_eplb is not False - or layer.expert_load_view is not None - or layer.logical_to_physical_map is not None - or layer.logical_replica_count is not None + or layer.eplb_state.expert_load_view is not None + or layer.eplb_state.logical_to_physical_map is not None + or layer.eplb_state.logical_replica_count is not None ): raise NotImplementedError("Expert load balancing is not supported for CPU.") @@ -346,9 +348,9 @@ def forward_xpu( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( layer.enable_eplb is not False - or layer.expert_load_view is not None - or layer.logical_to_physical_map is not None - or layer.logical_replica_count is not None + or layer.eplb_state.expert_load_view is not None + or layer.eplb_state.logical_to_physical_map is not None + or layer.eplb_state.logical_replica_count is not None ): raise NotImplementedError("Expert load balancing is not supported for XPU.") return layer.ipex_fusion( diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 829c08e9d7f2..b1fb67208e86 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,12 +10,12 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1d2334f3933a..542a7281051d 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,11 +6,11 @@ import torch from packaging import version +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 423062c6103c..13a123ba6026 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -22,6 +22,7 @@ FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, + FusedMoERouter, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) @@ -40,7 +41,6 @@ MarlinExperts, fused_marlin_moe, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 37e6020cb2a9..c3e7a812e023 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -10,12 +10,12 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + FusedMoERouter, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6c3412df8673..14ed28630680 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,13 +23,13 @@ FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, + FusedMoERouter, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 1c03e5243a85..2b1537089edd 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -12,11 +12,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 8cb7b83b422f..00cd635b26a8 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,12 +10,12 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 475bd853676e..9b2198b715c5 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -8,10 +8,10 @@ from torch.nn import Module from vllm._ipex_ops import ipex_ops as ops -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_moe_router import ( +from vllm.model_executor.layers.fused_moe import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 5eac19a17e92..4c9fac39ca7e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -13,11 +13,11 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.layer import Attention from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index d5d94082587f..5d29fd01c072 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,12 +6,12 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8e050b795f94..ecd13e5c715a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -14,6 +14,7 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + FusedMoERouter, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( @@ -27,7 +28,6 @@ MarlinExperts, fused_marlin_moe, ) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, @@ -936,9 +936,9 @@ def apply( layer.apply_router_weight_on_input, layer.scoring_func, layer.activation, - layer.expert_load_view, - layer.logical_to_physical_map, - layer.logical_replica_count, + layer.eplb_state.expert_load_view, + layer.eplb_state.logical_to_physical_map, + layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." if ( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6b731314825a..76ecd055c95e 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -548,7 +548,7 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - topk_weights, topk_ids = layer.select_experts( + topk_weights, topk_ids = router.select_experts( hidden_states=x, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 239adb384708..3544c2442298 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -10,12 +10,12 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoERouter from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 8c8cb73b8d6e..34da3e7c772c 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -201,6 +201,7 @@ def __init__( e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, + router_logits_dtype=torch.float32, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 75be587eedb2..2be22e0e3739 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -269,6 +269,7 @@ def __init__( quant_config=quant_config, e_score_correction_bias=self.e_score_correction_bias[0], prefix=f"{prefix}.text_experts", + router_logits_dtype=torch.float32, ) else: self.text_experts = Ernie4_5_VLMoeMLP( @@ -306,6 +307,7 @@ def __init__( quant_config=quant_config, e_score_correction_bias=self.e_score_correction_bias[1], prefix=f"{prefix}.vision_experts", + router_logits_dtype=torch.float32, ) else: self.vision_experts = Ernie4_5_VLMoeMLP(