Skip to content
Closed
65 changes: 55 additions & 10 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from math import log2
from typing import Optional

Expand Down Expand Up @@ -261,17 +262,40 @@ def workspace_shapes(
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)

# Tokens-per-expert capacity actually used by the backend for this
# call. For batched formats (DeepEP-LL / PPLX), aq has shape
# (E, T_backend, K)
# Prefer using aq.size(1) to avoid under-allocation during dummy/profile
# runs or when multiple dispatchers/ranks contribute tokens.
T_backend = aq.size(1) if aq.dim() == 3 else 0

# Fallback capacity from configuration/observation.
num_dispatchers = self.num_dispatchers
observed_M = a.size(0)
if self.max_num_tokens is None:
T_cfg = observed_M * num_dispatchers
else:
# Guard with observed_M to avoid under-estimation when TP>1 or
# during profiling runs.
max_num_tokens = max(self.max_num_tokens, observed_M)
if observed_M > self.max_num_tokens:
with contextlib.suppress(Exception):
logger.debug_once(
"[MoE Debug] Increasing workspace max_num_tokens "
"from configured=%d to observed=%d to avoid OOM. "
"(num_dispatchers=%d, E=%d, N=%d, K=%d)",
self.max_num_tokens, observed_M, num_dispatchers,
num_experts, N, K)
T_cfg = max_num_tokens * num_dispatchers

# Final capacity: honor backend's requested T if larger.
T_eff = max(T_backend, T_cfg)

workspace13 = (num_experts, T_eff, max(K, N))
workspace2 = (num_experts, T_eff, (N // 2))
output = (num_experts, T_eff, K)
Comment on lines +267 to +298
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes might be redundant with another PR that makes the workspaces effectively global. See #23693

return (workspace13, workspace2, output, a.dtype)

def apply(
Expand Down Expand Up @@ -306,6 +330,27 @@ def apply(
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)

# Debug (one-time): total dispatched tokens received on this EP rank.
# Avoid triggering CUDA Graph sync by skipping during graph capture or
# torch.compile. This reads a small scalar once for observability.
if (not torch.cuda.is_current_stream_capturing()
and not torch.compiler.is_compiling()):
try:
total_tokens = int(expert_num_tokens.sum().item())
logger.debug_once(
"[MoE Debug] EP rank received tokens: total=%d, E=%d, "
"max_tokens_per_dispatcher=%d, num_dispatchers=%d",
total_tokens, E, max_num_tokens, self.num_dispatchers)
except Exception as e:
# Log the failure without triggering CUDA graph sync.
# Only prints once to avoid log spam.
with contextlib.suppress(Exception):
logger.debug_once(
"[MoE Debug] Skipped token-count log due to %r "
"(E=%d, shape=%s, device=%s)", e, E,
tuple(expert_num_tokens.size()),
expert_num_tokens.device)

workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))

# (from deepgemm docs) : A value hint (which is a value on CPU)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
moe_kernel_quantize_input, restrict_dispatch_to_tp_leader)


class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Expand Down Expand Up @@ -191,6 +191,10 @@ def prepare_async(
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, mk.ReceiverType]:

# Restrict dispatch to TP leader to avoid duplicate work.
a1, topk_ids, topk_weights = restrict_dispatch_to_tp_leader(
a1, topk_ids, topk_weights)

if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape)
moe_kernel_quantize_input, normalize_batched_scales_shape,
restrict_dispatch_to_tp_leader)
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook)

