diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 9c19456f1287..ea8956e204a5 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -35,7 +35,8 @@ th { | naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] | | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | -| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | +| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize.FlashInferNVLinkTwoSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index cfad36c2d914..3b13872a23b8 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -21,7 +21,8 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to | `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration | | `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios | | `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios | -| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes | +| `flashinfer_nvlink_one_sided` | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads | +| `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes | | `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production | ## Single Node Deployment diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 38a9857ccfed..68cf07d7cf51 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -33,7 +33,10 @@ ) from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.flashinfer import ( + has_flashinfer_cutlass_fused_moe, + has_flashinfer_nvlink_one_sided, +) from vllm.utils.import_utils import ( has_aiter, has_deep_ep, @@ -234,15 +237,15 @@ def expert_info(kind) -> ExpertInfo: ) if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): - from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501 - FlashInferA2APrepareAndFinalize, - ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) + from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkTwoSidedPrepareAndFinalize, + ) register_prepare_and_finalize( - FlashInferA2APrepareAndFinalize, + FlashInferNVLinkTwoSidedPrepareAndFinalize, standard_format, nvfp4_types + fp8_types, blocked_quantization_support=True, @@ -263,6 +266,36 @@ def expert_info(kind) -> ExpertInfo: FlashInferCutlassMoEPrepareAndFinalize = None FlashInferExperts = None +if ( + has_flashinfer_nvlink_one_sided() + and has_flashinfer_cutlass_fused_moe() + and current_platform.has_device_capability(100) +): + from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkOneSidedPrepareAndFinalize, + ) + + register_prepare_and_finalize( + FlashInferNVLinkOneSidedPrepareAndFinalize, + standard_format, + nvfp4_types, + blocked_quantization_support=False, + backend="flashinfer_nvlink_one_sided", + supports_apply_weight_on_input=False, + ) + +if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): + from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( + TrtLlmNvFp4ExpertsModular, + ) + + register_experts( + TrtLlmNvFp4ExpertsModular, + standard_format, + nvfp4_types, + blocked_quantization_support=False, + supports_expert_map=True, + ) if has_aiter(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index fcad56133325..f7f952af66e1 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -45,7 +45,9 @@ "mori", "nixl_ep", "allgather_reducescatter", - "flashinfer_all2allv", + "flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided + "flashinfer_nvlink_two_sided", + "flashinfer_nvlink_one_sided", ] @@ -158,7 +160,8 @@ class ParallelConfig: - "deepep_low_latency": Use deepep low-latency kernels\n - "mori": Use mori kernels\n - "nixl_ep": Use nixl-ep kernels\n - - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" + - "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl + - "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels""" max_parallel_loading_workers: int | None = None """Maximum number of parallel loading workers when loading model diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index de5c5a79c15c..0cdff90320da 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -4,23 +4,36 @@ from typing import Any import torch +import torch.distributed as dist import vllm.envs as envs from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils.flashinfer import has_flashinfer_all2all +from vllm.utils.flashinfer import ( + has_flashinfer_nvlink_one_sided, + has_flashinfer_nvlink_two_sided, +) from vllm.utils.import_utils import has_deep_ep, has_mori from .base_device_communicator import All2AllManagerBase, Cache -if has_flashinfer_all2all(): +if has_flashinfer_nvlink_two_sided(): from flashinfer.comm import Mapping # type: ignore[import-not-found] from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] from flashinfer.comm.trtllm_alltoall import ( MnnvlMoe, # type: ignore[import-not-found] ) +if has_flashinfer_nvlink_one_sided(): + from flashinfer.comm import Mapping # type: ignore[import-not-found] + from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] + from flashinfer.comm.trtllm_moe_alltoall import ( + MoeAlltoAll, # type: ignore[import-not-found] + moe_a2a_get_workspace_size_per_rank, + ) + + logger = init_logger(__name__) @@ -529,9 +542,9 @@ def max_sms_used(self) -> int | None: return 0 -class FlashInferAllToAllManager(All2AllManagerBase): +class FlashInferNVLinkTwoSidedManager(All2AllManagerBase): """ - All2All communication based on flashinfer kernels. + All2All communication based on flashinfer all2allv/two-sided NVLink kernels. """ # This type lint could be removed after all of the work in @@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase): world_size: int def __init__(self, cpu_group, tcp_store_group=None): - assert has_flashinfer_all2all(), ( + assert has_flashinfer_nvlink_two_sided(), ( "flashinfer all2all module not found. Please install/check flashinfer" ) # noqa super().__init__(cpu_group, tcp_store_group) @@ -597,7 +610,7 @@ def initialize( def ensure_alltoall_workspace_initialized(self): """Ensure workspace is initialized""" - if not has_flashinfer_all2all(): + if not has_flashinfer_nvlink_two_sided(): return False if self.world_size <= 1: @@ -633,6 +646,119 @@ def cleanup(self): self.initialized = False +class FlashInferNVLinkOneSidedManager(All2AllManagerBase): + """ + All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel. + This is a newer kernel from trtllm that should perform better than the kernel + used by flashinfer_nvlink_two_sided. + """ + + rank: int + world_size: int + + def __init__(self, cpu_group): + assert has_flashinfer_nvlink_one_sided(), ( + "flashinfer trtllm_moe_alltoall module not found. " + "Please install/check flashinfer" + ) + super().__init__(cpu_group) + logger.debug( + "Initialize FlashInfer One-sided NVLink rank=%d, world size=%d", + self.rank, + self.world_size, + ) + self.initialized = False + self.moe_alltoall: MoeAlltoAll | None = None + self.mapping = None + + def initialize( + self, + max_num_tokens: int, + top_k: int, + num_experts: int, + hidden_size: int, + ): + """Initialize the MoeAlltoAll workspace.""" + if self.initialized: + return + + self.cleanup() + gpus_per_node = torch.accelerator.device_count() + logger.debug( + "Making One-sided NVLink mapping: rank=%d, world size=%d", + self.rank, + self.world_size, + ) + self.mapping = Mapping( + self.world_size, + self.rank, + gpus_per_node, + tp_size=self.world_size, + moe_ep_size=self.world_size, + ) + + from vllm.distributed.device_communicators.mnnvl_compat import ( + CustomCommunicator, + ) + + dp_config = MnnvlConfig( + comm_backend=CustomCommunicator(get_dp_group().cpu_group), + ) + total_dispatch_payload_size_per_token = ( + hidden_size // 2 # nvfp4 hidden states + + hidden_size // 16 # fp8 scaling factors + + top_k * 4 # int32 topks ids + + top_k * 4 # float32 topk weights + ) + combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states + self.workspace_size = moe_a2a_get_workspace_size_per_rank( + ep_size=self.world_size, + max_num_tokens=max_num_tokens, + total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token, + combine_payload_size_per_token=combine_payload_size_per_token, + ) + + self.moe_alltoall = MoeAlltoAll( + mapping=self.mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=num_experts, + workspace_size_per_rank=self.workspace_size, + mnnvl_config=dp_config, + ) + + self.gpus_per_node = gpus_per_node + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.num_experts = num_experts + self.hidden_size = hidden_size + self.initialized = True + + logger.info( + "FlashInfer One-sided NVLink initialized for rank %s, size %s", + self.rank, + self.world_size, + ) + dist.barrier() + + def get_handle(self, kwargs): + return self + + def cleanup(self): + """Clean up resources.""" + if self.initialized and self.moe_alltoall is not None: + try: + del self.moe_alltoall + except Exception as e: + logger.warning( + "Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e + ) + finally: + self.moe_alltoall = None + self.mapping = None + self.initialized = False + + class MoriAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): assert has_mori(), ( diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index faa3d093ad2d..bd5741e8dc72 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -149,12 +149,25 @@ def __init__( self.all2all_manager = NixlEPAll2AllManager( self.cpu_group, tcp_store_group ) - elif self.all2all_backend == "flashinfer_all2allv": - from .all2all import FlashInferAllToAllManager - - self.all2all_manager = FlashInferAllToAllManager( + elif ( + self.all2all_backend == "flashinfer_all2allv" + or self.all2all_backend == "flashinfer_nvlink_two_sided" + ): + if self.all2all_backend == "flashinfer_all2allv": + logger.warning_once( + "'flashinfer_all2allv' is deprecated and has been renamed to" + "'flashinfer_nvlink_two_sided'. It will be removed in a future" + "release." + ) + from .all2all import FlashInferNVLinkTwoSidedManager + + self.all2all_manager = FlashInferNVLinkTwoSidedManager( self.cpu_group, tcp_store_group ) + elif self.all2all_backend == "flashinfer_nvlink_one_sided": + from .all2all import FlashInferNVLinkOneSidedManager + + self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py index 81f4ae20738d..2a431ad15f3f 100644 --- a/vllm/distributed/device_communicators/mnnvl_compat.py +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -5,9 +5,9 @@ import torch.distributed as dist from flashinfer.comm.mnnvl import CommBackend as CommBackend -from vllm.utils.flashinfer import has_flashinfer_all2all +from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided -assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found" +assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found" class CustomCommunicator(CommBackend): @@ -25,14 +25,14 @@ def allgather(self, data: int): dist.all_gather_object(gathered, data, group=self._group) return gathered - # NOTE(rob): CommBackend is an abstract class, and bcast/barrier - # are unimplemented on vLLM side. If we need to utilize these - # methods in the future, can create a concrete implementation. def bcast(self, data: Any, root: int) -> Any: - raise NotImplementedError + obj_list = [data] + # broadcast_object_list mutates obj_list in-place + dist.broadcast_object_list(obj_list, src=root, group=self._group) + return obj_list[0] def barrier(self) -> None: - raise NotImplementedError + dist.barrier(group=self._group) def Split(self, color: int, key: int) -> "CustomCommunicator": return self diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 4d215645ecd4..4498a8a9306c 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -5,6 +5,7 @@ import torch +from vllm.config import get_current_vllm_config from vllm.distributed import ( get_ep_group, ) @@ -14,8 +15,11 @@ FusedMoEParallelConfig, FusedMoEQuantConfig, ) -from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( - FlashInferA2APrepareAndFinalize, +from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkOneSidedPrepareAndFinalize, +) +from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501 + FlashInferNVLinkTwoSidedPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, @@ -206,9 +210,22 @@ def maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) - elif moe.use_fi_all2allv_kernels: + elif moe.use_fi_nvl_two_sided_kernels: + assert quant_config is not None + prepare_finalize = FlashInferNVLinkTwoSidedPrepareAndFinalize( + num_dispatchers=all2all_manager.world_size, + ) + + elif moe.use_fi_nvl_one_sided_kernels: assert quant_config is not None - prepare_finalize = FlashInferA2APrepareAndFinalize( + max_num_tokens = ( + get_current_vllm_config().scheduler_config.max_num_batched_tokens + ) + prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( + max_num_tokens=max_num_tokens, + top_k=moe.experts_per_token, + num_experts=moe.num_experts, + hidden_size=moe.hidden_dim, num_dispatchers=all2all_manager.world_size, ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 57c787ca65a1..2500387debe1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -957,9 +957,17 @@ def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @property - def use_fi_all2allv_kernels(self): + def use_fi_nvl_two_sided_kernels(self): + return self.use_all2all_kernels and ( + self.all2all_backend == "flashinfer_all2allv" + or self.all2all_backend == "flashinfer_nvlink_two_sided" + ) + + @property + def use_fi_nvl_one_sided_kernels(self): return ( - self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + self.use_all2all_kernels + and self.all2all_backend == "flashinfer_nvlink_one_sided" ) @property @@ -1240,8 +1248,12 @@ def use_mori_kernels(self): return self.moe_parallel_config.use_mori_kernels @property - def use_fi_all2allv_kernels(self): - return self.moe_parallel_config.use_fi_all2allv_kernels + def use_fi_nvl_two_sided_kernels(self): + return self.moe_parallel_config.use_fi_nvl_two_sided_kernels + + @property + def use_fi_nvl_one_sided_kernels(self): + return self.moe_parallel_config.use_fi_nvl_one_sided_kernels @property def use_naive_all2all_kernels(self): diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 69a30f89ef72..51a97e0a2610 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -396,8 +396,9 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo # Note that the BATCHED activation format does not use # the expert map for identifying experts. return not ( - moe_parallel_config.use_fi_all2allv_kernels + moe_parallel_config.use_fi_nvl_two_sided_kernels or moe_parallel_config.use_deepep_ht_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels ) def supports_expert_map(self) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 18b3da34422e..03341378a13c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -152,7 +152,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: # NOTE(rob): discovered an IMA with this combination. Needs investigation. - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels + ) def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py new file mode 100644 index 000000000000..bdde3da6b3a3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +def get_local_sizes(): + return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + + +class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): + """FlashInfer implementation using the Moe AlltoAll kernel.""" + + def __init__( + self, + max_num_tokens: int, + top_k: int, + num_experts: int, + hidden_size: int, + num_dispatchers: int = 1, + ): + super().__init__() + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.num_experts = num_experts + self.hidden_size = hidden_size + self.num_dispatchers_ = num_dispatchers + + self.all2all_manager = get_ep_group().device_communicator.all2all_manager + self.all2all_manager.initialize( + max_num_tokens=self.max_num_tokens, + top_k=self.top_k, + num_experts=self.num_experts, + hidden_size=self.hidden_size, + ) + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return False + + def topk_indices_dtype(self) -> torch.dtype | None: + return torch.int32 + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) + + global_num_tokens_cpu = get_local_sizes() + self.runtime_max_tokens_per_rank = ( + max(global_num_tokens_cpu) + if global_num_tokens_cpu is not None + else a1.shape[0] + ) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + + payloads = [] + payloads.append(a1q) + if a1q_scale is not None: + payloads.append(a1q_scale) + payloads.append(topk_ids) + payloads.append(topk_weights) + + recv_payloads = self.all2all_manager.moe_alltoall.dispatch( + token_selected_experts=topk_ids, + input_payloads=payloads, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + ) + if a1q_scale is not None: + a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads + # Apply scale interleaving only for CUTLASS (not TRT-LLM) + if ( + quant_config.quant_dtype == "nvfp4" + and quant_config.is_nvfp4_scale_swizzled + ): + a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) + a1q_scale_recv = a1q_scale_recv.view(torch.uint8) + a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) + a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + else: + a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads + a1q_scale_recv = None + a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1]) + topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1]) + topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1]) + + return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + assert self.all2all_manager.moe_alltoall is not None + + ep_size = self.all2all_manager.world_size + hidden_size = fused_expert_output.shape[-1] + fused_expert_output = fused_expert_output.view( + ep_size, self.runtime_max_tokens_per_rank, hidden_size + ) + + combined_output = self.all2all_manager.moe_alltoall.combine( + payload=fused_expert_output, + runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + ) + output.copy_(combined_output) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py similarity index 98% rename from vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py rename to vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py index 465d0ae8f2c4..be63bd4e3f61 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py @@ -18,7 +18,7 @@ def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() -class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): +class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 86fef2528345..45575ab09c40 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -600,7 +600,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels + ) @property def quant_type_id(self) -> int: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 70adac711f5a..03ca8ba119c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1965,7 +1965,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels + ) def supports_expert_map(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fd759f22b1ff..7135cbbd2d7c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -638,7 +638,7 @@ def _get_quant_method() -> FusedMoEMethodBase: self.use_overlapped = ( not ( (self.enable_eplb and backend != "allgather_reducescatter") - or self.moe_parallel_config.use_fi_all2allv_kernels + or self.moe_parallel_config.use_fi_nvl_two_sided_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 6d178d587c69..b1a4b0d59d2b 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -332,7 +332,10 @@ def _supports_activation(activation: MoEActivation) -> bool: @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return not moe_parallel_config.use_fi_all2allv_kernels + return not ( + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels + ) def supports_expert_map(self): return True diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index d3c950dcbb33..b6313776e85d 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -233,7 +233,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_config.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.moe_parallel_config.use_mori_kernels - or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels + or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels or self.moe_config.moe_parallel_config.use_nixl_ep_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index c3ac839c21d1..fed44d04fb5e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -150,7 +150,7 @@ def has_flashinfer_comm() -> bool: @functools.cache -def has_flashinfer_all2all() -> bool: +def has_flashinfer_nvlink_two_sided() -> bool: """Return `True` if FlashInfer mnnvl all2all is available.""" if not has_flashinfer_comm(): return False @@ -170,6 +170,14 @@ def has_flashinfer_all2all() -> bool: return True +@functools.cache +def has_flashinfer_nvlink_one_sided() -> bool: + """Return `True` if FlashInfer trtllm_moe_alltoall module is available.""" + if not has_flashinfer_comm(): + return False + return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None + + @functools.cache def has_flashinfer_moe() -> bool: """Return `True` if FlashInfer MoE module is available.""" @@ -766,7 +774,8 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "autotune", "has_flashinfer_moe", "has_flashinfer_comm", - "has_flashinfer_all2all", + "has_flashinfer_nvlink_two_sided", + "has_flashinfer_nvlink_one_sided", "has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_fp8_blockscale_gemm",