diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index cf0b965cc8c5..d623ce5c8549 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib from math import log2 from typing import Optional @@ -261,17 +262,40 @@ def workspace_shapes( expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 - # FIXME (varun): We should be able to dispatch only from the leader - # DP ranks in the case of TP > 1. At the moment, all the Ranks - # end up sending their tokens. This needs to be fixed. - num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = (num_experts, max_num_tokens * num_dispatchers, - max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) - output = (num_experts, max_num_tokens * num_dispatchers, K) + + # Tokens-per-expert capacity actually used by the backend for this + # call. For batched formats (DeepEP-LL / PPLX), aq has shape + # (E, T_backend, K) + # Prefer using aq.size(1) to avoid under-allocation during dummy/profile + # runs or when multiple dispatchers/ranks contribute tokens. + T_backend = aq.size(1) if aq.dim() == 3 else 0 + + # Fallback capacity from configuration/observation. + num_dispatchers = self.num_dispatchers + observed_M = a.size(0) + if self.max_num_tokens is None: + T_cfg = observed_M * num_dispatchers + else: + # Guard with observed_M to avoid under-estimation when TP>1 or + # during profiling runs. + max_num_tokens = max(self.max_num_tokens, observed_M) + if observed_M > self.max_num_tokens: + with contextlib.suppress(Exception): + logger.debug_once( + "[MoE Debug] Increasing workspace max_num_tokens " + "from configured=%d to observed=%d to avoid OOM. " + "(num_dispatchers=%d, E=%d, N=%d, K=%d)", + self.max_num_tokens, observed_M, num_dispatchers, + num_experts, N, K) + T_cfg = max_num_tokens * num_dispatchers + + # Final capacity: honor backend's requested T if larger. + T_eff = max(T_backend, T_cfg) + + workspace13 = (num_experts, T_eff, max(K, N)) + workspace2 = (num_experts, T_eff, (N // 2)) + output = (num_experts, T_eff, K) return (workspace13, workspace2, output, a.dtype) def apply( @@ -306,6 +330,27 @@ def apply( E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) + # Debug (one-time): total dispatched tokens received on this EP rank. + # Avoid triggering CUDA Graph sync by skipping during graph capture or + # torch.compile. This reads a small scalar once for observability. + if (not torch.cuda.is_current_stream_capturing() + and not torch.compiler.is_compiling()): + try: + total_tokens = int(expert_num_tokens.sum().item()) + logger.debug_once( + "[MoE Debug] EP rank received tokens: total=%d, E=%d, " + "max_tokens_per_dispatcher=%d, num_dispatchers=%d", + total_tokens, E, max_num_tokens, self.num_dispatchers) + except Exception as e: + # Log the failure without triggering CUDA graph sync. + # Only prints once to avoid log spam. + with contextlib.suppress(Exception): + logger.debug_once( + "[MoE Debug] Skipped token-count log due to %r " + "(E=%d, shape=%s, device=%s)", e, E, + tuple(expert_num_tokens.size()), + expert_num_tokens.device) + workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) # (from deepgemm docs) : A value hint (which is a value on CPU) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f390f0a25875..f29987aa700a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + moe_kernel_quantize_input, restrict_dispatch_to_tp_leader) class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -191,6 +191,10 @@ def prepare_async( quant_config: FusedMoEQuantConfig, ) -> tuple[Callable, mk.ReceiverType]: + # Restrict dispatch to TP leader to avoid duplicate work. + a1, topk_ids, topk_weights = restrict_dispatch_to_tp_leader( + a1, topk_ids, topk_weights) + if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 101fc8798c42..7b8e5519c025 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -10,7 +10,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input, normalize_batched_scales_shape) + moe_kernel_quantize_input, normalize_batched_scales_shape, + restrict_dispatch_to_tp_leader) from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook) @@ -148,6 +149,10 @@ def prepare_async( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * topk_weights.to(a1.dtype) + # Restrict dispatch to TP leader to avoid duplicate work. + a1, topk_ids, topk_weights = restrict_dispatch_to_tp_leader( + a1, topk_ids, topk_weights) + # Dispatch expert_x, expert_num_tokens, handle, _, hook= \ self.buffer.low_latency_dispatch(a1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da513d75da4d..1dc48f2118cc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -162,9 +162,15 @@ def _maybe_make_prepare_finalize( all_to_all_args = dict() handle = all2all_manager.get_handle(all_to_all_args) + # Only DP leader ranks should dispatch when TP > 1. + # Use number of DP ranks (leaders) as dispatchers in that case. + tp_world_size = all2all_manager.tp_group.world_size + num_dispatchers = (all2all_manager.world_size // + tp_world_size) if tp_world_size > 1 else \ + all2all_manager.world_size prepare_finalize = DeepEPHTPrepareAndFinalize( handle, - num_dispatchers=all2all_manager.world_size, + num_dispatchers=num_dispatchers, dp_size=all2all_manager.dp_world_size, rank_expert_offset=all2all_manager.rank * moe.num_local_experts, @@ -190,7 +196,12 @@ def _maybe_make_prepare_finalize( prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, - num_dispatchers=all2all_manager.world_size, + # Only DP leader ranks should dispatch when TP > 1. + # Use number of DP ranks (leaders) as dispatchers in that case. + num_dispatchers=(all2all_manager.world_size // + all2all_manager.tp_group.world_size) + if all2all_manager.tp_group.world_size > 1 else + all2all_manager.world_size, use_fp8_dispatch=use_fp8_dispatch, ) @@ -224,6 +235,14 @@ def init_prepare_finalize(self, layer: torch.nn.Module): f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) + # Log which expert implementation was selected + allow = getattr(experts, "allow_deep_gemm", None) + use_fp8 = getattr(experts, "use_fp8_w8a8", None) + block_shape = getattr(experts, "block_shape", None) + logger.debug( + "[MoE Debug] Expert implementation selected: %s, " + "allow_deep_gemm=%s, use_fp8_w8a8=%s, block_shape=%s", + type(experts).__name__, allow, use_fp8, block_shape) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -301,7 +320,9 @@ def select_gemm_impl( assert self.moe_quant_config is not None if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): - logger.debug("BatchedTritonExperts %s", self.moe) + logger.debug( + "[MoE Debug] Creating BatchedTritonExperts with moe=%s", + self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), @@ -840,6 +861,14 @@ def __init__( has_bias: bool = False, is_sequence_parallel=False, ): + logger.debug( + "[MoE Debug] *** FusedMoE.__init__ ENTRY *** " + "Creating MoE layer with num_experts=%s, prefix='%s', " + "quant_config=%s, tp_size=%s, dp_size=%s, ep_size=%s", num_experts, + prefix, + type(quant_config).__name__ if quant_config else None, tp_size, + dp_size, ep_size) + super().__init__() if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -977,9 +1006,23 @@ def __init__( self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config + logger.debug( + "[MoE Debug] MoE Config created: global_experts=%s, " + "local_experts=%s, max_tokens=%s, parallel_config=%s, " + "use_pplx=%s, use_deepep_ht=%s, use_deepep_ll=%s, " + "use_flashinfer_cutlass=%s", self.global_num_experts, + self.local_num_experts, moe.max_num_tokens, + f"tp={self.tp_size},dp={self.dp_size},ep={self.ep_size}", + moe.use_pplx_kernels, moe.use_deepep_ht_kernels, + moe.use_deepep_ll_kernels, moe.use_flashinfer_cutlass_kernels) + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. quant_method: Optional[QuantizeMethodBase] = None + logger.debug( + "[MoE Debug] Selecting quantization method: quant_config=%s", + type(quant_config).__name__ if quant_config else "None") + quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None else quant_config.get_quant_method(self, prefix)) @@ -987,6 +1030,9 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + logger.debug("[MoE Debug] Quantization method selected: %s", + type(quant_method).__name__) + if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 678942e568d8..0bc6d4f83973 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -6,6 +6,8 @@ import torch from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( @@ -269,6 +271,28 @@ def _validate_scale_shape( expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + +def restrict_dispatch_to_tp_leader( + *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Restrict dispatch to the TP leader rank. + If tensor model parallelism is enabled (TP > 1), only ranks with + ``tp_rank_in_group == 0`` should perform dispatch. Non-leader ranks + return empty tensors to avoid duplicate dispatch work. + + Returns the input tensors unchanged on the TP leader or when TP == 1; + otherwise returns zero-length views of the inputs along the first dim. + """ + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size <= 1: + return tensors + + tp_rank_in_group = get_tp_group().rank_in_group + if tp_rank_in_group != 0: + return tuple(t[:0] for t in tensors) + + return tensors + + def activation_without_mul(activation: str) -> str: return activation + "_no_mul"