diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6..2a15e79513d8 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -188,6 +191,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # The Deep Gemm kernels only support block size of 128 DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + # Class variable for one-shot logging to verify dispatch optimization + _logged_dispatch_once: bool = False + def __init__(self, max_num_tokens: int, num_dispatchers: int, @@ -199,6 +205,13 @@ def __init__(self, block_shape: Block quantization block shape. per_act_token_quant: Per activation token quantization flag. """ + logger.info( + "[MoE Debug] BatchedDeepGemmExperts.__init__ called with: " + "max_num_tokens=%s, num_dispatchers=%s, block_shape=%s, " + "per_act_token_quant=%s, DEEPGEMM_BLOCK_SHAPE=%s", max_num_tokens, + num_dispatchers, block_shape, per_act_token_quant, + self.DEEPGEMM_BLOCK_SHAPE) + super().__init__( FusedMoEQuantConfig( quant_dtype=torch.float8_e4m3fn, @@ -209,6 +222,10 @@ def __init__(self, self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + logger.info( + "[MoE Debug] BatchedDeepGemmExperts initialized successfully with " + "final block_shape=%s", self.block_shape) + @property def activation_formats( self @@ -222,6 +239,56 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return False + def _get_effective_num_dispatchers(self) -> int: + """ + Calculates the effective number of token dispatchers considering tensor + parallelism. + + When tensor parallelism (TP) is used (TP > 1), only the leader rank + (rank 0) in each TP group should dispatch tokens to avoid redundant + communication. This significantly reduces cross-rank communication + overhead in distributed environments. + + Returns: + int: The effective number of dispatchers to use. + When TP > 1: + - Returns max(1, num_dispatchers // tp_size) for leader ranks + (tp_rank == 0) + - Returns 0 for non-leader ranks (tp_rank != 0) + When TP <= 1: + - Returns the original num_dispatchers + + Note: + Leader ranks are guaranteed at least 1 dispatcher for stability, + while non-leader ranks return 0 to eliminate redundant dispatching. + """ + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + if tp_size <= 1: + # No TP or single device - use all dispatchers + return self.num_dispatchers + + # TP > 1 case + eff = (max(1, self.num_dispatchers // tp_size) if tp_rank == 0 else 0) + + # --- lightweight one-shot log for verification --- + if (not BatchedDeepGemmExperts._logged_dispatch_once + and os.getenv("VLLM_LOG_MOE_DISPATCH", "0") == "1"): + logger.info( + "[moe-dispatch-opt] tp_rank=%d/%d, num_dispatchers=%d -> " + "effective=%d, leader=%s, participates_a2a=%s", + tp_rank, + tp_size, + self.num_dispatchers, + eff, + str(tp_rank == 0), + str(eff > 0), + ) + BatchedDeepGemmExperts._logged_dispatch_once = True + + return eff + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceDelegate() @@ -239,10 +306,10 @@ def workspace_shapes( expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - # FIXME (varun): We should be able to dispatch only from the leader - # DP ranks in the case of TP > 1. At the moment, all the Ranks - # end up sending their tokens. This needs to be fixed. - num_dispatchers = self.num_dispatchers + # Optimize token dispatch: only leader DP ranks dispatch tokens when + # TP > 1. This reduces cross-rank communication overhead in distributed + # MoE models. + num_dispatchers = self._get_effective_num_dispatchers() num_experts = local_num_experts max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens @@ -274,9 +341,31 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): + logger.info("[MoE Debug] *** BatchedDeepGemmExperts.apply() ENTRY *** " + "THIS IS THE DEEP GEMM IMPLEMENTATION BEING CALLED!") + logger.info( + "[MoE Debug] BatchedDeepGemmExperts.apply() parameters: " + "hidden_states.shape=%s, global_num_experts=%s, activation=%s", + hidden_states.shape, global_num_experts, activation) + assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens + # Monitor expert_num_tokens for workspace allocation analysis + if torch.cuda.is_current_stream_capturing(): + logger.debug( + "[MoE Monitor] skip logging during CUDA Graph capture") + else: + cpu_vals = expert_num_tokens.detach().to("cpu") + logger.info( + "[MoE Monitor] expert_num_tokens " + "shape=%s sum=%d max=%d values(sample)=%s", + tuple(expert_num_tokens.shape), + int(cpu_vals.sum().item()), + int(cpu_vals.max().item()), + cpu_vals.numpy(), + ) + assert hidden_states.ndim == 3 assert self.block_shape is not None @@ -288,17 +377,27 @@ def apply( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) + logger.info( + "[MoE Debug] Problem size: E=%s, max_num_tokens=%s, N=%s, K=%s, " + "top_k_num=%s", E, max_num_tokens, N, K, top_k_num) + workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) # (from deepgemm docs) : A value hint (which is a value on CPU) # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens + + logger.info("[MoE Debug] Calling first fp8_m_grouped_gemm_nt_masked") fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), workspace1, expert_num_tokens, expected_m) + logger.info("[MoE Debug] Calling silu_mul_fp8_quant_deep_gemm") a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) + logger.info("[MoE Debug] Calling second fp8_m_grouped_gemm_nt_masked") fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, expert_num_tokens, expected_m) + + logger.info("[MoE Debug] *** BatchedDeepGemmExperts.apply() EXIT ***") diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 89d7412ee223..637bcc439f1d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -5,12 +5,15 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) +logger = init_logger(__name__) + class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -24,6 +27,11 @@ def __init__(self, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, allow_deep_gemm: bool = False): + logger.info( + "[MoE Debug] BatchedTritonOrDeepGemmExperts.__init__ called with: " + "max_num_tokens=%s, num_dispatchers=%s, use_fp8_w8a8=%s, " + "allow_deep_gemm=%s, block_shape=%s", max_num_tokens, + num_dispatchers, use_fp8_w8a8, allow_deep_gemm, block_shape) assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" @@ -49,16 +57,34 @@ def __init__(self, block_shape=self.block_shape, ) + logger.debug( + "[MoE Debug] BatchedTritonOrDeepGemmExperts init: " + "allow_deep_gemm=%s, use_fp8_w8a8=%s, block_shape=%s, " + "DEEPGEMM_BLOCK_SHAPE=%s", allow_deep_gemm, use_fp8_w8a8, + self.block_shape, BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and self.block_shape == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + logger.debug( + "[MoE Debug] Final allow_deep_gemm decision: %s " + "(conditions: allow=%s, fp8=%s, shape_match=%s)", + self.allow_deep_gemm, allow_deep_gemm, use_fp8_w8a8, + self.block_shape == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, block_shape=self.block_shape, # type: ignore[arg-type] ) if self.allow_deep_gemm else None + if self.allow_deep_gemm: + logger.debug( + "[MoE Debug] Created BatchedDeepGemmExperts successfully") + else: + logger.debug("[MoE Debug] Using BatchedTritonExperts fallback") + assert (self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None) @@ -154,11 +180,27 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): + logger.info( + "[MoE Debug] BatchedTritonOrDeepGemmExperts.apply() ENTRY: " + "allow_deep_gemm=%s, hidden_states.shape=%s, global_num_experts=%s", + self.allow_deep_gemm, hidden_states.shape, global_num_experts) + experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) + + # Log which expert implementation is being used + if self.allow_deep_gemm: + logger.info( + "[MoE Debug] Using BatchedDeepGemmExperts for forward pass") + else: + logger.info( + "[MoE Debug] Using BatchedTritonExperts for forward pass") + assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, apply_router_weight_on_input) + + logger.info("[MoE Debug] BatchedTritonOrDeepGemmExperts.apply() EXIT") diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 2a3ae478f3ea..913a905207e1 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -11,6 +11,8 @@ TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_tp_group class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -193,6 +195,15 @@ def prepare_async( quant_config: FusedMoEQuantConfig, ) -> Callable: + # Only DP leader ranks (tp_rank == 0) should dispatch when TP > 1. + tp_world_size = get_tensor_model_parallel_world_size() + tp_rank_in_group = get_tp_group().rank_in_group if tp_world_size > 1 else 0 + if tp_world_size > 1 and tp_rank_in_group != 0: + # Non-leader TP ranks send zero tokens to avoid duplicate dispatch. + a1 = a1[:0] + topk_ids = topk_ids[:0] + topk_weights = topk_weights[:0] + if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 1849e49e0ab5..ff5050822ccc 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -11,6 +11,8 @@ TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_tp_group # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -147,6 +149,15 @@ def prepare_async( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) + # Only DP leader ranks (tp_rank == 0) should dispatch when TP > 1. + tp_world_size = get_tensor_model_parallel_world_size() + tp_rank_in_group = get_tp_group().rank_in_group if tp_world_size > 1 else 0 + if tp_world_size > 1 and tp_rank_in_group != 0: + # Non-leader TP ranks send zero tokens to avoid duplicate dispatch. + a1 = a1[:0] + topk_ids = topk_ids[:0] + topk_weights = topk_weights[:0] + # Dispatch expert_x, expert_num_tokens, self.handle, event, hook = \ self.buffer.low_latency_dispatch(a1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a90a71159f72..56a015ee6013 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -154,9 +154,15 @@ def _maybe_make_prepare_finalize( all_to_all_args = dict() handle = all2all_manager.get_handle(all_to_all_args) + # Only DP leader ranks should dispatch when TP > 1. + # Use number of DP ranks (leaders) as dispatchers in that case. + tp_world_size = all2all_manager.tp_group.world_size + num_dispatchers = (all2all_manager.world_size // + tp_world_size) if tp_world_size > 1 else \ + all2all_manager.world_size prepare_finalize = DeepEPHTPrepareAndFinalize( handle, - num_dispatchers=all2all_manager.world_size, + num_dispatchers=num_dispatchers, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, @@ -183,7 +189,12 @@ def _maybe_make_prepare_finalize( prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, - num_dispatchers=all2all_manager.world_size, + # Only DP leader ranks should dispatch when TP > 1. + # Use number of DP ranks (leaders) as dispatchers in that case. + num_dispatchers=(all2all_manager.world_size // + all2all_manager.tp_group.world_size) + if all2all_manager.tp_group.world_size > 1 else + all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, ) @@ -212,6 +223,16 @@ def init_prepare_finalize(self, layer: torch.nn.Module): f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) + + # Log which expert implementation was selected + allow = getattr(experts, "allow_deep_gemm", None) + use_fp8 = getattr(experts, "use_fp8_w8a8", None) + block_shape = getattr(experts, "block_shape", None) + logger.info( + "[MoE Debug] Expert implementation selected: %s, " + "allow_deep_gemm=%s, use_fp8_w8a8=%s, block_shape=%s", + type(experts).__name__, allow, use_fp8, block_shape) + self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -278,15 +299,22 @@ def select_gemm_impl( moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: + logger.info( + "[MoE Debug] select_gemm_impl called, activation_format=%s, " + "prepare_finalize=%s", prepare_finalize.activation_format, + type(prepare_finalize).__name__) if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): - logger.debug("BatchedTritonExperts %s", self.moe) + logger.info( + "[MoE Debug] Creating BatchedTritonExperts with moe=%s", + self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), ) else: - logger.debug("TritonExperts %s", self.moe) + logger.info("[MoE Debug] Creating TritonExperts with moe=%s", + self.moe) return TritonExperts() def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -788,6 +816,14 @@ def __init__( has_bias: bool = False, is_sequence_parallel=False, ): + logger.info( + "[MoE Debug] *** FusedMoE.__init__ ENTRY *** " + "Creating MoE layer with num_experts=%s, prefix='%s', " + "quant_config=%s, tp_size=%s, dp_size=%s, ep_size=%s", num_experts, + prefix, + type(quant_config).__name__ if quant_config else None, tp_size, + dp_size, ep_size) + super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -902,9 +938,23 @@ def __init__( self.moe_config = moe self.quant_config = quant_config + logger.info( + "[MoE Debug] MoE Config created: global_experts=%s, " + "local_experts=%s, max_tokens=%s, parallel_config=%s, " + "use_pplx=%s, use_deepep_ht=%s, use_deepep_ll=%s, " + "use_flashinfer_cutlass=%s", self.global_num_experts, + self.local_num_experts, moe.max_num_tokens, + f"tp={self.tp_size},dp={self.dp_size},ep={self.ep_size}", + moe.use_pplx_kernels, moe.use_deepep_ht_kernels, + moe.use_deepep_ll_kernels, moe.use_flashinfer_cutlass_kernels) + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. quant_method: Optional[QuantizeMethodBase] = None + logger.info( + "[MoE Debug] Selecting quantization method: quant_config=%s", + type(quant_config).__name__ if quant_config else "None") + quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) @@ -912,6 +962,9 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + logger.info("[MoE Debug] Quantization method selected: %s", + type(quant_method).__name__) + if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3d94626e5d8c..ffe9548eaa54 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -546,18 +546,34 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): # Check for DeepGemm support. self.allow_deep_gemm = False + logger.info( + "[MoE Debug] *** Fp8MoEMethod DeepGEMM Condition Check *** " + "VLLM_USE_DEEP_GEMM=%s, block_quant=%s, weight_block_size=%s", + envs.VLLM_USE_DEEP_GEMM, self.block_quant, + self.quant_config.weight_block_size) + if envs.VLLM_USE_DEEP_GEMM: + logger.info( + "[MoE Debug] VLLM_USE_DEEP_GEMM=True, checking conditions...") if not has_deep_gemm(): - logger.warning_once("Failed to import DeepGemm kernels.") + logger.warning( + "[MoE Debug] FAILED: DeepGemm kernels not available") elif not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "DeepGemm kernels") + logger.warning( + "[MoE Debug] FAILED: Model is not block quantized") elif (is_deep_gemm_supported()): - logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") + logger.info( + "[MoE Debug] SUCCESS: All DeepGEMM conditions met!") self.allow_deep_gemm = True else: - logger.warning_once( - "DeepGemm not supported on the current platform.") + logger.warning( + "[MoE Debug] FAILED: DeepGemm not supported on platform") + else: + logger.info( + "[MoE Debug] VLLM_USE_DEEP_GEMM=False, skipping DeepGEMM") + + logger.info("[MoE Debug] *** FINAL DECISION: allow_deep_gemm=%s ***", + self.allow_deep_gemm) # Check for CutlassBlockScaledGroupedGemm support. self.allow_cutlass_block_scaled_grouped_gemm = False @@ -934,6 +950,12 @@ def select_gemm_impl( moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: + logger.info( + "[MoE Debug] Fp8MoEMethod.select_gemm_impl() ENTRY: " + "activation_format=%s, allow_deep_gemm=%s, " + "block_size=%s", prepare_finalize.activation_format, + self.allow_deep_gemm, self.quant_config.weight_block_size) + from vllm.model_executor.layers.fused_moe import ( BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) @@ -945,11 +967,12 @@ def select_gemm_impl( max_num_tokens_per_rank = ( prepare_finalize.max_num_tokens_per_rank()) assert max_num_tokens_per_rank is not None - logger.debug( - "BatchedTritonOrDeepGemmExperts(%s): " - "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - self.__class__.__name__, max_num_tokens_per_rank, - self.quant_config.weight_block_size, False) + logger.info( + "[MoE Debug] Creating BatchedTritonOrDeepGemmExperts with: " + "max_tokens_per_rank=%s, block_size=%s, allow_deep_gemm=%s, " + "num_dispatchers=%s", max_num_tokens_per_rank, + self.quant_config.weight_block_size, self.allow_deep_gemm, + prepare_finalize.num_dispatchers()) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(),