Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
Expand All @@ -76,6 +81,7 @@ def dispatch(
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)

return hidden_states, router_logits

def combine(
Expand Down Expand Up @@ -113,23 +119,34 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
Comment thread
jiahanc marked this conversation as resolved.
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],

tensors_to_gather = [hidden_states, router_logits]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
Comment thread
jiahanc marked this conversation as resolved.

gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
return hidden_states, router_logits

if extra_tensors is not None:
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
Expand Down Expand Up @@ -204,6 +221,7 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

Expand Down Expand Up @@ -251,6 +269,7 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary

import torch
Expand Down Expand Up @@ -68,7 +69,11 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError

def set_num_sms(self, num_sms: int):
Expand Down
16 changes: 11 additions & 5 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,23 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):

return output_list

def dispatch(
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
Comment on lines +332 to +336
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch passes extra tensors to backends that cannot accept them

The new pre-quantized path now calls all2all_manager.dispatch(...) with an extra_tensors argument (see call below), but only NaiveAll2AllManager and AgRsAll2AllManager were updated to accept that parameter. Other supported backends (e.g., PPLXAll2AllManager.dispatch at vllm/distributed/device_communicators/all2all.py:229-235 and DeepEPAll2AllManagerBase.dispatch at lines 276-282) still take only (hidden_states, router_logits, is_sequence_parallel). When ModelOpt FP4 MoE runs with those backends and post_quant_allgather is enabled, this call will raise TypeError: dispatch() takes 4 positional arguments but 5 were given, crashing inference for those configurations.

Useful? React with 👍 / 👎.

)
return hidden_states, router_logits

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
Expand Down
13 changes: 10 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,10 +1007,17 @@ def dispatch(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, router_logits
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def select_gemm_impl(
"implementation based on the prepare_finalize"
)

def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
)

@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
Expand Down
41 changes: 39 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_flashinfer_supporting_global_sf,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import (
aux_stream,
Expand Down Expand Up @@ -1933,10 +1934,46 @@ def forward_impl(
)

with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4FusedMoE,
)

post_quant_allgather = (
has_flashinfer_trtllm_fused_moe()
and self.quant_method is not None
and self.dp_size > 1
and self.use_ep
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
self, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states

dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,
Comment on lines +1960 to +1962
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch uses undefined hidden_states_to_dispatch

In the DP dispatch path the call to get_ep_group().dispatch(...) uses hidden_states_to_dispatch, but that variable is only assigned inside the if post_quant_allgather branch above; when the optimization is disabled (e.g., any dp_size>1 run that is not ModelOpt TRTLLM), this block is skipped and hidden_states_to_dispatch is undefined, so the forward will crash with an UnboundLocalError before any dispatch occurs.

Useful? React with 👍 / 👎.

self.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = (
dispatch_res
)
hidden_states_combined = (
hidden_states_combined,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res

# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:
Expand Down
25 changes: 24 additions & 1 deletion vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w2_blockscale_swizzled, requires_grad=False
)

def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer

a1_gscale = layer.w13_input_scale_quant
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
is_sf_swizzled_layout=False,
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
Expand Down Expand Up @@ -1584,8 +1602,13 @@ def apply(
e_score_correction_bias=layer.e_score_correction_bias,
)

# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
hidden_states=x_routing,
router_logits=router_logits,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(

def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
global_num_experts: int,
Expand Down Expand Up @@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE

# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)

# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
Expand Down Expand Up @@ -364,13 +368,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch.bfloat16
).view(torch.int16)

# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)

# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
Expand Down
17 changes: 17 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
)


@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if not has_flashinfer_moe():
return False
required_functions = [
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True


@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
Expand Down