From bcf3413e0141b2abc4ee348af75196d557af321d Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:15:33 -0800 Subject: [PATCH 1/6] move quant before all gather in fp4 Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../device_communicators/all2all.py | 22 +++++++-- .../device_communicators/cuda_communicator.py | 16 +++++-- vllm/distributed/parallel_state.py | 13 +++-- vllm/model_executor/layers/fused_moe/layer.py | 48 ++++++++++++++++++- .../quantization/utils/flashinfer_fp4_moe.py | 16 ++++--- vllm/utils/flashinfer.py | 17 +++++++ 6 files changed, 111 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741..b91e5b9bb422 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -76,6 +76,7 @@ def dispatch( router_logits = self.naive_multicast( router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel ) + return hidden_states, router_logits def combine( @@ -113,7 +114,11 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): """ Gather hidden_states and router_logits from all dp ranks. """ @@ -121,15 +126,22 @@ def dispatch( assert dp_metadata is not None sizes = dp_metadata.get_chunk_sizes_across_dp_rank() assert sizes is not None - dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - hidden_states, router_logits = dist_group.all_gatherv( - [hidden_states, router_logits], + + tensors_to_gather = [hidden_states, router_logits] + if extra_tensors is not None: + tensors_to_gather.extend(extra_tensors) + + gathered_tensors = dist_group.all_gatherv( + tensors_to_gather, dim=0, sizes=sizes, ) - return hidden_states, router_logits + + if extra_tensors is not None: + return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) + return gathered_tensors[0], gathered_tensors[1] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index cd9c267beb5b..9542498c453e 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -318,17 +318,23 @@ def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): return output_list - def dispatch( + def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.all2all_manager.dispatch( + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, # type: ignore[call-arg] ) - return hidden_states, router_logits def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 338cb1f1814b..f5ada5a009ec 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1007,10 +1007,17 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if self.device_communicator is not None: - return self.device_communicator.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.device_communicator.dispatch( # type: ignore[call-arg] + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, ) else: return hidden_states, router_logits diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cc3afade709d..5f472b818cd7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -44,6 +44,7 @@ is_flashinfer_supporting_global_sf, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import ( aux_stream, @@ -1933,10 +1934,53 @@ def forward_impl( ) with sp_ctx: + extra_tensors = None + hidden_states_to_dispatch = hidden_states if do_naive_dispatch_combine: - hidden_states_combined, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, + ) + + post_quant_allgather = ( + has_flashinfer_trtllm_fused_moe() + and self.quant_method is not None + and self.dp_size > 1 + and self.use_ep + and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) ) + if post_quant_allgather: + import flashinfer + + a1_gscale = self.w13_input_scale_quant + ( + hidden_states_fp4, + hidden_states_sf, + ) = flashinfer.fp4_quantize( + hidden_states, + a1_gscale, + is_sf_swizzled_layout=False, + ) + extra_tensors = [hidden_states_sf] + hidden_states_to_dispatch = hidden_states_fp4 + + dispatch_res = get_ep_group().dispatch( + hidden_states_to_dispatch, + router_logits, + self.is_sequence_parallel, + extra_tensors=extra_tensors, + ) + if extra_tensors is not None: + hidden_states_combined, router_logits, extra_tensors_combined = ( + dispatch_res + ) + hidden_states_combined = ( + hidden_states_combined, + extra_tensors_combined[0], + ) + else: + hidden_states_combined, router_logits = dispatch_res + # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. if has_separate_shared_experts and not use_shared_experts_stream: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index e424cd0e1ac9..ffda394918e6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # hidden_states is the already quantized + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Determine routing method type use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5019b771f4a1..1c2710be3173 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool: ) +@functools.cache +def has_flashinfer_trtllm_fused_moe() -> bool: + """Return `True` if FlashInfer TRTLLM fused MoE is available.""" + if not has_flashinfer_moe(): + return False + required_functions = [ + ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ] + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return `True` if FlashInfer CUTLASS fused MoE is available.""" From 193328c53f04cbeb336239e9409c1779ec31c540 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:47:44 -0800 Subject: [PATCH 2/6] extra tensor in base class Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/distributed/device_communicators/all2all.py | 7 +++++++ .../device_communicators/base_device_communicator.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index b91e5b9bb422..7a4e81cf967d 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -64,7 +64,12 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError( + "extra_tensors is not supported for NaiveAll2AllManager" + ) sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata assert dp_metadata is not None @@ -216,6 +221,7 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -263,6 +269,7 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 3a849da70e4c..51a1c969ec63 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -3,6 +3,8 @@ import threading from weakref import WeakValueDictionary +from typing import Any + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -68,7 +70,11 @@ def dispatch( hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ): + extra_tensors: list[torch.Tensor] | None = None, + ) -> Any: + # Subclasses should either: + # - implement handling for extra_tensors, or + # - raise a clear error if extra_tensors is not supported. raise NotImplementedError def set_num_sms(self, num_sms: int): From d8127ce9fa20d9afdb6ddc2b464244703c2be1dd Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:30:38 -0800 Subject: [PATCH 3/6] move get extra tensor to base class Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/fused_moe/fused_moe_method_base.py | 11 +++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 17 ++++------------- .../layers/quantization/modelopt.py | 18 ++++++++++++++++++ .../quantization/utils/flashinfer_fp4_moe.py | 2 +- 4 files changed, 34 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 8c9d8a2777d5..305feb1cd126 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -71,6 +71,17 @@ def select_gemm_impl( "implementation based on the prepare_finalize" ) + def prepare_dp_allgather_tensor( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" + raise NotImplementedError( + f"Method 'prepare_dp_allgather_tensor' is not implemented in {self.__class__.__name__}." + ) + @abstractmethod def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5f472b818cd7..e17d3d4afed7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1935,7 +1935,6 @@ def forward_impl( with sp_ctx: extra_tensors = None - hidden_states_to_dispatch = hidden_states if do_naive_dispatch_combine: # Avoid circular import from vllm.model_executor.layers.quantization.modelopt import ( @@ -1950,19 +1949,11 @@ def forward_impl( and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) ) if post_quant_allgather: - import flashinfer - - a1_gscale = self.w13_input_scale_quant - ( - hidden_states_fp4, - hidden_states_sf, - ) = flashinfer.fp4_quantize( - hidden_states, - a1_gscale, - is_sf_swizzled_layout=False, + hidden_states_to_dispatch, extra_tensors = ( + self.quant_method.prepare_dp_allgather_tensor( + self, hidden_states, router_logits + ) ) - extra_tensors = [hidden_states_sf] - hidden_states_to_dispatch = hidden_states_fp4 dispatch_res = get_ep_group().dispatch( hidden_states_to_dispatch, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 030d85080a34..648ce5af54a4 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1530,6 +1530,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_blockscale_swizzled, requires_grad=False ) + def prepare_dp_allgather_tensor( + self, + layer: FusedMoE, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Optionally prepare extra tensors to carry through DP allgather/EP.""" + import flashinfer + + a1_gscale = layer.w13_input_scale_quant + hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( + hidden_states, + a1_gscale, + is_sf_swizzled_layout=False, + ) + extra_tensors: list[torch.Tensor] = [hidden_states_sf] + return hidden_states_fp4, extra_tensors + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index ffda394918e6..d1d796c849a4 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( def flashinfer_trtllm_fp4_moe( layer: torch.nn.Module, - x: torch.Tensor, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, global_num_experts: int, From 4d15aae83b206723351ec5b406dd8e345617cb0a Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:41:11 -0800 Subject: [PATCH 4/6] lint Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../device_communicators/base_device_communicator.py | 3 +-- vllm/model_executor/layers/fused_moe/fused_moe_method_base.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 51a1c969ec63..caeff54406b5 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from weakref import WeakValueDictionary - from typing import Any +from weakref import WeakValueDictionary import torch import torch.distributed as dist diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 305feb1cd126..a46e3972ed8e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -79,7 +79,8 @@ def prepare_dp_allgather_tensor( ) -> tuple[torch.Tensor, list[torch.Tensor]]: """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" raise NotImplementedError( - f"Method 'prepare_dp_allgather_tensor' is not implemented in {self.__class__.__name__}." + "Method 'prepare_dp_allgather_tensor' is not implemented in " + f"{self.__class__.__name__}." ) @abstractmethod From dd5ae909b44d18e47cffb7251d768a77bd09211c Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Thu, 11 Dec 2025 23:44:59 +0000 Subject: [PATCH 5/6] add quant before comm for routed trtllmgen also Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/quantization/modelopt.py | 7 ++++++- .../quantization/utils/flashinfer_fp4_moe.py | 18 +++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 648ce5af54a4..3d5a0b8b41c5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1602,8 +1602,13 @@ def apply( e_score_correction_bias=layer.e_score_correction_bias, ) + # Hidden_states in select_experts is only used to extract metadata + if isinstance(x, tuple): + x_routing, _ = x + else: + x_routing = x topk_weights, topk_ids, _ = layer.select_experts( - hidden_states=x, + hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index d1d796c849a4..003e5c26659e 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -368,13 +368,17 @@ def flashinfer_trtllm_fp4_routed_moe( torch.bfloat16 ).view(torch.int16) - # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + # Hidden_states is the already quantized + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Call TRT-LLM FP4 block-scale MoE kernel out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( From 15f9155a1dcb4c6ecbfd2c42033ea6d44437e282 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 12 Dec 2025 19:11:00 +0000 Subject: [PATCH 6/6] fix Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/layer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e17d3d4afed7..b39ce415a0f8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1954,6 +1954,8 @@ def forward_impl( self, hidden_states, router_logits ) ) + else: + hidden_states_to_dispatch = hidden_states dispatch_res = get_ep_group().dispatch( hidden_states_to_dispatch,