Expand Down Expand Up @@ -148,6 +149,10 @@ def prepare_async(
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * topk_weights.to(a1.dtype)

# Restrict dispatch to TP leader to avoid duplicate work.
a1, topk_ids, topk_weights = restrict_dispatch_to_tp_leader(
a1, topk_ids, topk_weights)

# Dispatch
expert_x, expert_num_tokens, handle, _, hook= \
self.buffer.low_latency_dispatch(a1,
Expand Down
52 changes: 49 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,15 @@ def _maybe_make_prepare_finalize(

all_to_all_args = dict()
handle = all2all_manager.get_handle(all_to_all_args)
# Only DP leader ranks should dispatch when TP > 1.
# Use number of DP ranks (leaders) as dispatchers in that case.
tp_world_size = all2all_manager.tp_group.world_size
num_dispatchers = (all2all_manager.world_size //
tp_world_size) if tp_world_size > 1 else \
all2all_manager.world_size
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
num_dispatchers=all2all_manager.world_size,
num_dispatchers=num_dispatchers,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
Expand All @@ -190,7 +196,12 @@ def _maybe_make_prepare_finalize(
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
# Only DP leader ranks should dispatch when TP > 1.
# Use number of DP ranks (leaders) as dispatchers in that case.
num_dispatchers=(all2all_manager.world_size //
all2all_manager.tp_group.world_size)
if all2all_manager.tp_group.world_size > 1 else
all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
)

Expand Down Expand Up @@ -224,6 +235,14 @@ def init_prepare_finalize(self, layer: torch.nn.Module):
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, layer)
# Log which expert implementation was selected
allow = getattr(experts, "allow_deep_gemm", None)
use_fp8 = getattr(experts, "use_fp8_w8a8", None)
block_shape = getattr(experts, "block_shape", None)
logger.debug(
"[MoE Debug] Expert implementation selected: %s, "
"allow_deep_gemm=%s, use_fp8_w8a8=%s, block_shape=%s",
type(experts).__name__, allow, use_fp8, block_shape)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
Expand Down Expand Up @@ -301,7 +320,9 @@ def select_gemm_impl(
assert self.moe_quant_config is not None
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
logger.debug(
"[MoE Debug] Creating BatchedTritonExperts with moe=%s",
self.moe)
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
Expand Down Expand Up @@ -840,6 +861,14 @@ def __init__(
has_bias: bool = False,
is_sequence_parallel=False,
):
logger.debug(
"[MoE Debug] *** FusedMoE.__init__ ENTRY *** "
"Creating MoE layer with num_experts=%s, prefix='%s', "
"quant_config=%s, tp_size=%s, dp_size=%s, ep_size=%s", num_experts,
prefix,
type(quant_config).__name__ if quant_config else None, tp_size,
dp_size, ep_size)

super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -977,16 +1006,33 @@ def __init__(
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.quant_config = quant_config

logger.debug(
"[MoE Debug] MoE Config created: global_experts=%s, "
"local_experts=%s, max_tokens=%s, parallel_config=%s, "
"use_pplx=%s, use_deepep_ht=%s, use_deepep_ll=%s, "
"use_flashinfer_cutlass=%s", self.global_num_experts,
self.local_num_experts, moe.max_num_tokens,
f"tp={self.tp_size},dp={self.dp_size},ep={self.ep_size}",
moe.use_pplx_kernels, moe.use_deepep_ht_kernels,
moe.use_deepep_ll_kernels, moe.use_flashinfer_cutlass_kernels)

# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
quant_method: Optional[QuantizeMethodBase] = None
logger.debug(
"[MoE Debug] Selecting quantization method: quant_config=%s",
type(quant_config).__name__ if quant_config else "None")

quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
else quant_config.get_quant_method(self, prefix))

assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method

logger.debug("[MoE Debug] Quantization method selected: %s",
type(quant_method).__name__)

if self.enable_eplb:
from vllm.model_executor.layers.quantization.fp8 import (
Fp8MoEMethod)
Expand Down
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch

from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
Expand Down Expand Up @@ -269,6 +271,28 @@ def _validate_scale_shape(
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"


def restrict_dispatch_to_tp_leader(
*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Restrict dispatch to the TP leader rank.

If tensor model parallelism is enabled (TP > 1), only ranks with
``tp_rank_in_group == 0`` should perform dispatch. Non-leader ranks
return empty tensors to avoid duplicate dispatch work.

Returns the input tensors unchanged on the TP leader or when TP == 1;
otherwise returns zero-length views of the inputs along the first dim.
"""
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size <= 1:
return tensors

tp_rank_in_group = get_tp_group().rank_in_group
if tp_rank_in_group != 0:
return tuple(t[:0] for t in tensors)

return tensors


def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"
Loading