diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py index 8b3a64b9c134..46f63408f467 100644 --- a/tests/distributed/test_expert_placement.py +++ b/tests/distributed/test_expert_placement.py @@ -3,7 +3,9 @@ import pytest -from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.model_executor.layers.fused_moe.expert_map_manager import ( + determine_expert_map, +) def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts): diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 544106282585..c4efe0a04796 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1588,7 +1588,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( layer.apply_router_weight_on_input = False layer.routed_scaling_factor = None layer.shared_experts = None - layer._maybe_init_expert_routing_tables = lambda: None + layer._expert_routing_tables = lambda: None quant_method.process_weights_after_loading(layer) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 92126171a17b..5aafb89589fd 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -10,7 +10,9 @@ import torch from vllm.model_executor.layers.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.model_executor.layers.fused_moe.expert_map_manager import ( + determine_expert_map, +) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_permute, moe_permute_unpermute_supported, diff --git a/vllm/model_executor/layers/fused_moe/expert_map_manager.py b/vllm/model_executor/layers/fused_moe/expert_map_manager.py new file mode 100644 index 000000000000..71f2186ea4dd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/expert_map_manager.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Expert Map Manager for MoE layers. + +This module contains the ExpertMapManager class which manages expert ID +mappings and placement strategies for Expert Parallelism in MoE models. +""" + +import torch + +from vllm.config.parallel import ExpertPlacementStrategy +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import ( + init_aiter_topK_meta_data, +) + +logger = init_logger(__name__) + + +def determine_expert_map( + ep_size: int, + ep_rank: int, + global_num_experts: int, + expert_placement_strategy: ExpertPlacementStrategy = "linear", + num_fused_shared_experts: int = 0, + return_expert_mask: bool = False, +) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Args: + ep_size: The size of the expert parallel group + ep_rank: The rank of the current process in the expert parallel + group + global_num_experts: The total number of experts in the model. + expert_placement_strategy: The expert placement strategy. + num_fused_shared_experts: Number of fused shared experts (for AITER) + return_expert_mask: Whether to return expert mask for AITER + + Returns: + tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + - expert_mask (Optional[torch.Tensor]): A tensor of shape + (global_num_experts + num_fused_shared_experts + 1,) + containing 1 for experts assigned to the current rank + and 0 for sentinel. + Returns None if ep_size is 1. + Used only when AITER MOE is enabled. + """ + from typing import get_args + + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None, None) + + # Distribute experts as evenly as possible to each rank. + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts + + # Create a tensor of size num_experts filled with -1 + expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) + + # Create an expert map for the local experts + if expert_placement_strategy == "linear": + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx : start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + elif expert_placement_strategy == "round_robin": + local_log_experts = torch.arange( + ep_rank, global_num_experts, ep_size, dtype=torch.int32 + ) + + expert_map[local_log_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + else: + raise ValueError( + "Unsupported expert placement strategy " + f"'{expert_placement_strategy}', expected one of " + f"{get_args(ExpertPlacementStrategy)}" + ) + + expert_mask = None + if return_expert_mask: + expert_mask = torch.ones( + (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 + ) + expert_mask[-1] = 0 + expert_mask[:global_num_experts] = expert_map > -1 + expert_map = torch.cat( + ( + expert_map, + torch.tensor( + [local_num_experts + i for i in range(num_fused_shared_experts)], + dtype=torch.int32, + ), + ), + dim=0, + ) + + return (local_num_experts, expert_map, expert_mask) + + +def determine_expert_placement_strategy( + expert_placement_strategy: ExpertPlacementStrategy, + moe_parallel_config: FusedMoEParallelConfig, + num_expert_group: int | None, + num_redundant_experts: int, + enable_eplb: bool, +) -> ExpertPlacementStrategy: + if expert_placement_strategy == "round_robin": + round_robin_supported = ( + (num_expert_group is not None and num_expert_group > 1) + and num_redundant_experts == 0 + and not enable_eplb + ) + + if not round_robin_supported: + logger.warning( + "Round-robin expert placement is only supported for " + "models with multiple expert groups and no redundant " + "experts. Falling back to linear expert placement." + ) + return "linear" + if ( + moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.needs_round_robin_routing_tables + ): + logger.warning( + "Round-robin expert placement currently only supports " + "the DeepEP low-latency or NIXL EP backend, but '%s' was configured. " + "Falling back to linear expert placement.", + moe_parallel_config.all2all_backend, + ) + return "linear" + + return expert_placement_strategy + + +class ExpertMapManager: + """ + Manages expert ID mappings and placement for Expert Parallelism. + + Responsibilities: + - Calculate local vs global expert counts + - Map between global, local, and physical expert IDs + - Manage placement strategies (linear, round_robin) + - Maintain routing tables for round-robin placement + - Support dynamic reconfiguration of EP topology + + When expert_map is required: + - Expert Parallelism (EP) is enabled, i.e., when ep_size > 1 + - EP disabled (ep_size == 1): expert_map is None + * All experts are local to the current rank + * No mapping is needed + - EP enabled (ep_size > 1): expert_map is created + * Maps global expert IDs to local expert IDs + * Shape: (global_num_experts,) + * Contains the local expert index for experts on this rank, -1 for experts + on other ranks + * Used by kernels to handle distributed expert execution + - Kernel support varies: + * Supports expert_map: fused_moe, fused_marlin_moe, fused_humming_moe, + rocm_aiter_fused_moe, deep_gemm_moe, xpu_moe, gpt_oss_triton_kernels_moe + * Does not support: flashinfer_cutlass_moe, fused_batched_moe, most cutlass_moe + variants, trtllm_* kernels + * When kernel doesn't support expert_map: The modular kernel method sets + expert_map=None even if EP is enabled + """ + + def __init__( + self, + max_num_batched_tokens: int, + top_k: int, + global_num_experts: int, + num_redundant_experts: int, + num_expert_group: int | None, + moe_parallel_config: FusedMoEParallelConfig, + placement_strategy: ExpertPlacementStrategy, + enable_eplb: bool, + num_fused_shared_experts: int = 0, + rocm_aiter_enabled: bool = False, + ): + """ + Initialize expert map manager. + + Args: + global_num_experts: Total number of experts across all ranks + moe_parallel_config: MoE parallel configuration (contains ep_size, + ep_rank, backend flags) + placement_strategy: Strategy for placing experts ('linear' or 'round_robin') + num_fused_shared_experts: Number of fused shared experts (for AITER) + rocm_aiter_enabled: Whether ROCm AITER fusion is enabled + """ + self.global_num_experts = global_num_experts + self.moe_parallel_config = moe_parallel_config + self.num_fused_shared_experts = num_fused_shared_experts + self.rocm_aiter_enabled = rocm_aiter_enabled + self.top_k = top_k + self.max_num_batched_tokens = max_num_batched_tokens + + if moe_parallel_config.use_ep: + # Determine expert placement strategy before creating manager + placement_strategy = determine_expert_placement_strategy( + expert_placement_strategy=placement_strategy, + moe_parallel_config=moe_parallel_config, + num_expert_group=num_expert_group, + num_redundant_experts=num_redundant_experts, + enable_eplb=enable_eplb, + ) + + # Determine effective placement strategy + self._placement_strategy = self._determine_placement_strategy( + placement_strategy + ) + + # Calculate expert mappings + self._calculate_expert_maps() + + # Initialize routing tables if needed + self._routing_tables = self._init_routing_tables() + + self._init_aiter_shared_experts_topK_buffer() + + if self.use_ep and self.rocm_aiter_enabled: + expert_mask = self.expert_mask + assert expert_mask is None or torch.all( + (expert_mask == 0) | (expert_mask == 1) + ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." + + # Log EP configuration + if self.use_ep: + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Expert " + "placement strategy: %s. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", + self.ep_rank, + self.ep_size, + self.placement_strategy, + self.local_num_experts, + self.global_num_experts, + self.get_compressed_map_string(), + ) + + def _init_aiter_shared_experts_topK_buffer(self): + if self.num_fused_shared_experts > 0: + dp_size = self.moe_parallel_config.dp_size + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=self.max_num_batched_tokens * dp_size, + is_EP=self.use_ep, + ) + + @property + def use_ep(self) -> int: + return self.moe_parallel_config.use_ep + + @property + def ep_size(self) -> int: + return self.moe_parallel_config.ep_size + + @property + def ep_rank(self) -> int: + return self.moe_parallel_config.ep_rank + + @property + def tp_size(self) -> int: + return self.moe_parallel_config.tp_size + + @property + def tp_rank(self) -> int: + return self.moe_parallel_config.tp_rank + + @property + def local_num_experts(self) -> int: + return self._local_num_experts + + @property + def expert_map(self) -> torch.Tensor | None: + """ + Mapping from global expert ID to local expert ID. + + Returns tensor of shape (global_num_experts,) where: + - expert_map[global_id] = local_id if expert is on this rank + - expert_map[global_id] = -1 if expert is not on this rank + + Returns None if EP is not enabled (ep_size == 1). + """ + return self._expert_map + + @property + def expert_mask(self) -> torch.Tensor | None: + """ + Expert mask for AITER fusion (ROCm-specific). + + Returns tensor of shape (global_num_experts + num_fused_shared + 1,) + where 1 indicates expert is on this rank, 0 otherwise. + """ + return self._expert_mask + + @property + def placement_strategy(self) -> ExpertPlacementStrategy: + """Expert placement strategy ('linear' or 'round_robin').""" + return self._placement_strategy + + @property + def routing_tables( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: + """ + Routing tables for round-robin placement. + + Returns (global_to_physical, physical_to_global, local_to_global) + or None if not using round-robin or tables not needed. + """ + return self._routing_tables + + def map_global_to_local(self, global_id: int) -> int: + """ + Map global expert ID to local expert ID. + + Args: + global_id: Global expert ID (0 to global_num_experts - 1) + + Returns: + Local expert ID (0 to local_num_experts - 1) + + Raises: + ValueError: If expert is not on this rank + """ + if self._expert_map is None: + return global_id + + return self._expert_map[global_id].item() + + def is_local_expert(self, global_id: int) -> bool: + """Check if expert is assigned to this rank.""" + if self._expert_map is None: + return True + return self._expert_map[global_id] != -1 + + def get_local_expert_ids(self) -> list[int]: + """Get list of global IDs for experts on this rank.""" + if self._expert_map is None: + return list(range(self.global_num_experts)) + + return torch.where(self._expert_map != -1)[0].tolist() + + def update( + self, + moe_parallel_config: FusedMoEParallelConfig, + global_num_experts: int, + ) -> None: + """ + Update expert mappings for new EP configuration. + + Used during dynamic reconfiguration (e.g., elastic scaling). + + Args: + global_num_experts: New total number of experts across all ranks + moe_parallel_config: New MoE parallel configuration (contains ep_size, + ep_rank, backend flags) + """ + self.moe_parallel_config = moe_parallel_config + self.global_num_experts = global_num_experts + + if self._expert_map is not None: + device = self._expert_map.device + elif self._expert_mask is not None: + device = self._expert_mask.device + else: + raise AssertionError("_expert_map or _expert_mask must be present.") + + with device: + self._calculate_expert_maps() + self._routing_tables = self._init_routing_tables() + + # Reinitialize AITER buffer if needed and parameters provided + self._init_aiter_shared_experts_topK_buffer() + + def get_compressed_map_string(self) -> str: + """ + Get compressed string representation of expert map for logging. + + Returns string mapping local to global expert IDs. + """ + if self._expert_map is None: + return f"[0..{self.global_num_experts - 1}]" + + global_indices = torch.where(self._expert_map != -1)[0] + local_indices = self._expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices) + ) + + # Private methods + + def _determine_placement_strategy( + self, requested_strategy: ExpertPlacementStrategy + ) -> ExpertPlacementStrategy: + """Determine effective placement strategy based on config.""" + if requested_strategy != "round_robin": + return requested_strategy + + # Round-robin requires specific conditions + if self.ep_size == 1: + return "linear" + + if ( + self.moe_parallel_config.use_all2all_kernels + and not self.moe_parallel_config.needs_round_robin_routing_tables + ): + logger.warning( + "Round-robin placement requires DeepEP-ll or NIXL backend. " + "Falling back to linear." + ) + return "linear" + + return "round_robin" + + def _calculate_expert_maps(self) -> None: + """Calculate expert mappings based on placement strategy.""" + ( + self._local_num_experts, + self._expert_map, + self._expert_mask, + ) = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts, + expert_placement_strategy=self._placement_strategy, + num_fused_shared_experts=self.num_fused_shared_experts, + return_expert_mask=self.rocm_aiter_enabled, + ) + + self._local_num_experts += self.num_fused_shared_experts + + def _init_routing_tables( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: + """ + Ensure routing tables are initialized if needed for round-robin. + + This is a public method that can be called to explicitly initialize + routing tables. It's safe to call multiple times (idempotent). + """ + if self._placement_strategy != "round_robin": + return None + + if not self.moe_parallel_config.needs_round_robin_routing_tables: + return None + + if self._expert_map is None: + return None + + return self._init_round_robin_expert_routing_tables() + + def _init_round_robin_expert_routing_tables( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build routing tables for round-robin placement.""" + assert self.num_fused_shared_experts == 0, ( + "Round robin not supported for AITER." + ) + + global_indices = torch.arange( + self.global_num_experts, + dtype=torch.long, + ) + owner = torch.remainder(global_indices, self.ep_size) + local_index = torch.div(global_indices, self.ep_size, rounding_mode="floor") + + base = self.global_num_experts // self.ep_size + remainder = self.global_num_experts % self.ep_size + physical_offset = owner * base + + if remainder > 0: + remainder_tensor = torch.tensor( + remainder, + dtype=torch.long, + ) + physical_offset = physical_offset + torch.minimum(owner, remainder_tensor) + + global_to_physical = physical_offset + local_index + physical_to_global = torch.empty_like(global_to_physical) + physical_to_global[global_to_physical] = global_indices + + local_global = torch.arange( + self.ep_rank, + self.global_num_experts, + self.ep_size, + dtype=torch.long, + ) + if local_global.numel() != self._local_num_experts: + local_global = local_global[: self._local_num_experts] + + return (global_to_physical, physical_to_global, local_global) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2eef89793a6e..ca18e8588798 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,13 +3,13 @@ from collections.abc import Callable, Iterable from enum import Enum -from typing import Literal, cast, get_args, overload +from typing import Literal, cast, overload import torch from torch.nn.parameter import UninitializedParameter from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( get_dp_group, @@ -26,8 +26,8 @@ FusedMoEQuantConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.experts.rocm_aiter_moe import ( - init_aiter_topK_meta_data, +from vllm.model_executor.layers.fused_moe.expert_map_manager import ( + ExpertMapManager, ) from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, @@ -68,152 +68,6 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -def determine_expert_map( - ep_size: int, - ep_rank: int, - global_num_experts: int, - expert_placement_strategy: ExpertPlacementStrategy = "linear", - num_fused_shared_experts: int = 0, - return_expert_mask: bool = False, -) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: - """ - Calculates how many experts should be assigned to each rank for EP and - creates a mapping from global to local expert index. Experts are - distributed evenly across ranks. Any remaining are assigned to the - last rank. - - Args: - ep_size: The size of the expert parallel group - ep_rank: The rank of the current process in the expert parallel - group - global_num_experts: The total number of experts in the model. - expert_placement_strategy: The expert placement strategy. - - Returns: - tuple[int, Optional[torch.Tensor]]: A tuple containing: - - local_num_experts (int): The number of experts assigned - to the current rank. - - expert_map (Optional[torch.Tensor]): A tensor of shape - (global_num_experts,) mapping from global to local index. - Contains -1 for experts not assigned to the current rank. - Returns None if ep_size is 1. - - expert_mask (Optional[torch.Tensor]): A tensor of shape - (global_num_experts + num_fused_shared_experts + 1,) - containing 1 for experts assigned to the current rank - and 0 for sentinel. - Returns None if ep_size is 1. - Used only when AITER MOE is enabled. - """ - assert ep_size > 0 - if ep_size == 1: - return (global_num_experts, None, None) - - # Distribute experts as evenly as possible to each rank. - base_experts = global_num_experts // ep_size - remainder = global_num_experts % ep_size - local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts - - # Create a tensor of size num_experts filled with -1 - expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) - # Create an expert map for the local experts - if expert_placement_strategy == "linear": - start_idx = ep_rank * base_experts + min(ep_rank, remainder) - expert_map[start_idx : start_idx + local_num_experts] = torch.arange( - 0, local_num_experts, dtype=torch.int32 - ) - elif expert_placement_strategy == "round_robin": - local_log_experts = torch.arange( - ep_rank, global_num_experts, ep_size, dtype=torch.int32 - ) - - expert_map[local_log_experts] = torch.arange( - 0, local_num_experts, dtype=torch.int32 - ) - else: - raise ValueError( - "Unsupported expert placement strategy " - f"'{expert_placement_strategy}', expected one of " - f"{get_args(ExpertPlacementStrategy)}" - ) - - expert_mask = None - if return_expert_mask: - expert_mask = torch.ones( - (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 - ) - expert_mask[-1] = 0 - expert_mask[:global_num_experts] = expert_map > -1 - expert_map = torch.cat( - ( - expert_map, - torch.tensor( - [local_num_experts + i for i in range(num_fused_shared_experts)], - dtype=torch.int32, - ), - ), - dim=0, - ) - - return (local_num_experts, expert_map, expert_mask) - - -def determine_expert_placement_strategy( - expert_placement_strategy: ExpertPlacementStrategy, - moe_parallel_config: FusedMoEParallelConfig, - num_expert_group: int | None, - num_redundant_experts: int, - enable_eplb: bool, -) -> ExpertPlacementStrategy: - if expert_placement_strategy == "round_robin": - round_robin_supported = ( - (num_expert_group is not None and num_expert_group > 1) - and num_redundant_experts == 0 - and not enable_eplb - ) - - if not round_robin_supported: - logger.warning( - "Round-robin expert placement is only supported for " - "models with multiple expert groups and no redundant " - "experts. Falling back to linear expert placement." - ) - return "linear" - if ( - moe_parallel_config.use_all2all_kernels - and not moe_parallel_config.needs_round_robin_routing_tables - ): - logger.warning( - "Round-robin expert placement currently only supports " - "the DeepEP low-latency or NIXL EP backend, but '%s' was configured. " - "Falling back to linear expert placement.", - moe_parallel_config.all2all_backend, - ) - return "linear" - - return expert_placement_strategy - - -def get_compressed_expert_map(expert_map: torch.Tensor) -> str: - """ - Compresses the expert map by removing any -1 entries. - - Args: - expert_map (torch.Tensor): A tensor of shape (global_num_experts,) - mapping from global to local index. Contains -1 for experts not - assigned to the current rank. - - Returns: - str: A string mapping from local to global index. - Using str to support hashing for logging once only. - """ - global_indices = torch.where(expert_map != -1)[0] - local_indices = expert_map[global_indices] - return ", ".join( - f"{local_index.item()}->{global_index.item()}" - for local_index, global_index in zip(local_indices, global_indices) - ) - - # --8<-- [start:fused_moe] @PluggableLayer.register("fused_moe") class FusedMoE(PluggableLayer): @@ -394,55 +248,27 @@ def __init__( "Redundant experts are only supported with EPLB." ) - self.expert_placement_strategy = determine_expert_placement_strategy( - expert_placement_strategy=self.expert_placement_strategy, - moe_parallel_config=self.moe_parallel_config, - num_expert_group=num_expert_group, - num_redundant_experts=num_redundant_experts, - enable_eplb=self.enable_eplb, - ) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self._expert_map: torch.Tensor | None - local_num_experts, expert_map, expert_mask = determine_expert_map( - ep_size=self.ep_size, - ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts, - expert_placement_strategy=self.expert_placement_strategy, - num_fused_shared_experts=self.num_fused_shared_experts, - return_expert_mask=self.rocm_aiter_fmoe_enabled, - ) - self.local_num_experts = local_num_experts - self.register_buffer("_expert_map", expert_map) - self.register_buffer("expert_mask", expert_mask) - self._maybe_init_expert_routing_tables() - logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Expert " - "placement strategy: %s. Local/global" - " number of experts: %s/%s. Experts local to global index map:" - " %s.", - self.ep_rank, - self.ep_size, - self.expert_placement_strategy, - self.local_num_experts, - self.global_num_experts, - get_compressed_expert_map(self._expert_map), - ) - else: - self.local_num_experts, self._expert_map, self.expert_mask = ( - self.global_num_experts, - None, - None, - ) + # Create ExpertMapManager to handle expert mapping and placement for EP. + # See ExpertMapManager for a detailed description of what it does and when + # it is required. + self.expert_map_manager = ExpertMapManager( + max_num_batched_tokens=max_num_batched_tokens, + top_k=top_k, + global_num_experts=self.global_num_experts, + num_redundant_experts=num_redundant_experts, + num_expert_group=num_expert_group, + moe_parallel_config=self.moe_parallel_config, + placement_strategy=self.expert_placement_strategy, + enable_eplb=self.enable_eplb, + num_fused_shared_experts=self.num_fused_shared_experts, + rocm_aiter_enabled=self.rocm_aiter_fmoe_enabled, + ) - self.top_k = top_k + self.update_expert_map_info() - self._init_aiter_shared_experts_topK_buffer( - vllm_config=vllm_config, dp_size=dp_size_ - ) - if self.use_ep and self.rocm_aiter_fmoe_enabled: - assert self.expert_mask is None or torch.all( - (expert_mask == 0) | (expert_mask == 1) - ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." + self.top_k = top_k assert intermediate_size % self.tp_size == 0 intermediate_size_per_partition = intermediate_size // self.tp_size @@ -510,7 +336,7 @@ def __init__( in_dtype=moe_in_dtype, moe_backend=vllm_config.kernel_config.moe_backend, router_logits_dtype=router_logits_dtype, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens, + max_num_tokens=max_num_batched_tokens, has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, @@ -642,11 +468,8 @@ def maybe_init_modular_kernel(self) -> None: return None self.ensure_moe_quant_config_init() - # routing_tables only needed for round-robin expert placement with - # DeepEP all2all backend. - routing_tables = self._maybe_init_expert_routing_tables() prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize( - routing_tables=routing_tables + routing_tables=self._expert_routing_tables() ) if prepare_finalize is not None: logger.debug( @@ -698,16 +521,26 @@ def is_internal_router(self) -> bool: # By default, router/gate is called before FusedMoE forward pass return self.runner.is_internal_router() - def _maybe_init_expert_routing_tables( + def update_expert_map_info(self): + # Update local attributes from ExpertMapManager + self.local_num_experts = self.expert_map_manager.local_num_experts + self.expert_placement_strategy = self.expert_map_manager.placement_strategy + self.register_buffer("_expert_map", self.expert_map_manager.expert_map) + self.register_buffer("expert_mask", self.expert_map_manager.expert_mask) + + # Get routing tables from ExpertMapManager + routing_tables = self.expert_map_manager.routing_tables + if routing_tables is not None: + # Register routing tables as buffers for this layer + global_to_physical, physical_to_global, local_global = routing_tables + self.register_buffer("expert_global_to_physical", global_to_physical) + self.register_buffer("expert_physical_to_global", physical_to_global) + self.register_buffer("expert_local_to_global", local_global) + + def _expert_routing_tables( self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: - # Currently routing_tables only needed for round-robin expert placement - # with DeepEP-ll or NIXL EP all2all backends. - if self.expert_placement_strategy != "round_robin" or ( - not self.moe_parallel_config.needs_round_robin_routing_tables - ): - return None - + # Return cached routing tables if already registered as buffers if hasattr(self, "expert_global_to_physical"): return cast( tuple[torch.Tensor, torch.Tensor, torch.Tensor], @@ -717,85 +550,21 @@ def _maybe_init_expert_routing_tables( self.expert_local_to_global, ), ) + return None - if self._expert_map is None: - return None - - routing_tables = self.ensure_round_robin_expert_routing_tables( + def update_expert_map(self): + # Update ExpertMapManager with new EP configuration + # The moe_parallel_config (including ep_size and ep_rank) + # should already be updated. + # Note: ExpertMapManager.update() recalculates expert maps and + # reinitializes routing tables internally. + self.expert_map_manager.update( + self.moe_parallel_config, global_num_experts=self.global_num_experts, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - local_num_experts=self.local_num_experts, - device=self._expert_map.device, - ) - - global_to_physical, physical_to_global, local_global = routing_tables - self.register_buffer("expert_global_to_physical", global_to_physical) - self.register_buffer("expert_physical_to_global", physical_to_global) - self.register_buffer("expert_local_to_global", local_global) - - return routing_tables - - @staticmethod - def ensure_round_robin_expert_routing_tables( - global_num_experts: int, - ep_size: int, - ep_rank: int, - local_num_experts: int, - device: torch.device | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device_kwargs = {"device": device} if device is not None else {} - global_indices = torch.arange( - global_num_experts, dtype=torch.long, **device_kwargs - ) - owner = torch.remainder(global_indices, ep_size) - local_index = torch.div(global_indices, ep_size, rounding_mode="floor") - base = global_num_experts // ep_size - remainder = global_num_experts % ep_size - physical_offset = owner * base - if remainder > 0: - remainder_tensor = torch.tensor( - remainder, dtype=torch.long, **device_kwargs - ) - physical_offset = physical_offset + torch.minimum(owner, remainder_tensor) - - global_to_physical = physical_offset + local_index - physical_to_global = torch.empty_like(global_to_physical) - physical_to_global[global_to_physical] = global_indices - - local_global = torch.arange( - ep_rank, - global_num_experts, - ep_size, - dtype=torch.long, - **device_kwargs, ) - if local_global.numel() != local_num_experts: - local_global = local_global[:local_num_experts] - - return (global_to_physical, physical_to_global, local_global) - def update_expert_map(self): - # ep_size and ep_rank should already be updated - assert self._expert_map is not None - with self._expert_map.device: - local_num_experts, expert_map, expert_mask = determine_expert_map( - ep_size=self.ep_size, - ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts, - expert_placement_strategy=self.expert_placement_strategy, - num_fused_shared_experts=self.num_fused_shared_experts, - return_expert_mask=self.rocm_aiter_fmoe_enabled, - ) - self.local_num_experts = local_num_experts - self.register_buffer("_expert_map", expert_map) - self.register_buffer("expert_mask", expert_mask) - self._maybe_init_expert_routing_tables() - if self.aiter_fmoe_shared_expert_enabled: - self._init_aiter_shared_experts_topK_buffer( - vllm_config=get_current_vllm_config(), - dp_size=get_dp_group().world_size, - ) + # Update local attributes from ExpertMapManager + self.update_expert_map_info() def _load_per_tensor_weight_scale( self, @@ -1063,26 +832,7 @@ def _load_g_idx( expert_data.copy_(loaded_weight) def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: - if self._expert_map is None: - return expert_id - return self._expert_map[expert_id].item() - - def _init_aiter_shared_experts_topK_buffer( - self, vllm_config: VllmConfig, dp_size: int - ): - if self.num_fused_shared_experts > 0: - init_aiter_topK_meta_data( - n_routed_experts=self.global_num_experts, - n_shared_experts=self.num_fused_shared_experts, - top_k=self.top_k, - tp_rank=self.ep_rank if self.use_ep else self.tp_rank, - tp_size=self.ep_size if self.use_ep else self.tp_size, - shared_experts_score=1.0, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens - * dp_size, - is_EP=self.use_ep, - ) - self.local_num_experts += self.num_fused_shared_experts + return self.expert_map_manager.map_global_to_local(expert_id) @overload def weight_loader( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index d9cfce2c2141..074f216e908e 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -191,7 +191,7 @@ def _setup_kernel( moe_config=self.moe, backend=self.unquantized_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index 2ac2e28f20b5..0755a3b5fe51 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -194,7 +194,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: experts_cls=self.experts_cls, mxfp4_backend=self.mxfp4_backend, shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), ) def apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py index 29c673d0f6e3..46b7db1f0475 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -236,7 +236,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: moe_config=self.moe, experts_cls=self.experts_cls, shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), ) self.moe_kernel.fused_experts.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index bba7e0e7abce..433f7a5c76a7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -336,7 +336,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py index bad5b3895b8f..d39dbee747c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_int8.py @@ -147,7 +147,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index ecd0b54890d1..219a0526c481 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -138,7 +138,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index af4419ccbe98..d3d4a15a3b5e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -786,7 +786,7 @@ def _setup_kernel( moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 7b6f1f9cf6cd..0156744fcc42 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -769,7 +769,7 @@ def _setup_kernel(self, layer: FusedMoE) -> None: w2_g_idx=layer.w2_g_idx, w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 6566d671532a..13cbbabd7c3d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -901,7 +901,7 @@ def _setup_kernel( moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) @@ -1590,7 +1590,7 @@ def process_weights_after_loading(self, layer: FusedMoE) -> None: moe_config=self.moe, experts_cls=self.experts_cls, shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), ) self.moe_kernel.fused_experts.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d3f3033be7b3..ba7be3483ebf 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -369,7 +369,7 @@ def _setup_kernel( moe_config=self.moe, mxfp4_backend=self.mxfp4_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, layer=layer, ) @@ -702,7 +702,7 @@ def _setup_kernel( moe_config=self.moe, mxfp4_backend=self.mxfp4_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, layer=layer, ) diff --git a/vllm/model_executor/layers/quantization/online/fp8.py b/vllm/model_executor/layers/quantization/online/fp8.py index cad65c4c9fe4..3e81b791ef88 100644 --- a/vllm/model_executor/layers/quantization/online/fp8.py +++ b/vllm/model_executor/layers/quantization/online/fp8.py @@ -354,7 +354,7 @@ def _setup_kernel( moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/online/int8.py b/vllm/model_executor/layers/quantization/online/int8.py index 4b4c87fbce96..f4d2f9a2a371 100644 --- a/vllm/model_executor/layers/quantization/online/int8.py +++ b/vllm/model_executor/layers/quantization/online/int8.py @@ -99,7 +99,7 @@ def _setup_kernel(self, layer: "FusedMoE") -> None: moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/online/mxfp8.py b/vllm/model_executor/layers/quantization/online/mxfp8.py index 39a32604442c..312da8a12158 100644 --- a/vllm/model_executor/layers/quantization/online/mxfp8.py +++ b/vllm/model_executor/layers/quantization/online/mxfp8.py @@ -199,7 +199,7 @@ def _setup_kernel( moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 3889e376b560..3caced3d3dbc 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1340,7 +1340,7 @@ def _setup_kernel(self, layer: FusedMoE): moe_config=self.moe, mxfp4_backend=self.mxfp4_backend, experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), + routing_tables=layer._expert_routing_tables(), shared_experts=layer.shared_experts, )