From 1897f159f28b6d7c13e8aa45cbffd5ae65d5240a Mon Sep 17 00:00:00 2001 From: Leo Tian Date: Thu, 26 Feb 2026 16:56:37 -0500 Subject: [PATCH 01/12] initial port Signed-off-by: wzhao18 --- vllm/config/parallel.py | 4 +- .../device_communicators/all2all.py | 126 +++++++++++++++++- .../device_communicators/cuda_communicator.py | 4 + .../device_communicators/mnnvl_compat.py | 10 +- .../layers/fused_moe/all2all_utils.py | 10 ++ .../model_executor/layers/fused_moe/config.py | 18 +++ .../flashinfer_a2a_prepare_finalize.py | 114 ++++++++++++++++ .../layers/fused_moe/oracle/nvfp4.py | 18 ++- vllm/utils/flashinfer.py | 9 ++ 9 files changed, 302 insertions(+), 11 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 6e84cf16b203..a403bd7e60f9 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -44,6 +44,7 @@ "mori", "allgather_reducescatter", "flashinfer_all2allv", + "flashinfer_moe_a2a", ] @@ -155,7 +156,8 @@ class ParallelConfig: - "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n - "mori": Use mori kernels\n - - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" + - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl\n + - "flashinfer_moe_a2a": Use flashinfer moe alltoall kernels""" max_parallel_loading_workers: int | None = None """Maximum number of parallel loading workers when loading model diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 3efcebd54a97..c2631d0c7ecc 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -3,12 +3,16 @@ from typing import Any import torch +import torch.distributed as dist import vllm.envs as envs from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils.flashinfer import has_flashinfer_all2all +from vllm.utils.flashinfer import ( + has_flashinfer_all2all, + has_flashinfer_moe_a2a, +) from vllm.utils.import_utils import has_deep_ep, has_mori from .base_device_communicator import All2AllManagerBase, Cache @@ -20,6 +24,15 @@ MnnvlMoe, # type: ignore[import-not-found] ) +if has_flashinfer_moe_a2a(): + from flashinfer.comm import Mapping # type: ignore[import-not-found] + from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] + from flashinfer.comm.trtllm_moe_alltoall import ( + MoeAlltoAll, # type: ignore[import-not-found] + moe_a2a_get_workspace_size_per_rank, + ) + + logger = init_logger(__name__) @@ -608,3 +621,114 @@ def get_handle(self, kwargs): mori_kwargs, self._make_handle ) return handle + + +class FlashInferMoeA2AManager(All2AllManagerBase): + """ + All2All communication based on FlashInfer's MoeAlltoAll kernel. + This is a newer kernel from trtllm that should perform better than the kernel + used by flashinfer_all2allv. + """ + + rank: int + world_size: int + + def __init__(self, cpu_group): + assert has_flashinfer_moe_a2a(), ( + "flashinfer trtllm_moe_alltoall module not found. " + "Please install/check flashinfer" + ) + super().__init__(cpu_group) + logger.debug( + "Initialize FlashInfer MoeA2A rank=%d, world size=%d", + self.rank, + self.world_size, + ) + self.initialized = False + self.moe_alltoall: MoeAlltoAll | None = None + self.mapping = None + + def initialize( + self, + max_num_tokens: int, + top_k: int, + num_experts: int, + hidden_size: int, + ): + """Initialize the MoeAlltoAll workspace.""" + if self.initialized: + return + + self.cleanup() + gpus_per_node = torch.cuda.device_count() + logger.debug( + "Making MoeA2A mapping: rank=%d, world size=%d", + self.rank, + self.world_size, + ) + self.mapping = Mapping( + self.world_size, + self.rank, + gpus_per_node, + tp_size=self.world_size, + moe_ep_size=self.world_size, + ) + + from vllm.distributed.device_communicators.mnnvl_compat import ( + CustomCommunicator, + ) + + dp_config = MnnvlConfig( + comm_backend=CustomCommunicator(get_dp_group().cpu_group), + ) + total_dispatch_payload_size_per_token = ( + hidden_size // 2 # nvfp4 hidden states + + hidden_size // 16 # fp8 scaling factors + + top_k * 4 # int32 topks ids + + top_k * 4 # float32 topk weights + ) + combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states + self.workspace_size = moe_a2a_get_workspace_size_per_rank( + ep_size=self.world_size, + max_num_tokens=max_num_tokens, + total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token, + combine_payload_size_per_token=combine_payload_size_per_token, + ) + + self.moe_alltoall = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=num_experts, + workspace_size_per_rank=self.workspace_size, + mnnvl_config=dp_config, + ) + + self.gpus_per_node = gpus_per_node + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.num_experts = num_experts + self.hidden_size = hidden_size + self.initialized = True + + logger.info( + "FlashInfer MoeA2A initialized for rank %s, size %s", + self.rank, + self.world_size, + ) + dist.barrier() + + def get_handle(self, kwargs): + return self + + def cleanup(self): + """Clean up resources.""" + if self.initialized and self.moe_alltoall is not None: + try: + del self.moe_alltoall + except Exception as e: + logger.warning("Failed to cleanup FlashInfer MoeA2A workspace: %s", e) + finally: + self.moe_alltoall = None + self.mapping = None + self.initialized = False diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 5e18dbde91d2..51ef6f305f67 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -149,6 +149,10 @@ def __init__( self.all2all_manager = FlashInferAllToAllManager( self.cpu_group, tcp_store_group ) + elif self.all2all_backend == "flashinfer_moe_a2a": + from .all2all import FlashInferMoeA2AManager + + self.all2all_manager = FlashInferMoeA2AManager(self.cpu_group) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py index 81f4ae20738d..cf78d7bbcccb 100644 --- a/vllm/distributed/device_communicators/mnnvl_compat.py +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -25,14 +25,14 @@ def allgather(self, data: int): dist.all_gather_object(gathered, data, group=self._group) return gathered - # NOTE(rob): CommBackend is an abstract class, and bcast/barrier - # are unimplemented on vLLM side. If we need to utilize these - # methods in the future, can create a concrete implementation. def bcast(self, data: Any, root: int) -> Any: - raise NotImplementedError + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] def barrier(self) -> None: - raise NotImplementedError + dist.barrier(group=self._group) def Split(self, color: int, key: int) -> "CustomCommunicator": return self diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 47ca95ee54cb..ba11bb2bd512 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -16,6 +16,7 @@ ) from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( FlashInferA2APrepareAndFinalize, + FlashInferMoeA2APrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, @@ -202,6 +203,15 @@ def maybe_make_prepare_finalize( num_dispatchers=all2all_manager.world_size, ) + elif moe.use_fi_moe_all2all_kernels: + assert quant_config is not None + prepare_finalize = FlashInferMoeA2APrepareAndFinalize( + max_num_tokens=moe.max_num_tokens, + top_k=moe.experts_per_token, + num_experts=moe.num_experts, + hidden_size=moe.hidden_dim, + ) + elif moe.use_naive_all2all_kernels and allow_new_interface: prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep( use_monolithic=use_monolithic, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index e0ed9130c2ce..87aaa0eef283 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -229,6 +229,7 @@ class FusedMoEQuantConfig: _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc is_nvfp4_scale_swizzled: bool = True + _g1_scale_c: torch.Tensor | None = None def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( @@ -344,6 +345,10 @@ def w2_precision(self) -> "PrecisionConfig | None": def g2_alphas(self) -> torch.Tensor | None: return self._w2.alpha_or_gscale + @property + def g1_scale_c(self) -> torch.Tensor | None: + return self._g1_scale_c + @property def use_fp8_w8a8(self) -> bool: return self.quant_dtype == torch.float8_e4m3fn @@ -477,6 +482,7 @@ def make( w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, is_nvfp4_scale_swizzled: bool = True, + g1_scale_c: torch.Tensor | None = None, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -507,6 +513,7 @@ def make( - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. + - g1_scale_c: Pre-computed scale for TRT-LLM FP4 MoE kernel (a2_gscale * g1_alphas) """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", @@ -540,6 +547,7 @@ def make( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, + _g1_scale_c=g1_scale_c, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -742,6 +750,7 @@ def nvfp4_moe_quant_config( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, is_nvfp4_scale_swizzled: bool = True, + g1_scale_c: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. @@ -760,6 +769,7 @@ def nvfp4_moe_quant_config( per_out_ch_quant=False, block_shape=None, is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, + g1_scale_c=g1_scale_c, ) @@ -962,6 +972,10 @@ def use_fi_all2allv_kernels(self): self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" ) + @property + def use_fi_moe_a2a_kernels(self): + return self.use_all2all_kernels and self.all2all_backend == "flashinfer_moe_a2a" + @property def use_batched_activation_format(self): return self.use_deepep_ll_kernels @@ -1239,6 +1253,10 @@ def use_mori_kernels(self): def use_fi_all2allv_kernels(self): return self.moe_parallel_config.use_fi_all2allv_kernels + @property + def use_fi_moe_all2all_kernels(self): + return self.moe_parallel_config.use_fi_moe_a2a_kernels + @property def use_naive_all2all_kernels(self): return self.moe_parallel_config.use_naive_all2all_kernels diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index 465d0ae8f2c4..2fcecf1e0868 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -224,3 +224,117 @@ def flashinfer_alltoall_combine( top_k=top_k, token_count=token_count, ) + + +class FlashInferMoeA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): + """FlashInfer implementation using the Moe AlltoAll kernel.""" + + def __init__( + self, + max_num_tokens: int, + top_k: int, + num_experts: int, + hidden_size: int, + ): + super().__init__() + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.num_experts = num_experts + self.hidden_size = hidden_size + + self.all2all_manager = get_ep_group().device_communicator.all2all_manager + self.all2all_manager.initialize( + max_num_tokens=self.max_num_tokens, + top_k=self.top_k, + num_experts=self.num_experts, + hidden_size=self.hidden_size, + ) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input + ) + + global_num_tokens_cpu = get_local_sizes() + self.runtime_max_tokens_per_rank = ( + max(global_num_tokens_cpu) + if global_num_tokens_cpu is not None + else a1.shape[0] + ) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + + payloads = [] + payloads.append(a1q) + if a1q_scale is not None: + payloads.append(a1q_scale) + expert_id_payload_index = 2 + else: + expert_id_payload_index = 1 + payloads.append(topk_ids) + payloads.append(topk_weights) + + recv_payloads = self.all2all_manager.moe_alltoall.dispatch( + token_selected_experts=topk_ids, + input_payloads=payloads, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + # invalid_token_expert_id=-1, + # expert_id_payload_index=expert_id_payload_index, + ) + if a1q_scale is not None: + a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads + # Apply scale interleaving only for CUTLASS (not TRT-LLM) + if ( + quant_config.quant_dtype == "nvfp4" + and quant_config.is_nvfp4_scale_swizzled + ): + a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) + a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) + a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + else: + a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads + a1q_scale_recv = None + a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1]) + topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1]) + topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1]) + + return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + assert self.all2all_manager.moe_alltoall is not None + + ep_size = self.all2all_manager.world_size + hidden_size = fused_expert_output.shape[-1] + fused_expert_output = fused_expert_output.view( + ep_size, self.runtime_max_tokens_per_rank, hidden_size + ) + + combined_output = self.all2all_manager.moe_alltoall.combine( + payload=fused_expert_output, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + ) + output.copy_(combined_output) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index dd1a24d863de..8f5e8f738ce7 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -376,6 +376,19 @@ def make_nvfp4_moe_quant_config( g1_alphas = a13_scale * w13_scale_2 g2_alphas = a2_scale * w2_scale_2 + + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + return nvfp4_moe_quant_config( + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + a1_gscale=(1.0 / a13_scale), + a2_gscale=(1.0 / a2_scale), + w1_scale=w13_scale, + w2_scale=w2_scale, + is_nvfp4_scale_swizzled=False, + g1_scale_c=(g1_alphas / a2_scale), + ) + return nvfp4_moe_quant_config( g1_alphas=g1_alphas, g2_alphas=g2_alphas, @@ -383,10 +396,7 @@ def make_nvfp4_moe_quant_config( a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, w2_scale=w2_scale, - # NOTE(rob): this is a hack until the MoE kernels - # create their own quant configs. TRTLLM kernel - # does not accept swizzled input quant scales. - is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), + is_nvfp4_scale_swizzled=True, ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index c3ac839c21d1..355639517486 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -170,6 +170,14 @@ def has_flashinfer_all2all() -> bool: return True +@functools.cache +def has_flashinfer_moe_a2a() -> bool: + """Return `True` if FlashInfer trtllm_moe_alltoall module is available.""" + if not has_flashinfer_comm(): + return False + return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None + + @functools.cache def has_flashinfer_moe() -> bool: """Return `True` if FlashInfer MoE module is available.""" @@ -767,6 +775,7 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "has_flashinfer_moe", "has_flashinfer_comm", "has_flashinfer_all2all", + "has_flashinfer_moe_a2a", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_fp8_blockscale_gemm", From 3e26d15acb43476dc183550338d9e9b703141c0e Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Fri, 27 Feb 2026 06:04:05 +0000 Subject: [PATCH 02/12] Functional fix Signed-off-by: wzhao18 --- .../layers/fused_moe/all2all_utils.py | 6 +++- .../flashinfer_a2a_prepare_finalize.py | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index ba11bb2bd512..28c1374818b2 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -5,6 +5,7 @@ import torch +from vllm.config import get_current_vllm_config from vllm.distributed import ( get_ep_group, ) @@ -205,8 +206,11 @@ def maybe_make_prepare_finalize( elif moe.use_fi_moe_all2all_kernels: assert quant_config is not None + max_num_tokens = ( + get_current_vllm_config().scheduler_config.max_num_batched_tokens + ) prepare_finalize = FlashInferMoeA2APrepareAndFinalize( - max_num_tokens=moe.max_num_tokens, + max_num_tokens=max_num_tokens, top_k=moe.experts_per_token, num_experts=moe.num_experts, hidden_size=moe.hidden_dim, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index 2fcecf1e0868..e05bca281872 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -235,12 +235,14 @@ def __init__( top_k: int, num_experts: int, hidden_size: int, + num_dispatchers: int = 1, ): super().__init__() self.max_num_tokens = max_num_tokens self.top_k = top_k self.num_experts = num_experts self.hidden_size = hidden_size + self.num_dispatchers_ = num_dispatchers self.all2all_manager = get_ep_group().device_communicator.all2all_manager self.all2all_manager.initialize( @@ -250,6 +252,37 @@ def __init__( hidden_size=self.hidden_size, ) + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return False + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def _apply_router_weight_on_input( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """Apply router weight on input if needed.""" + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) + def prepare( self, a1: torch.Tensor, @@ -259,6 +292,7 @@ def prepare( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: self._apply_router_weight_on_input( a1, topk_weights, topk_ids, apply_router_weight_on_input @@ -305,6 +339,7 @@ def prepare( and quant_config.is_nvfp4_scale_swizzled ): a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) + a1q_scale_recv = a1q_scale_recv.view(torch.uint8) a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) else: From 6ebc00ea5d14b4701d9bf75c1e897cd2c7df8ecb Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Fri, 27 Feb 2026 16:23:38 +0000 Subject: [PATCH 03/12] Clean up Signed-off-by: wzhao18 --- .../device_communicators/all2all.py | 186 +++++++++--------- .../model_executor/layers/fused_moe/config.py | 10 - .../flashinfer_a2a_prepare_finalize.py | 29 +-- .../layers/fused_moe/oracle/nvfp4.py | 18 +- 4 files changed, 103 insertions(+), 140 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c2631d0c7ecc..75b8279f52b5 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -530,99 +530,6 @@ def cleanup(self): self.initialized = False -class MoriAll2AllManager(All2AllManagerBase): - def __init__(self, cpu_group): - assert has_mori(), ( - "MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md" - " to install MoRI kernels." - ) # noqa - import mori - - super().__init__(cpu_group) - self.handle_cache = Cache() - - torch._C._distributed_c10d._register_process_group("mori", cpu_group) - mori.shmem.shmem_torch_process_group_init("mori") - - def _make_all2all_kwargs( - self, - rank: int, - num_ep_ranks: int, - input_dtype: torch.dtype, - quant_dtype: torch.dtype, - token_hidden_size: int, - scale_dim: int, - scale_type_size: int, - max_num_tokens_per_dp_rank: int, - num_local_experts: int, - num_experts_per_token: int, - ): - import mori # type: ignore[import-not-found] - - from vllm.platforms.rocm import on_gfx942, on_gfx950 - - assert on_gfx942() or on_gfx950(), ( - "mori currently only support arch gfx942 and gfx950" - ) - - if not self.internode: - # single node - kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode - rdma_block_num = 0 - warp_num_per_block = 16 - block_num = 80 - else: - # multi node - kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1 - if on_gfx942(): - warp_num_per_block = 16 - block_num = 32 - rdma_block_num = 16 - elif on_gfx950(): - warp_num_per_block = 8 - block_num = 64 - rdma_block_num = 32 - else: - raise NotImplementedError( - "mori currently only support arch gfx942 and gfx950" - ) - - return dict( - rank=rank, - world_size=num_ep_ranks, - data_type=quant_dtype, - hidden_dim=token_hidden_size, - scale_dim=scale_dim, - scale_type_size=scale_type_size, - max_token_type_size=input_dtype.itemsize, - max_num_inp_token_per_rank=max_num_tokens_per_dp_rank, - num_experts_per_rank=num_local_experts, - num_experts_per_token=num_experts_per_token, - warp_num_per_block=warp_num_per_block, - block_num=block_num, - kernel_type=kernel_type, - rdma_block_num=rdma_block_num, - gpu_per_node=min(8, num_ep_ranks), - ) - - def _make_handle(self, **kwargs): - import mori # type: ignore[import-not-found] - - mori_config = mori.ops.EpDispatchCombineConfig(**kwargs) - handle = mori.ops.EpDispatchCombineOp(mori_config) - return handle - - def get_handle(self, kwargs): - import mori # type: ignore[import-not-found] - - mori_kwargs = self._make_all2all_kwargs(**kwargs) - logger.debug("MoRI all2all args %s", mori_kwargs) - handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create( - mori_kwargs, self._make_handle - ) - return handle - - class FlashInferMoeA2AManager(All2AllManagerBase): """ All2All communication based on FlashInfer's MoeAlltoAll kernel. @@ -732,3 +639,96 @@ def cleanup(self): self.moe_alltoall = None self.mapping = None self.initialized = False + + +class MoriAll2AllManager(All2AllManagerBase): + def __init__(self, cpu_group): + assert has_mori(), ( + "MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md" + " to install MoRI kernels." + ) # noqa + import mori + + super().__init__(cpu_group) + self.handle_cache = Cache() + + torch._C._distributed_c10d._register_process_group("mori", cpu_group) + mori.shmem.shmem_torch_process_group_init("mori") + + def _make_all2all_kwargs( + self, + rank: int, + num_ep_ranks: int, + input_dtype: torch.dtype, + quant_dtype: torch.dtype, + token_hidden_size: int, + scale_dim: int, + scale_type_size: int, + max_num_tokens_per_dp_rank: int, + num_local_experts: int, + num_experts_per_token: int, + ): + import mori # type: ignore[import-not-found] + + from vllm.platforms.rocm import on_gfx942, on_gfx950 + + assert on_gfx942() or on_gfx950(), ( + "mori currently only support arch gfx942 and gfx950" + ) + + if not self.internode: + # single node + kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode + rdma_block_num = 0 + warp_num_per_block = 16 + block_num = 80 + else: + # multi node + kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1 + if on_gfx942(): + warp_num_per_block = 16 + block_num = 32 + rdma_block_num = 16 + elif on_gfx950(): + warp_num_per_block = 8 + block_num = 64 + rdma_block_num = 32 + else: + raise NotImplementedError( + "mori currently only support arch gfx942 and gfx950" + ) + + return dict( + rank=rank, + world_size=num_ep_ranks, + data_type=quant_dtype, + hidden_dim=token_hidden_size, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + max_token_type_size=input_dtype.itemsize, + max_num_inp_token_per_rank=max_num_tokens_per_dp_rank, + num_experts_per_rank=num_local_experts, + num_experts_per_token=num_experts_per_token, + warp_num_per_block=warp_num_per_block, + block_num=block_num, + kernel_type=kernel_type, + rdma_block_num=rdma_block_num, + gpu_per_node=min(8, num_ep_ranks), + ) + + def _make_handle(self, **kwargs): + import mori # type: ignore[import-not-found] + + mori_config = mori.ops.EpDispatchCombineConfig(**kwargs) + handle = mori.ops.EpDispatchCombineOp(mori_config) + return handle + + def get_handle(self, kwargs): + import mori # type: ignore[import-not-found] + + mori_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("MoRI all2all args %s", mori_kwargs) + handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create( + mori_kwargs, self._make_handle + ) + return handle diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 87aaa0eef283..7c8ddfe48d5e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -229,7 +229,6 @@ class FusedMoEQuantConfig: _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc is_nvfp4_scale_swizzled: bool = True - _g1_scale_c: torch.Tensor | None = None def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( @@ -345,10 +344,6 @@ def w2_precision(self) -> "PrecisionConfig | None": def g2_alphas(self) -> torch.Tensor | None: return self._w2.alpha_or_gscale - @property - def g1_scale_c(self) -> torch.Tensor | None: - return self._g1_scale_c - @property def use_fp8_w8a8(self) -> bool: return self.quant_dtype == torch.float8_e4m3fn @@ -482,7 +477,6 @@ def make( w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, is_nvfp4_scale_swizzled: bool = True, - g1_scale_c: torch.Tensor | None = None, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -513,7 +507,6 @@ def make( - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. - - g1_scale_c: Pre-computed scale for TRT-LLM FP4 MoE kernel (a2_gscale * g1_alphas) """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", @@ -547,7 +540,6 @@ def make( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, - _g1_scale_c=g1_scale_c, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -750,7 +742,6 @@ def nvfp4_moe_quant_config( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, is_nvfp4_scale_swizzled: bool = True, - g1_scale_c: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. @@ -769,7 +760,6 @@ def nvfp4_moe_quant_config( per_out_ch_quant=False, block_shape=None, is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, - g1_scale_c=g1_scale_c, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index e05bca281872..4854d643217c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -268,21 +268,6 @@ def output_is_reduced(self) -> bool: def topk_indices_dtype(self) -> torch.dtype | None: return None - def _apply_router_weight_on_input( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: - """Apply router weight on input if needed.""" - if apply_router_weight_on_input: - topk = topk_ids.size(1) - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - a1.mul_(topk_weights.to(a1.dtype)) - def prepare( self, a1: torch.Tensor, @@ -294,9 +279,12 @@ def prepare( quant_config: FusedMoEQuantConfig, defer_input_quant: bool = False, ) -> mk.PrepareResultType: - self._apply_router_weight_on_input( - a1, topk_weights, topk_ids, apply_router_weight_on_input - ) + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) global_num_tokens_cpu = get_local_sizes() self.runtime_max_tokens_per_rank = ( @@ -318,9 +306,6 @@ def prepare( payloads.append(a1q) if a1q_scale is not None: payloads.append(a1q_scale) - expert_id_payload_index = 2 - else: - expert_id_payload_index = 1 payloads.append(topk_ids) payloads.append(topk_weights) @@ -328,8 +313,6 @@ def prepare( token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, - # invalid_token_expert_id=-1, - # expert_id_payload_index=expert_id_payload_index, ) if a1q_scale is not None: a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 8f5e8f738ce7..dd1a24d863de 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -376,19 +376,6 @@ def make_nvfp4_moe_quant_config( g1_alphas = a13_scale * w13_scale_2 g2_alphas = a2_scale * w2_scale_2 - - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - return nvfp4_moe_quant_config( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=(1.0 / a13_scale), - a2_gscale=(1.0 / a2_scale), - w1_scale=w13_scale, - w2_scale=w2_scale, - is_nvfp4_scale_swizzled=False, - g1_scale_c=(g1_alphas / a2_scale), - ) - return nvfp4_moe_quant_config( g1_alphas=g1_alphas, g2_alphas=g2_alphas, @@ -396,7 +383,10 @@ def make_nvfp4_moe_quant_config( a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, w2_scale=w2_scale, - is_nvfp4_scale_swizzled=True, + # NOTE(rob): this is a hack until the MoE kernels + # create their own quant configs. TRTLLM kernel + # does not accept swizzled input quant scales. + is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), ) From 2661092a11813fb439f87b014416d62b3543385d Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Fri, 27 Feb 2026 19:41:29 +0000 Subject: [PATCH 04/12] Set num_dispatchers Signed-off-by: wzhao18 --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 28c1374818b2..463492b16edd 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -214,6 +214,7 @@ def maybe_make_prepare_finalize( top_k=moe.experts_per_token, num_experts=moe.num_experts, hidden_size=moe.hidden_dim, + num_dispatchers=all2all_manager.world_size, ) elif moe.use_naive_all2all_kernels and allow_new_interface: From b74eaa332c1a876233b0a87016de896f9e37fb7e Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Fri, 27 Feb 2026 21:29:44 +0000 Subject: [PATCH 05/12] same flag for MoEConfig and moe_parallel_config Signed-off-by: wzhao18 --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 2 +- vllm/model_executor/layers/fused_moe/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 463492b16edd..de21088f15f3 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -204,7 +204,7 @@ def maybe_make_prepare_finalize( num_dispatchers=all2all_manager.world_size, ) - elif moe.use_fi_moe_all2all_kernels: + elif moe.use_fi_moe_a2a_kernels: assert quant_config is not None max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 7c8ddfe48d5e..a60f528a0680 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1244,7 +1244,7 @@ def use_fi_all2allv_kernels(self): return self.moe_parallel_config.use_fi_all2allv_kernels @property - def use_fi_moe_all2all_kernels(self): + def use_fi_moe_a2a_kernels(self): return self.moe_parallel_config.use_fi_moe_a2a_kernels @property From 8c11d321527bdbb9aa64cc85d24293f65b2f3f03 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Wed, 4 Mar 2026 03:26:18 +0000 Subject: [PATCH 06/12] Disable flashinfer moe a2a for non-nvfp4 moe backends Signed-off-by: wzhao18 --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 1 + vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/fused_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py | 5 ++++- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 64848bf931ae..8093aecb7461 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -398,6 +398,7 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo return not ( moe_parallel_config.use_fi_all2allv_kernels or moe_parallel_config.use_deepep_ht_kernels + or moe_parallel_config.use_fi_moe_a2a_kernels ) def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 8af439a0d435..4e71aa4e7870 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -152,7 +152,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: # NOTE(rob): discovered an IMA with this combination. Needs investigation. - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_all2allv_kernels + or moe_parallel_config.use_fi_moe_a2a_kernels + ) def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 280d090795e2..7004c7d06d40 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -600,7 +600,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_all2allv_kernels + or moe_parallel_config.use_fi_moe_a2a_kernels + ) @property def quant_type_id(self) -> int: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 023cdd0b4340..9275bc21c645 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1987,7 +1987,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_all2allv_kernels + or moe_parallel_config.use_fi_moe_a2a_kernels + ) def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index c550cad9e892..4311db25feba 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -332,7 +332,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_all2allv_kernels + or moe_parallel_config.use_fi_moe_a2a_kernels + ) def supports_expert_map(self): return True From eec985e8f361b93329e16a40b5b9de1bcfaac123 Mon Sep 17 00:00:00 2001 From: Leo Tian Date: Wed, 4 Mar 2026 14:57:00 -0800 Subject: [PATCH 07/12] move implementation Signed-off-by: Leo Tian --- .../layers/fused_moe/all2all_utils.py | 6 +- .../flashinfer_a2a_prepare_finalize.py | 132 ---------------- .../prepare_finalize/flashinfer_moe_a2a.py | 142 ++++++++++++++++++ 3 files changed, 144 insertions(+), 136 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index de21088f15f3..7ad682664f71 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,10 +15,8 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( - FlashInferA2APrepareAndFinalize, - FlashInferMoeA2APrepareAndFinalize, -) +from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import FlashInferA2APrepareAndFinalize +from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import FlashInferMoeA2APrepareAndFinalize from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index 4854d643217c..465d0ae8f2c4 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -224,135 +224,3 @@ def flashinfer_alltoall_combine( top_k=top_k, token_count=token_count, ) - - -class FlashInferMoeA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): - """FlashInfer implementation using the Moe AlltoAll kernel.""" - - def __init__( - self, - max_num_tokens: int, - top_k: int, - num_experts: int, - hidden_size: int, - num_dispatchers: int = 1, - ): - super().__init__() - self.max_num_tokens = max_num_tokens - self.top_k = top_k - self.num_experts = num_experts - self.hidden_size = hidden_size - self.num_dispatchers_ = num_dispatchers - - self.all2all_manager = get_ep_group().device_communicator.all2all_manager - self.all2all_manager.initialize( - max_num_tokens=self.max_num_tokens, - top_k=self.top_k, - num_experts=self.num_experts, - hidden_size=self.hidden_size, - ) - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def num_dispatchers(self) -> int: - return self.num_dispatchers_ - - def output_is_reduced(self) -> bool: - return False - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - a1.mul_(topk_weights.to(a1.dtype)) - - global_num_tokens_cpu = get_local_sizes() - self.runtime_max_tokens_per_rank = ( - max(global_num_tokens_cpu) - if global_num_tokens_cpu is not None - else a1.shape[0] - ) - - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) - - payloads = [] - payloads.append(a1q) - if a1q_scale is not None: - payloads.append(a1q_scale) - payloads.append(topk_ids) - payloads.append(topk_weights) - - recv_payloads = self.all2all_manager.moe_alltoall.dispatch( - token_selected_experts=topk_ids, - input_payloads=payloads, - runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, - ) - if a1q_scale is not None: - a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads - # Apply scale interleaving only for CUTLASS (not TRT-LLM) - if ( - quant_config.quant_dtype == "nvfp4" - and quant_config.is_nvfp4_scale_swizzled - ): - a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) - a1q_scale_recv = a1q_scale_recv.view(torch.uint8) - a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) - a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) - else: - a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads - a1q_scale_recv = None - a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1]) - topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1]) - topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1]) - - return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - assert self.all2all_manager.moe_alltoall is not None - - ep_size = self.all2all_manager.world_size - hidden_size = fused_expert_output.shape[-1] - fused_expert_output = fused_expert_output.view( - ep_size, self.runtime_max_tokens_per_rank, hidden_size - ) - - combined_output = self.all2all_manager.moe_alltoall.combine( - payload=fused_expert_output, - runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, - ) - output.copy_(combined_output) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py new file mode 100644 index 000000000000..7f4b64959443 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py @@ -0,0 +1,142 @@ +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + +def get_local_sizes(): + return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + +class FlashInferMoeA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): + """FlashInfer implementation using the Moe AlltoAll kernel.""" + + def __init__( + self, + max_num_tokens: int, + top_k: int, + num_experts: int, + hidden_size: int, + num_dispatchers: int = 1, + ): + super().__init__() + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.num_experts = num_experts + self.hidden_size = hidden_size + self.num_dispatchers_ = num_dispatchers + + self.all2all_manager = get_ep_group().device_communicator.all2all_manager + self.all2all_manager.initialize( + max_num_tokens=self.max_num_tokens, + top_k=self.top_k, + num_experts=self.num_experts, + hidden_size=self.hidden_size, + ) + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return False + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) + + global_num_tokens_cpu = get_local_sizes() + self.runtime_max_tokens_per_rank = ( + max(global_num_tokens_cpu) + if global_num_tokens_cpu is not None + else a1.shape[0] + ) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + + payloads = [] + payloads.append(a1q) + if a1q_scale is not None: + payloads.append(a1q_scale) + payloads.append(topk_ids) + payloads.append(topk_weights) + + recv_payloads = self.all2all_manager.moe_alltoall.dispatch( + token_selected_experts=topk_ids, + input_payloads=payloads, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + ) + if a1q_scale is not None: + a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads + # Apply scale interleaving only for CUTLASS (not TRT-LLM) + if ( + quant_config.quant_dtype == "nvfp4" + and quant_config.is_nvfp4_scale_swizzled + ): + a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) + a1q_scale_recv = a1q_scale_recv.view(torch.uint8) + a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) + a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + else: + a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads + a1q_scale_recv = None + a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1]) + topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1]) + topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1]) + + return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + assert self.all2all_manager.moe_alltoall is not None + + ep_size = self.all2all_manager.world_size + hidden_size = fused_expert_output.shape[-1] + fused_expert_output = fused_expert_output.view( + ep_size, self.runtime_max_tokens_per_rank, hidden_size + ) + + combined_output = self.all2all_manager.moe_alltoall.combine( + payload=fused_expert_output, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + ) + output.copy_(combined_output) From e7f7b85df10894b43a9a369df632bd9a0e3e1648 Mon Sep 17 00:00:00 2001 From: Leo Tian Date: Thu, 5 Mar 2026 09:11:50 -0800 Subject: [PATCH 08/12] precommit fixes Signed-off-by: Leo Tian --- vllm/model_executor/layers/fused_moe/all2all_utils.py | 8 ++++++-- .../fused_moe/prepare_finalize/flashinfer_moe_a2a.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 7ad682664f71..14086792b184 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,8 +15,9 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import FlashInferA2APrepareAndFinalize -from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import FlashInferMoeA2APrepareAndFinalize +from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( + FlashInferA2APrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) @@ -24,6 +25,9 @@ make_moe_prepare_and_finalize_naive_dp_ep, make_moe_prepare_and_finalize_no_dp_ep, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import ( + FlashInferMoeA2APrepareAndFinalize, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_mori diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py index 7f4b64959443..367565b236aa 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -7,9 +9,11 @@ from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave + def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + class FlashInferMoeA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """FlashInfer implementation using the Moe AlltoAll kernel.""" From b1060cda83ac15d84530a65490b919d4899f2821 Mon Sep 17 00:00:00 2001 From: Stefano Castagnetta Date: Tue, 10 Mar 2026 14:16:39 +0100 Subject: [PATCH 09/12] [Tests] Add test coverage for FlashInfer MoE A2A kernel backend - Register FlashInferMoeA2APrepareAndFinalize in mk_objects.py for combinatorial multi-GPU testing against compatible Expert backends - Register TrtLlmNvFp4ExpertsModular (previously missing from registry) - Add parametrized tests validating _supports_parallel_config incompatibility matrix for flashinfer_moe_a2a across 7 Expert types - Add parity test ensuring flashinfer_moe_a2a and flashinfer_all2allv share the same incompatibility matrix Signed-off-by: Leo Tian --- .../moe/modular_kernel_tools/mk_objects.py | 36 ++++- tests/kernels/moe/test_flashinfer_moe_a2a.py | 145 ++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/moe/test_flashinfer_moe_a2a.py diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index ee4190859e4c..cde2522b078f 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -33,7 +33,10 @@ ) from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.flashinfer import ( + has_flashinfer_cutlass_fused_moe, + has_flashinfer_moe_a2a, +) from vllm.utils.import_utils import ( has_aiter, has_deep_ep, @@ -270,6 +273,37 @@ def expert_info(kind) -> ExpertInfo: FlashInferCutlassMoEPrepareAndFinalize = None FlashInferExperts = None +if ( + has_flashinfer_moe_a2a() + and has_flashinfer_cutlass_fused_moe() + and current_platform.has_device_capability(100) +): + from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import ( # noqa: E501 + FlashInferMoeA2APrepareAndFinalize, + ) + + register_prepare_and_finalize( + FlashInferMoeA2APrepareAndFinalize, + standard_format, + nvfp4_types, + blocked_quantization_support=False, + backend="flashinfer_moe_a2a", + supports_apply_weight_on_input=False, + ) + +if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): + from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( + TrtLlmNvFp4ExpertsModular, + ) + + register_experts( + TrtLlmNvFp4ExpertsModular, + standard_format, + nvfp4_types, + blocked_quantization_support=False, + supports_chunking=True, + supports_expert_map=True, + ) if has_aiter(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( diff --git a/tests/kernels/moe/test_flashinfer_moe_a2a.py b/tests/kernels/moe/test_flashinfer_moe_a2a.py new file mode 100644 index 000000000000..d9a71ca00197 --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_moe_a2a.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for FlashInfer MoE A2A (trtllm_moe_alltoall) kernel backend. + +Validates the _supports_parallel_config incompatibility matrix to ensure +each Expert backend correctly accepts or rejects the flashinfer_moe_a2a +parallel configuration. No GPU required. + +See also: + - mk_objects.py for combinatorial registration of the new P/F and Experts + - PR #36022 (FlashInfer MoE A2A kernel backend) +""" + +import importlib + +import pytest + +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, +) + + +def _make_parallel_config(all2all_backend: str) -> FusedMoEParallelConfig: + """Create a FusedMoEParallelConfig with EP enabled for the given backend.""" + return FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=2, + ep_size=2, + tp_rank=0, + pcp_rank=0, + dp_rank=0, + ep_rank=0, + sp_size=1, + use_ep=True, + all2all_backend=all2all_backend, + enable_eplb=False, + ) + + +def _import_expert_cls(module_path: str, class_name: str, skip_reason: str | None): + """Import an Expert class, skipping the test if unavailable.""" + try: + mod = importlib.import_module(module_path) + return getattr(mod, class_name) + except (ImportError, AttributeError): + if skip_reason: + pytest.skip(skip_reason) + raise + + +# (module_path, class_name, supports_flashinfer_moe_a2a, skip_reason) +_EXPERT_COMPAT_CASES = [ + # Backends that reject flashinfer_moe_a2a (Standard format, no all2allv) + ( + "vllm.model_executor.layers.fused_moe.fused_moe", + "TritonExperts", + False, + None, + ), + ( + "vllm.model_executor.layers.fused_moe.deep_gemm_moe", + "DeepGemmExperts", + False, + "requires deep_gemm", + ), + ( + "vllm.model_executor.layers.fused_moe.fused_marlin_moe", + "MarlinExperts", + False, + None, + ), + ( + "vllm.model_executor.layers.fused_moe.cutlass_moe", + "CutlassExpertsFp8", + False, + "requires cutlass_fp8", + ), + # Backends that accept flashinfer_moe_a2a + ( + "vllm.model_executor.layers.fused_moe.fused_batched_moe", + "BatchedTritonExperts", + True, + None, + ), + ( + "vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe", + "FlashInferExperts", + True, + "requires flashinfer_cutlass on Blackwell", + ), + ( + "vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe", + "TrtLlmNvFp4ExpertsModular", + True, + "requires flashinfer trtllm", + ), +] + + +@pytest.mark.parametrize( + "module_path,class_name,expected_support,skip_reason", + _EXPERT_COMPAT_CASES, + ids=[c[1] for c in _EXPERT_COMPAT_CASES], +) +def test_supports_parallel_config_flashinfer_moe_a2a( + module_path: str, + class_name: str, + expected_support: bool, + skip_reason: str | None, +): + """Verify _supports_parallel_config for the flashinfer_moe_a2a backend.""" + cls = _import_expert_cls(module_path, class_name, skip_reason) + config = _make_parallel_config("flashinfer_moe_a2a") + result = cls._supports_parallel_config(config) + assert result == expected_support, ( + f"{class_name}._supports_parallel_config('flashinfer_moe_a2a') " + f"returned {result}, expected {expected_support}" + ) + + +@pytest.mark.parametrize( + "module_path,class_name,expected_support,skip_reason", + _EXPERT_COMPAT_CASES, + ids=[c[1] for c in _EXPERT_COMPAT_CASES], +) +def test_supports_parallel_config_parity_with_all2allv( + module_path: str, + class_name: str, + expected_support: bool, + skip_reason: str | None, +): + """Verify flashinfer_moe_a2a and flashinfer_all2allv share the same + incompatibility matrix (both reject and accept the same Expert backends). + """ + cls = _import_expert_cls(module_path, class_name, skip_reason) + config = _make_parallel_config("flashinfer_all2allv") + result = cls._supports_parallel_config(config) + assert result == expected_support, ( + f"{class_name}._supports_parallel_config('flashinfer_all2allv') " + f"returned {result}, expected {expected_support}. " + f"flashinfer_moe_a2a and flashinfer_all2allv should share the same " + f"incompatibility matrix." + ) From 12373feebc7b47918a41a1b9501b8e76ae565ba0 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 11 Mar 2026 10:06:07 -0700 Subject: [PATCH 10/12] rename all2allv and moe_a2a Signed-off-by: <> Signed-off-by: Leo Tian --- docs/design/moe_kernel_features.md | 2 +- docs/serving/expert_parallel_deployment.md | 2 +- .../moe/modular_kernel_tools/mk_objects.py | 20 +++++------ ...py => test_flashinfer_nvlink_one_sided.py} | 23 ++++++------- vllm/config/parallel.py | 7 ++-- .../device_communicators/all2all.py | 34 ++++++++++--------- .../device_communicators/cuda_communicator.py | 24 +++++++++---- .../device_communicators/mnnvl_compat.py | 4 +-- .../layers/fused_moe/all2all_utils.py | 16 ++++----- .../model_executor/layers/fused_moe/config.py | 22 +++++++----- .../layers/fused_moe/cutlass_moe.py | 4 +-- .../layers/fused_moe/deep_gemm_moe.py | 4 +-- ... flashinfer_nvlink_2s_prepare_finalize.py} | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 4 +-- .../layers/fused_moe/fused_moe.py | 4 +-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- ...fer_moe_a2a.py => flashinfer_nvlink_1s.py} | 2 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 4 +-- .../fused_moe/runner/default_moe_runner.py | 2 +- vllm/utils/flashinfer.py | 8 ++--- 20 files changed, 103 insertions(+), 87 deletions(-) rename tests/kernels/moe/{test_flashinfer_moe_a2a.py => test_flashinfer_nvlink_one_sided.py} (82%) rename vllm/model_executor/layers/fused_moe/{flashinfer_a2a_prepare_finalize.py => flashinfer_nvlink_2s_prepare_finalize.py} (98%) rename vllm/model_executor/layers/fused_moe/prepare_finalize/{flashinfer_moe_a2a.py => flashinfer_nvlink_1s.py} (98%) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0c92e597582e..5308ad030ff2 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -35,7 +35,7 @@ th { | naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] | | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | -| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | +| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize.FlashInferNVLinkTwoSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index d469e20c9866..e52388805cc6 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -21,7 +21,7 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to | `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration | | `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios | | `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios | -| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes | +| `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes | | `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production | ## Single Node Deployment diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index cde2522b078f..228f854319c5 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -35,7 +35,7 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import ( has_flashinfer_cutlass_fused_moe, - has_flashinfer_moe_a2a, + has_flashinfer_nvlink_one_sided, ) from vllm.utils.import_utils import ( has_aiter, @@ -243,15 +243,15 @@ def expert_info(kind) -> ExpertInfo: ) if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): - from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501 - FlashInferA2APrepareAndFinalize, - ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) + from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_2s_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkTwoSidedPrepareAndFinalize, + ) register_prepare_and_finalize( - FlashInferA2APrepareAndFinalize, + FlashInferNVLinkTwoSidedPrepareAndFinalize, standard_format, nvfp4_types + fp8_types, blocked_quantization_support=True, @@ -274,20 +274,20 @@ def expert_info(kind) -> ExpertInfo: FlashInferExperts = None if ( - has_flashinfer_moe_a2a() + has_flashinfer_nvlink_one_sided() and has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100) ): - from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import ( # noqa: E501 - FlashInferMoeA2APrepareAndFinalize, + from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_1s import ( # noqa: E501 + FlashInferNVLinkOneSidedPrepareAndFinalize, ) register_prepare_and_finalize( - FlashInferMoeA2APrepareAndFinalize, + FlashInferNVLinkOneSidedPrepareAndFinalize, standard_format, nvfp4_types, blocked_quantization_support=False, - backend="flashinfer_moe_a2a", + backend="flashinfer_nvlink_one_sided", supports_apply_weight_on_input=False, ) diff --git a/tests/kernels/moe/test_flashinfer_moe_a2a.py b/tests/kernels/moe/test_flashinfer_nvlink_one_sided.py similarity index 82% rename from tests/kernels/moe/test_flashinfer_moe_a2a.py rename to tests/kernels/moe/test_flashinfer_nvlink_one_sided.py index d9a71ca00197..cc6321f0163f 100644 --- a/tests/kernels/moe/test_flashinfer_moe_a2a.py +++ b/tests/kernels/moe/test_flashinfer_nvlink_one_sided.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Tests for FlashInfer MoE A2A (trtllm_moe_alltoall) kernel backend. +Tests for FlashInfer MoeAlltoAll/One-sided NVLink (trtllm_moe_alltoall) kernel backend. Validates the _supports_parallel_config incompatibility matrix to ensure -each Expert backend correctly accepts or rejects the flashinfer_moe_a2a +each Expert backend correctly accepts or rejects the flashinfer_nvlink_one_sided parallel configuration. No GPU required. See also: - mk_objects.py for combinatorial registration of the new P/F and Experts - - PR #36022 (FlashInfer MoE A2A kernel backend) """ import importlib @@ -50,9 +49,9 @@ def _import_expert_cls(module_path: str, class_name: str, skip_reason: str | Non raise -# (module_path, class_name, supports_flashinfer_moe_a2a, skip_reason) +# (module_path, class_name, supports_flashinfer_nvlink_one_sided, skip_reason) _EXPERT_COMPAT_CASES = [ - # Backends that reject flashinfer_moe_a2a (Standard format, no all2allv) + # Backends that reject flashinfer_nvlink_one_sided (Standard format, no all2allv) ( "vllm.model_executor.layers.fused_moe.fused_moe", "TritonExperts", @@ -77,7 +76,7 @@ def _import_expert_cls(module_path: str, class_name: str, skip_reason: str | Non False, "requires cutlass_fp8", ), - # Backends that accept flashinfer_moe_a2a + # Backends that accept flashinfer_nvlink_one_sided ( "vllm.model_executor.layers.fused_moe.fused_batched_moe", "BatchedTritonExperts", @@ -104,18 +103,18 @@ def _import_expert_cls(module_path: str, class_name: str, skip_reason: str | Non _EXPERT_COMPAT_CASES, ids=[c[1] for c in _EXPERT_COMPAT_CASES], ) -def test_supports_parallel_config_flashinfer_moe_a2a( +def test_supports_parallel_config_flashinfer_nvlink_one_sided( module_path: str, class_name: str, expected_support: bool, skip_reason: str | None, ): - """Verify _supports_parallel_config for the flashinfer_moe_a2a backend.""" + """Verify _supports_parallel_config for the flashinfer_nvlink_one_sided backend.""" cls = _import_expert_cls(module_path, class_name, skip_reason) - config = _make_parallel_config("flashinfer_moe_a2a") + config = _make_parallel_config("flashinfer_nvlink_one_sided") result = cls._supports_parallel_config(config) assert result == expected_support, ( - f"{class_name}._supports_parallel_config('flashinfer_moe_a2a') " + f"{class_name}._supports_parallel_config('flashinfer_nvlink_one_sided') " f"returned {result}, expected {expected_support}" ) @@ -131,7 +130,7 @@ def test_supports_parallel_config_parity_with_all2allv( expected_support: bool, skip_reason: str | None, ): - """Verify flashinfer_moe_a2a and flashinfer_all2allv share the same + """Verify flashinfer_nvlink_one_sided and flashinfer_all2allv share the same incompatibility matrix (both reject and accept the same Expert backends). """ cls = _import_expert_cls(module_path, class_name, skip_reason) @@ -140,6 +139,6 @@ def test_supports_parallel_config_parity_with_all2allv( assert result == expected_support, ( f"{class_name}._supports_parallel_config('flashinfer_all2allv') " f"returned {result}, expected {expected_support}. " - f"flashinfer_moe_a2a and flashinfer_all2allv should share the same " + f"flashinfer_nvlink_one_sided and flashinfer_all2allv should share the same " f"incompatibility matrix." ) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index a403bd7e60f9..f61c1f72058e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -43,8 +43,9 @@ "deepep_low_latency", "mori", "allgather_reducescatter", + "flashinfer_nvlink_two_sided", "flashinfer_all2allv", - "flashinfer_moe_a2a", + "flashinfer_nvlink_one_sided", ] @@ -156,8 +157,8 @@ class ParallelConfig: - "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n - "mori": Use mori kernels\n - - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl\n - - "flashinfer_moe_a2a": Use flashinfer moe alltoall kernels""" + - "flashinfer_nvlink_two_sided": Use flashinfer alltoallv kernels for mnnvl\n + - "flashinfer_nvlink_one_sided": Use flashinfer moe alltoall kernels""" max_parallel_loading_workers: int | None = None """Maximum number of parallel loading workers when loading model diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 75b8279f52b5..27d3261b13b5 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -10,21 +10,21 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils.flashinfer import ( - has_flashinfer_all2all, - has_flashinfer_moe_a2a, + has_flashinfer_nvlink_one_sided, + has_flashinfer_nvlink_two_sided, ) from vllm.utils.import_utils import has_deep_ep, has_mori from .base_device_communicator import All2AllManagerBase, Cache -if has_flashinfer_all2all(): +if has_flashinfer_nvlink_two_sided(): from flashinfer.comm import Mapping # type: ignore[import-not-found] from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] from flashinfer.comm.trtllm_alltoall import ( MnnvlMoe, # type: ignore[import-not-found] ) -if has_flashinfer_moe_a2a(): +if has_flashinfer_nvlink_one_sided(): from flashinfer.comm import Mapping # type: ignore[import-not-found] from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] from flashinfer.comm.trtllm_moe_alltoall import ( @@ -426,9 +426,9 @@ def max_sms_used(self) -> int | None: return 0 -class FlashInferAllToAllManager(All2AllManagerBase): +class FlashInferNVLinkTwoSidedManager(All2AllManagerBase): """ - All2All communication based on flashinfer kernels. + All2All communication based on flashinfer all2allv/two-sided NVLink kernels. """ # This type lint could be removed after all of the work in @@ -437,7 +437,7 @@ class FlashInferAllToAllManager(All2AllManagerBase): world_size: int def __init__(self, cpu_group, tcp_store_group=None): - assert has_flashinfer_all2all(), ( + assert has_flashinfer_nvlink_two_sided(), ( "flashinfer all2all module not found. Please install/check flashinfer" ) # noqa super().__init__(cpu_group, tcp_store_group) @@ -494,7 +494,7 @@ def initialize( def ensure_alltoall_workspace_initialized(self): """Ensure workspace is initialized""" - if not has_flashinfer_all2all(): + if not has_flashinfer_nvlink_two_sided(): return False if self.world_size <= 1: @@ -530,24 +530,24 @@ def cleanup(self): self.initialized = False -class FlashInferMoeA2AManager(All2AllManagerBase): +class FlashInferNVLinkOneSidedManager(All2AllManagerBase): """ - All2All communication based on FlashInfer's MoeAlltoAll kernel. + All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel. This is a newer kernel from trtllm that should perform better than the kernel - used by flashinfer_all2allv. + used by flashinfer_nvlink_two_sided. """ rank: int world_size: int def __init__(self, cpu_group): - assert has_flashinfer_moe_a2a(), ( + assert has_flashinfer_nvlink_one_sided(), ( "flashinfer trtllm_moe_alltoall module not found. " "Please install/check flashinfer" ) super().__init__(cpu_group) logger.debug( - "Initialize FlashInfer MoeA2A rank=%d, world size=%d", + "Initialize FlashInfer One-sided NVLink rank=%d, world size=%d", self.rank, self.world_size, ) @@ -569,7 +569,7 @@ def initialize( self.cleanup() gpus_per_node = torch.cuda.device_count() logger.debug( - "Making MoeA2A mapping: rank=%d, world size=%d", + "Making One-sided NVLink mapping: rank=%d, world size=%d", self.rank, self.world_size, ) @@ -619,7 +619,7 @@ def initialize( self.initialized = True logger.info( - "FlashInfer MoeA2A initialized for rank %s, size %s", + "FlashInfer One-sided NVLink initialized for rank %s, size %s", self.rank, self.world_size, ) @@ -634,7 +634,9 @@ def cleanup(self): try: del self.moe_alltoall except Exception as e: - logger.warning("Failed to cleanup FlashInfer MoeA2A workspace: %s", e) + logger.warning( + "Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e + ) finally: self.moe_alltoall = None self.mapping = None diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 51ef6f305f67..02377358c47e 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -143,16 +143,26 @@ def __init__( from .all2all import MoriAll2AllManager self.all2all_manager = MoriAll2AllManager(self.cpu_group) - elif self.all2all_backend == "flashinfer_all2allv": - from .all2all import FlashInferAllToAllManager - - self.all2all_manager = FlashInferAllToAllManager( + elif ( + self.all2all_backend == "flashinfer_all2allv" + or self.all2all_backend == "flashinfer_nvlink_two_sided" + ): + if self.all2all_backend == "flashinfer_all2allv": + logger.warning_once( + "'flashinfer_all2allv' is deprecated and has been renamed to" + "'flashinfer_nvlink_two_sided'. It will be removed in a future" + "release." + ) + + from .all2all import FlashInferNVLinkTwoSidedManager + + self.all2all_manager = FlashInferNVLinkTwoSidedManager( self.cpu_group, tcp_store_group ) - elif self.all2all_backend == "flashinfer_moe_a2a": - from .all2all import FlashInferMoeA2AManager + elif self.all2all_backend == "flashinfer_nvlink_one_sided": + from .all2all import FlashInferNVLinkOneSidedManager - self.all2all_manager = FlashInferMoeA2AManager(self.cpu_group) + self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py index cf78d7bbcccb..2a431ad15f3f 100644 --- a/vllm/distributed/device_communicators/mnnvl_compat.py +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -5,9 +5,9 @@ import torch.distributed as dist from flashinfer.comm.mnnvl import CommBackend as CommBackend -from vllm.utils.flashinfer import has_flashinfer_all2all +from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided -assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found" +assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found" class CustomCommunicator(CommBackend): diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 14086792b184..0f05567e401d 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,8 +15,8 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( - FlashInferA2APrepareAndFinalize, +from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_2s_prepare_finalize import ( + FlashInferNVLinkTwoSidedPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, @@ -25,8 +25,8 @@ make_moe_prepare_and_finalize_naive_dp_ep, make_moe_prepare_and_finalize_no_dp_ep, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_moe_a2a import ( - FlashInferMoeA2APrepareAndFinalize, +from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_1s import ( + FlashInferNVLinkOneSidedPrepareAndFinalize, ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_mori @@ -200,18 +200,18 @@ def maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) - elif moe.use_fi_all2allv_kernels: + elif moe.use_fi_nvl_two_sided_kernels: assert quant_config is not None - prepare_finalize = FlashInferA2APrepareAndFinalize( + prepare_finalize = FlashInferNVLinkTwoSidedPrepareAndFinalize( num_dispatchers=all2all_manager.world_size, ) - elif moe.use_fi_moe_a2a_kernels: + elif moe.use_fi_nvl_one_sided_kernels: assert quant_config is not None max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens ) - prepare_finalize = FlashInferMoeA2APrepareAndFinalize( + prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( max_num_tokens=max_num_tokens, top_k=moe.experts_per_token, num_experts=moe.num_experts, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a60f528a0680..ced382abbdd2 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -957,14 +957,18 @@ def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @property - def use_fi_all2allv_kernels(self): - return ( - self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + def use_fi_nvl_two_sided_kernels(self): + return self.use_all2all_kernels and ( + self.all2all_backend == "flashinfer_all2allv" + or self.all2all_backend == "flashinfer_nvlink_two_sided" ) @property - def use_fi_moe_a2a_kernels(self): - return self.use_all2all_kernels and self.all2all_backend == "flashinfer_moe_a2a" + def use_fi_nvl_one_sided_kernels(self): + return ( + self.use_all2all_kernels + and self.all2all_backend == "flashinfer_nvlink_one_sided" + ) @property def use_batched_activation_format(self): @@ -1240,12 +1244,12 @@ def use_mori_kernels(self): return self.moe_parallel_config.use_mori_kernels @property - def use_fi_all2allv_kernels(self): - return self.moe_parallel_config.use_fi_all2allv_kernels + def use_fi_nvl_two_sided_kernels(self): + return self.moe_parallel_config.use_fi_nvl_two_sided_kernels @property - def use_fi_moe_a2a_kernels(self): - return self.moe_parallel_config.use_fi_moe_a2a_kernels + def use_fi_nvl_one_sided_kernels(self): + return self.moe_parallel_config.use_fi_nvl_one_sided_kernels @property def use_naive_all2all_kernels(self): diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 8093aecb7461..4129f759015c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -396,9 +396,9 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo # Note that the BATCHED activation format does not use # the expert map for identifying experts. return not ( - moe_parallel_config.use_fi_all2allv_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels or moe_parallel_config.use_deepep_ht_kernels - or moe_parallel_config.use_fi_moe_a2a_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 4e71aa4e7870..3232cb9f27c8 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -153,8 +153,8 @@ def _supports_activation(activation: MoEActivation) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: # NOTE(rob): discovered an IMA with this combination. Needs investigation. return not ( - moe_parallel_config.use_fi_all2allv_kernels - or moe_parallel_config.use_fi_moe_a2a_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_2s_prepare_finalize.py similarity index 98% rename from vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py rename to vllm/model_executor/layers/fused_moe/flashinfer_nvlink_2s_prepare_finalize.py index 465d0ae8f2c4..be63bd4e3f61 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_2s_prepare_finalize.py @@ -18,7 +18,7 @@ def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() -class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): +class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 7004c7d06d40..95c3cfab89f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -601,8 +601,8 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return not ( - moe_parallel_config.use_fi_all2allv_kernels - or moe_parallel_config.use_fi_moe_a2a_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) @property diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9275bc21c645..65c296b266b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1988,8 +1988,8 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return not ( - moe_parallel_config.use_fi_all2allv_kernels - or moe_parallel_config.use_fi_moe_a2a_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) def supports_chunking(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6200477092ab..32a85315eeb2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -637,7 +637,7 @@ def _get_quant_method() -> FusedMoEMethodBase: self.use_overlapped = ( not ( (self.enable_eplb and backend != "allgather_reducescatter") - or self.moe_parallel_config.use_fi_all2allv_kernels + or self.moe_parallel_config.use_fi_nvl_two_sided_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py similarity index 98% rename from vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py rename to vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py index 367565b236aa..f1282a27978e 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_moe_a2a.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py @@ -14,7 +14,7 @@ def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() -class FlashInferMoeA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): +class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """FlashInfer implementation using the Moe AlltoAll kernel.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 4311db25feba..1fed3e96c563 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -333,8 +333,8 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: return not ( - moe_parallel_config.use_fi_all2allv_kernels - or moe_parallel_config.use_fi_moe_a2a_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) def supports_expert_map(self): diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index e9e849b25910..1edf98931a65 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -233,7 +233,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_config.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.moe_parallel_config.use_mori_kernels - or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels + or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK def _maybe_setup_shared_experts_stream( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 355639517486..fed44d04fb5e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -150,7 +150,7 @@ def has_flashinfer_comm() -> bool: @functools.cache -def has_flashinfer_all2all() -> bool: +def has_flashinfer_nvlink_two_sided() -> bool: """Return `True` if FlashInfer mnnvl all2all is available.""" if not has_flashinfer_comm(): return False @@ -171,7 +171,7 @@ def has_flashinfer_all2all() -> bool: @functools.cache -def has_flashinfer_moe_a2a() -> bool: +def has_flashinfer_nvlink_one_sided() -> bool: """Return `True` if FlashInfer trtllm_moe_alltoall module is available.""" if not has_flashinfer_comm(): return False @@ -774,8 +774,8 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "autotune", "has_flashinfer_moe", "has_flashinfer_comm", - "has_flashinfer_all2all", - "has_flashinfer_moe_a2a", + "has_flashinfer_nvlink_two_sided", + "has_flashinfer_nvlink_one_sided", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_fp8_blockscale_gemm", From 3e7f513c76ad98f435da46c58962cfa5180b3981 Mon Sep 17 00:00:00 2001 From: Leo Tian Date: Thu, 12 Mar 2026 12:23:56 -0700 Subject: [PATCH 11/12] fix topk dtype issue Signed-off-by: Leo Tian --- .../layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py index f1282a27978e..bdde3da6b3a3 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py @@ -54,7 +54,7 @@ def output_is_reduced(self) -> bool: return False def topk_indices_dtype(self) -> torch.dtype | None: - return None + return torch.int32 def prepare( self, From d7930ed4baf6f149783d0717c722fab498bf85c5 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Mon, 16 Mar 2026 03:28:43 +0000 Subject: [PATCH 12/12] Rename one-sided prepare finalize file name Signed-off-by: wzhao18 --- docs/design/moe_kernel_features.md | 1 + .../moe/modular_kernel_tools/mk_objects.py | 4 +- .../moe/test_flashinfer_nvlink_one_sided.py | 144 ------------------ .../layers/fused_moe/all2all_utils.py | 8 +- ...nfer_nvlink_one_sided_prepare_finalize.py} | 0 ...nfer_nvlink_two_sided_prepare_finalize.py} | 0 6 files changed, 7 insertions(+), 150 deletions(-) delete mode 100644 tests/kernels/moe/test_flashinfer_nvlink_one_sided.py rename vllm/model_executor/layers/fused_moe/{prepare_finalize/flashinfer_nvlink_1s.py => flashinfer_nvlink_one_sided_prepare_finalize.py} (100%) rename vllm/model_executor/layers/fused_moe/{flashinfer_nvlink_2s_prepare_finalize.py => flashinfer_nvlink_two_sided_prepare_finalize.py} (100%) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index fa5791fbe926..ea8956e204a5 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,6 +36,7 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize.FlashInferNVLinkTwoSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index b5d29824b0a3..68cf07d7cf51 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -240,7 +240,7 @@ def expert_info(kind) -> ExpertInfo: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) - from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_2s_prepare_finalize import ( # noqa: E501 + from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501 FlashInferNVLinkTwoSidedPrepareAndFinalize, ) @@ -271,7 +271,7 @@ def expert_info(kind) -> ExpertInfo: and has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100) ): - from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_1s import ( # noqa: E501 + from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501 FlashInferNVLinkOneSidedPrepareAndFinalize, ) diff --git a/tests/kernels/moe/test_flashinfer_nvlink_one_sided.py b/tests/kernels/moe/test_flashinfer_nvlink_one_sided.py deleted file mode 100644 index cc6321f0163f..000000000000 --- a/tests/kernels/moe/test_flashinfer_nvlink_one_sided.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests for FlashInfer MoeAlltoAll/One-sided NVLink (trtllm_moe_alltoall) kernel backend. - -Validates the _supports_parallel_config incompatibility matrix to ensure -each Expert backend correctly accepts or rejects the flashinfer_nvlink_one_sided -parallel configuration. No GPU required. - -See also: - - mk_objects.py for combinatorial registration of the new P/F and Experts -""" - -import importlib - -import pytest - -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEParallelConfig, -) - - -def _make_parallel_config(all2all_backend: str) -> FusedMoEParallelConfig: - """Create a FusedMoEParallelConfig with EP enabled for the given backend.""" - return FusedMoEParallelConfig( - tp_size=1, - pcp_size=1, - dp_size=2, - ep_size=2, - tp_rank=0, - pcp_rank=0, - dp_rank=0, - ep_rank=0, - sp_size=1, - use_ep=True, - all2all_backend=all2all_backend, - enable_eplb=False, - ) - - -def _import_expert_cls(module_path: str, class_name: str, skip_reason: str | None): - """Import an Expert class, skipping the test if unavailable.""" - try: - mod = importlib.import_module(module_path) - return getattr(mod, class_name) - except (ImportError, AttributeError): - if skip_reason: - pytest.skip(skip_reason) - raise - - -# (module_path, class_name, supports_flashinfer_nvlink_one_sided, skip_reason) -_EXPERT_COMPAT_CASES = [ - # Backends that reject flashinfer_nvlink_one_sided (Standard format, no all2allv) - ( - "vllm.model_executor.layers.fused_moe.fused_moe", - "TritonExperts", - False, - None, - ), - ( - "vllm.model_executor.layers.fused_moe.deep_gemm_moe", - "DeepGemmExperts", - False, - "requires deep_gemm", - ), - ( - "vllm.model_executor.layers.fused_moe.fused_marlin_moe", - "MarlinExperts", - False, - None, - ), - ( - "vllm.model_executor.layers.fused_moe.cutlass_moe", - "CutlassExpertsFp8", - False, - "requires cutlass_fp8", - ), - # Backends that accept flashinfer_nvlink_one_sided - ( - "vllm.model_executor.layers.fused_moe.fused_batched_moe", - "BatchedTritonExperts", - True, - None, - ), - ( - "vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe", - "FlashInferExperts", - True, - "requires flashinfer_cutlass on Blackwell", - ), - ( - "vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe", - "TrtLlmNvFp4ExpertsModular", - True, - "requires flashinfer trtllm", - ), -] - - -@pytest.mark.parametrize( - "module_path,class_name,expected_support,skip_reason", - _EXPERT_COMPAT_CASES, - ids=[c[1] for c in _EXPERT_COMPAT_CASES], -) -def test_supports_parallel_config_flashinfer_nvlink_one_sided( - module_path: str, - class_name: str, - expected_support: bool, - skip_reason: str | None, -): - """Verify _supports_parallel_config for the flashinfer_nvlink_one_sided backend.""" - cls = _import_expert_cls(module_path, class_name, skip_reason) - config = _make_parallel_config("flashinfer_nvlink_one_sided") - result = cls._supports_parallel_config(config) - assert result == expected_support, ( - f"{class_name}._supports_parallel_config('flashinfer_nvlink_one_sided') " - f"returned {result}, expected {expected_support}" - ) - - -@pytest.mark.parametrize( - "module_path,class_name,expected_support,skip_reason", - _EXPERT_COMPAT_CASES, - ids=[c[1] for c in _EXPERT_COMPAT_CASES], -) -def test_supports_parallel_config_parity_with_all2allv( - module_path: str, - class_name: str, - expected_support: bool, - skip_reason: str | None, -): - """Verify flashinfer_nvlink_one_sided and flashinfer_all2allv share the same - incompatibility matrix (both reject and accept the same Expert backends). - """ - cls = _import_expert_cls(module_path, class_name, skip_reason) - config = _make_parallel_config("flashinfer_all2allv") - result = cls._supports_parallel_config(config) - assert result == expected_support, ( - f"{class_name}._supports_parallel_config('flashinfer_all2allv') " - f"returned {result}, expected {expected_support}. " - f"flashinfer_nvlink_one_sided and flashinfer_all2allv should share the same " - f"incompatibility matrix." - ) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index be4831e10b22..4498a8a9306c 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,7 +15,10 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_2s_prepare_finalize import ( +from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkOneSidedPrepareAndFinalize, +) +from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501 FlashInferNVLinkTwoSidedPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -25,9 +28,6 @@ make_moe_prepare_and_finalize_naive_dp_ep, make_moe_prepare_and_finalize_no_dp_ep, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_1s import ( - FlashInferNVLinkOneSidedPrepareAndFinalize, -) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py similarity index 100% rename from vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_1s.py rename to vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_2s_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py similarity index 100% rename from vllm/model_executor/layers/fused_moe/flashinfer_nvlink_2s_prepare_finalize.py rename to vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py