diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 9c19456f1287..ea8956e204a5 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -35,7 +35,8 @@ th {
| naive | standard | all1 | G,A,T | N | 6 | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] |
| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
-| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
+| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize.FlashInferNVLinkTwoSidedPrepareAndFinalize] |
+| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize.FlashInferNVLinkOneSidedPrepareAndFinalize] |
!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md
index cfad36c2d914..3b13872a23b8 100644
--- a/docs/serving/expert_parallel_deployment.md
+++ b/docs/serving/expert_parallel_deployment.md
@@ -21,7 +21,8 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
-| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
+| `flashinfer_nvlink_one_sided` | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads |
+| `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes |
| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production |
## Single Node Deployment
diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
index 38a9857ccfed..68cf07d7cf51 100644
--- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py
+++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
@@ -33,7 +33,10 @@
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
-from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
+from vllm.utils.flashinfer import (
+ has_flashinfer_cutlass_fused_moe,
+ has_flashinfer_nvlink_one_sided,
+)
from vllm.utils.import_utils import (
has_aiter,
has_deep_ep,
@@ -234,15 +237,15 @@ def expert_info(kind) -> ExpertInfo:
)
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
- from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
- FlashInferA2APrepareAndFinalize,
- )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
+ from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501
+ FlashInferNVLinkTwoSidedPrepareAndFinalize,
+ )
register_prepare_and_finalize(
- FlashInferA2APrepareAndFinalize,
+ FlashInferNVLinkTwoSidedPrepareAndFinalize,
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
@@ -263,6 +266,36 @@ def expert_info(kind) -> ExpertInfo:
FlashInferCutlassMoEPrepareAndFinalize = None
FlashInferExperts = None
+if (
+ has_flashinfer_nvlink_one_sided()
+ and has_flashinfer_cutlass_fused_moe()
+ and current_platform.has_device_capability(100)
+):
+ from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501
+ FlashInferNVLinkOneSidedPrepareAndFinalize,
+ )
+
+ register_prepare_and_finalize(
+ FlashInferNVLinkOneSidedPrepareAndFinalize,
+ standard_format,
+ nvfp4_types,
+ blocked_quantization_support=False,
+ backend="flashinfer_nvlink_one_sided",
+ supports_apply_weight_on_input=False,
+ )
+
+if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
+ from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
+ TrtLlmNvFp4ExpertsModular,
+ )
+
+ register_experts(
+ TrtLlmNvFp4ExpertsModular,
+ standard_format,
+ nvfp4_types,
+ blocked_quantization_support=False,
+ supports_expert_map=True,
+ )
if has_aiter():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index fcad56133325..f7f952af66e1 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -45,7 +45,9 @@
"mori",
"nixl_ep",
"allgather_reducescatter",
- "flashinfer_all2allv",
+ "flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
+ "flashinfer_nvlink_two_sided",
+ "flashinfer_nvlink_one_sided",
]
@@ -158,7 +160,8 @@ class ParallelConfig:
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "nixl_ep": Use nixl-ep kernels\n
- - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
+ - "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
+ - "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index de5c5a79c15c..0cdff90320da 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -4,23 +4,36 @@
from typing import Any
import torch
+import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
-from vllm.utils.flashinfer import has_flashinfer_all2all
+from vllm.utils.flashinfer import (
+ has_flashinfer_nvlink_one_sided,
+ has_flashinfer_nvlink_two_sided,
+)
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
-if has_flashinfer_all2all():
+if has_flashinfer_nvlink_two_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
+if has_flashinfer_nvlink_one_sided():
+ from flashinfer.comm import Mapping # type: ignore[import-not-found]
+ from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
+ from flashinfer.comm.trtllm_moe_alltoall import (
+ MoeAlltoAll, # type: ignore[import-not-found]
+ moe_a2a_get_workspace_size_per_rank,
+ )
+
+
logger = init_logger(__name__)
@@ -529,9 +542,9 @@ def max_sms_used(self) -> int | None:
return 0
-class FlashInferAllToAllManager(All2AllManagerBase):
+class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
"""
- All2All communication based on flashinfer kernels.
+ All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
"""
# This type lint could be removed after all of the work in
@@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
world_size: int
def __init__(self, cpu_group, tcp_store_group=None):
- assert has_flashinfer_all2all(), (
+ assert has_flashinfer_nvlink_two_sided(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group, tcp_store_group)
@@ -597,7 +610,7 @@ def initialize(
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
- if not has_flashinfer_all2all():
+ if not has_flashinfer_nvlink_two_sided():
return False
if self.world_size <= 1:
@@ -633,6 +646,119 @@ def cleanup(self):
self.initialized = False
+class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
+ """
+ All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
+ This is a newer kernel from trtllm that should perform better than the kernel
+ used by flashinfer_nvlink_two_sided.
+ """
+
+ rank: int
+ world_size: int
+
+ def __init__(self, cpu_group):
+ assert has_flashinfer_nvlink_one_sided(), (
+ "flashinfer trtllm_moe_alltoall module not found. "
+ "Please install/check flashinfer"
+ )
+ super().__init__(cpu_group)
+ logger.debug(
+ "Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
+ self.rank,
+ self.world_size,
+ )
+ self.initialized = False
+ self.moe_alltoall: MoeAlltoAll | None = None
+ self.mapping = None
+
+ def initialize(
+ self,
+ max_num_tokens: int,
+ top_k: int,
+ num_experts: int,
+ hidden_size: int,
+ ):
+ """Initialize the MoeAlltoAll workspace."""
+ if self.initialized:
+ return
+
+ self.cleanup()
+ gpus_per_node = torch.accelerator.device_count()
+ logger.debug(
+ "Making One-sided NVLink mapping: rank=%d, world size=%d",
+ self.rank,
+ self.world_size,
+ )
+ self.mapping = Mapping(
+ self.world_size,
+ self.rank,
+ gpus_per_node,
+ tp_size=self.world_size,
+ moe_ep_size=self.world_size,
+ )
+
+ from vllm.distributed.device_communicators.mnnvl_compat import (
+ CustomCommunicator,
+ )
+
+ dp_config = MnnvlConfig(
+ comm_backend=CustomCommunicator(get_dp_group().cpu_group),
+ )
+ total_dispatch_payload_size_per_token = (
+ hidden_size // 2 # nvfp4 hidden states
+ + hidden_size // 16 # fp8 scaling factors
+ + top_k * 4 # int32 topks ids
+ + top_k * 4 # float32 topk weights
+ )
+ combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
+ self.workspace_size = moe_a2a_get_workspace_size_per_rank(
+ ep_size=self.world_size,
+ max_num_tokens=max_num_tokens,
+ total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
+ combine_payload_size_per_token=combine_payload_size_per_token,
+ )
+
+ self.moe_alltoall = MoeAlltoAll(
+ mapping=self.mapping,
+ max_num_tokens=max_num_tokens,
+ top_k=top_k,
+ num_experts=num_experts,
+ workspace_size_per_rank=self.workspace_size,
+ mnnvl_config=dp_config,
+ )
+
+ self.gpus_per_node = gpus_per_node
+ self.max_num_tokens = max_num_tokens
+ self.top_k = top_k
+ self.num_experts = num_experts
+ self.hidden_size = hidden_size
+ self.initialized = True
+
+ logger.info(
+ "FlashInfer One-sided NVLink initialized for rank %s, size %s",
+ self.rank,
+ self.world_size,
+ )
+ dist.barrier()
+
+ def get_handle(self, kwargs):
+ return self
+
+ def cleanup(self):
+ """Clean up resources."""
+ if self.initialized and self.moe_alltoall is not None:
+ try:
+ del self.moe_alltoall
+ except Exception as e:
+ logger.warning(
+ "Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
+ )
+ finally:
+ self.moe_alltoall = None
+ self.mapping = None
+ self.initialized = False
+
+
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index faa3d093ad2d..bd5741e8dc72 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -149,12 +149,25 @@ def __init__(
self.all2all_manager = NixlEPAll2AllManager(
self.cpu_group, tcp_store_group
)
- elif self.all2all_backend == "flashinfer_all2allv":
- from .all2all import FlashInferAllToAllManager
-
- self.all2all_manager = FlashInferAllToAllManager(
+ elif (
+ self.all2all_backend == "flashinfer_all2allv"
+ or self.all2all_backend == "flashinfer_nvlink_two_sided"
+ ):
+ if self.all2all_backend == "flashinfer_all2allv":
+ logger.warning_once(
+ "'flashinfer_all2allv' is deprecated and has been renamed to"
+ "'flashinfer_nvlink_two_sided'. It will be removed in a future"
+ "release."
+ )
+ from .all2all import FlashInferNVLinkTwoSidedManager
+
+ self.all2all_manager = FlashInferNVLinkTwoSidedManager(
self.cpu_group, tcp_store_group
)
+ elif self.all2all_backend == "flashinfer_nvlink_one_sided":
+ from .all2all import FlashInferNVLinkOneSidedManager
+
+ self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py
index 81f4ae20738d..2a431ad15f3f 100644
--- a/vllm/distributed/device_communicators/mnnvl_compat.py
+++ b/vllm/distributed/device_communicators/mnnvl_compat.py
@@ -5,9 +5,9 @@
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
-from vllm.utils.flashinfer import has_flashinfer_all2all
+from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided
-assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
+assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
@@ -25,14 +25,14 @@ def allgather(self, data: int):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
- # NOTE(rob): CommBackend is an abstract class, and bcast/barrier
- # are unimplemented on vLLM side. If we need to utilize these
- # methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
- raise NotImplementedError
+ obj_list = [data]
+ # broadcast_object_list mutates obj_list in-place
+ dist.broadcast_object_list(obj_list, src=root, group=self._group)
+ return obj_list[0]
def barrier(self) -> None:
- raise NotImplementedError
+ dist.barrier(group=self._group)
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index 4d215645ecd4..4498a8a9306c 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -5,6 +5,7 @@
import torch
+from vllm.config import get_current_vllm_config
from vllm.distributed import (
get_ep_group,
)
@@ -14,8 +15,11 @@
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
- FlashInferA2APrepareAndFinalize,
+from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501
+ FlashInferNVLinkOneSidedPrepareAndFinalize,
+)
+from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501
+ FlashInferNVLinkTwoSidedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
@@ -206,9 +210,22 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)
- elif moe.use_fi_all2allv_kernels:
+ elif moe.use_fi_nvl_two_sided_kernels:
+ assert quant_config is not None
+ prepare_finalize = FlashInferNVLinkTwoSidedPrepareAndFinalize(
+ num_dispatchers=all2all_manager.world_size,
+ )
+
+ elif moe.use_fi_nvl_one_sided_kernels:
assert quant_config is not None
- prepare_finalize = FlashInferA2APrepareAndFinalize(
+ max_num_tokens = (
+ get_current_vllm_config().scheduler_config.max_num_batched_tokens
+ )
+ prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize(
+ max_num_tokens=max_num_tokens,
+ top_k=moe.experts_per_token,
+ num_experts=moe.num_experts,
+ hidden_size=moe.hidden_dim,
num_dispatchers=all2all_manager.world_size,
)
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 57c787ca65a1..2500387debe1 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -957,9 +957,17 @@ def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
- def use_fi_all2allv_kernels(self):
+ def use_fi_nvl_two_sided_kernels(self):
+ return self.use_all2all_kernels and (
+ self.all2all_backend == "flashinfer_all2allv"
+ or self.all2all_backend == "flashinfer_nvlink_two_sided"
+ )
+
+ @property
+ def use_fi_nvl_one_sided_kernels(self):
return (
- self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
+ self.use_all2all_kernels
+ and self.all2all_backend == "flashinfer_nvlink_one_sided"
)
@property
@@ -1240,8 +1248,12 @@ def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
- def use_fi_all2allv_kernels(self):
- return self.moe_parallel_config.use_fi_all2allv_kernels
+ def use_fi_nvl_two_sided_kernels(self):
+ return self.moe_parallel_config.use_fi_nvl_two_sided_kernels
+
+ @property
+ def use_fi_nvl_one_sided_kernels(self):
+ return self.moe_parallel_config.use_fi_nvl_one_sided_kernels
@property
def use_naive_all2all_kernels(self):
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index 69a30f89ef72..51a97e0a2610 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -396,8 +396,9 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not (
- moe_parallel_config.use_fi_all2allv_kernels
+ moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_deepep_ht_kernels
+ or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index 18b3da34422e..03341378a13c 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -152,7 +152,10 @@ def _supports_activation(activation: MoEActivation) -> bool:
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
- return not moe_parallel_config.use_fi_all2allv_kernels
+ return not (
+ moe_parallel_config.use_fi_nvl_two_sided_kernels
+ or moe_parallel_config.use_fi_nvl_one_sided_kernels
+ )
def supports_expert_map(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py
new file mode 100644
index 000000000000..bdde3da6b3a3
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_one_sided_prepare_finalize.py
@@ -0,0 +1,146 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.distributed import get_ep_group
+from vllm.forward_context import get_forward_context
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+from vllm.utils.flashinfer import nvfp4_block_scale_interleave
+
+
+def get_local_sizes():
+ return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
+
+
+class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
+ """FlashInfer implementation using the Moe AlltoAll kernel."""
+
+ def __init__(
+ self,
+ max_num_tokens: int,
+ top_k: int,
+ num_experts: int,
+ hidden_size: int,
+ num_dispatchers: int = 1,
+ ):
+ super().__init__()
+ self.max_num_tokens = max_num_tokens
+ self.top_k = top_k
+ self.num_experts = num_experts
+ self.hidden_size = hidden_size
+ self.num_dispatchers_ = num_dispatchers
+
+ self.all2all_manager = get_ep_group().device_communicator.all2all_manager
+ self.all2all_manager.initialize(
+ max_num_tokens=self.max_num_tokens,
+ top_k=self.top_k,
+ num_experts=self.num_experts,
+ hidden_size=self.hidden_size,
+ )
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self.num_dispatchers_
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return torch.int32
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareResultType:
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ a1.mul_(topk_weights.to(a1.dtype))
+
+ global_num_tokens_cpu = get_local_sizes()
+ self.runtime_max_tokens_per_rank = (
+ max(global_num_tokens_cpu)
+ if global_num_tokens_cpu is not None
+ else a1.shape[0]
+ )
+
+ a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ quant_config.a1_gscale,
+ quant_config.quant_dtype,
+ quant_config.per_act_token_quant,
+ quant_config.block_shape,
+ is_fp4_scale_swizzled=False, # delay swizzle to after comm
+ )
+
+ payloads = []
+ payloads.append(a1q)
+ if a1q_scale is not None:
+ payloads.append(a1q_scale)
+ payloads.append(topk_ids)
+ payloads.append(topk_weights)
+
+ recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
+ token_selected_experts=topk_ids,
+ input_payloads=payloads,
+ runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
+ )
+ if a1q_scale is not None:
+ a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
+ # Apply scale interleaving only for CUTLASS (not TRT-LLM)
+ if (
+ quant_config.quant_dtype == "nvfp4"
+ and quant_config.is_nvfp4_scale_swizzled
+ ):
+ a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
+ a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
+ a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
+ a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
+ else:
+ a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
+ a1q_scale_recv = None
+ a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1])
+ topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1])
+ topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1])
+
+ return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ assert self.all2all_manager.moe_alltoall is not None
+
+ ep_size = self.all2all_manager.world_size
+ hidden_size = fused_expert_output.shape[-1]
+ fused_expert_output = fused_expert_output.view(
+ ep_size, self.runtime_max_tokens_per_rank, hidden_size
+ )
+
+ combined_output = self.all2all_manager.moe_alltoall.combine(
+ payload=fused_expert_output,
+ runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
+ )
+ output.copy_(combined_output)
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py
similarity index 98%
rename from vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
rename to vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py
index 465d0ae8f2c4..be63bd4e3f61 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_nvlink_two_sided_prepare_finalize.py
@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
-class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
+class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index 86fef2528345..45575ab09c40 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -600,7 +600,10 @@ def _supports_activation(activation: MoEActivation) -> bool:
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return not moe_parallel_config.use_fi_all2allv_kernels
+ return not (
+ moe_parallel_config.use_fi_nvl_two_sided_kernels
+ or moe_parallel_config.use_fi_nvl_one_sided_kernels
+ )
@property
def quant_type_id(self) -> int:
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 70adac711f5a..03ca8ba119c0 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -1965,7 +1965,10 @@ def _supports_activation(activation: MoEActivation) -> bool:
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return not moe_parallel_config.use_fi_all2allv_kernels
+ return not (
+ moe_parallel_config.use_fi_nvl_two_sided_kernels
+ or moe_parallel_config.use_fi_nvl_one_sided_kernels
+ )
def supports_expert_map(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index fd759f22b1ff..7135cbbd2d7c 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -638,7 +638,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
self.use_overlapped = (
not (
(self.enable_eplb and backend != "allgather_reducescatter")
- or self.moe_parallel_config.use_fi_all2allv_kernels
+ or self.moe_parallel_config.use_fi_nvl_two_sided_kernels
)
and self._shared_experts is not None
)
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index 6d178d587c69..b1a4b0d59d2b 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -332,7 +332,10 @@ def _supports_activation(activation: MoEActivation) -> bool:
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return not moe_parallel_config.use_fi_all2allv_kernels
+ return not (
+ moe_parallel_config.use_fi_nvl_two_sided_kernels
+ or moe_parallel_config.use_fi_nvl_one_sided_kernels
+ )
def supports_expert_map(self):
return True
diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
index d3c950dcbb33..b6313776e85d 100644
--- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
+++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
@@ -233,7 +233,7 @@ def use_dp_chunking(self) -> bool:
return (
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
- or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
+ or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py
index c3ac839c21d1..fed44d04fb5e 100644
--- a/vllm/utils/flashinfer.py
+++ b/vllm/utils/flashinfer.py
@@ -150,7 +150,7 @@ def has_flashinfer_comm() -> bool:
@functools.cache
-def has_flashinfer_all2all() -> bool:
+def has_flashinfer_nvlink_two_sided() -> bool:
"""Return `True` if FlashInfer mnnvl all2all is available."""
if not has_flashinfer_comm():
return False
@@ -170,6 +170,14 @@ def has_flashinfer_all2all() -> bool:
return True
+@functools.cache
+def has_flashinfer_nvlink_one_sided() -> bool:
+ """Return `True` if FlashInfer trtllm_moe_alltoall module is available."""
+ if not has_flashinfer_comm():
+ return False
+ return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None
+
+
@functools.cache
def has_flashinfer_moe() -> bool:
"""Return `True` if FlashInfer MoE module is available."""
@@ -766,7 +774,8 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
"autotune",
"has_flashinfer_moe",
"has_flashinfer_comm",
- "has_flashinfer_all2all",
+ "has_flashinfer_nvlink_two_sided",
+ "has_flashinfer_nvlink_one_sided",
"has_flashinfer_cutlass_fused_moe",
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_flashinfer_fp8_blockscale_gemm",