Skip to content
3 changes: 2 additions & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ th {
| naive | standard | all<sup>1</sup> | G,A,T | N | <sup>6</sup> | [layer.py][vllm.model_executor.layers.fused_moe.layer.FusedMoE] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize.FlashInferNVLinkTwoSidedPrepareAndFinalize] |
| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize.FlashInferNVLinkOneSidedPrepareAndFinalize] |

!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
Expand Down
3 changes: 2 additions & 1 deletion docs/serving/expert_parallel_deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ vLLM provides multiple communication backends for EP. Use `--all2all-backend` to
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
| `flashinfer_nvlink_one_sided` | MNNVL systems | FlashInfer's one-sided A2A strategy for multi-node NVLink | High-throughput workloads |
| `flashinfer_nvlink_two_sided` | MNNVL systems | FlashInfer's two-sided A2A strategy for multi-node NVLink | Systems with NVLink across nodes |
| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production |

## Single Node Deployment
Expand Down
43 changes: 38 additions & 5 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.flashinfer import (
has_flashinfer_cutlass_fused_moe,
has_flashinfer_nvlink_one_sided,
)
from vllm.utils.import_utils import (
has_aiter,
has_deep_ep,
Expand Down Expand Up @@ -234,15 +237,15 @@ def expert_info(kind) -> ExpertInfo:
)

if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501
FlashInferNVLinkTwoSidedPrepareAndFinalize,
)

register_prepare_and_finalize(
FlashInferA2APrepareAndFinalize,
FlashInferNVLinkTwoSidedPrepareAndFinalize,
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
Expand All @@ -263,6 +266,36 @@ def expert_info(kind) -> ExpertInfo:
FlashInferCutlassMoEPrepareAndFinalize = None
FlashInferExperts = None

if (
has_flashinfer_nvlink_one_sided()
and has_flashinfer_cutlass_fused_moe()
and current_platform.has_device_capability(100)
):
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501
FlashInferNVLinkOneSidedPrepareAndFinalize,
)

register_prepare_and_finalize(
FlashInferNVLinkOneSidedPrepareAndFinalize,
standard_format,
nvfp4_types,
blocked_quantization_support=False,
backend="flashinfer_nvlink_one_sided",
supports_apply_weight_on_input=False,
)

if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
TrtLlmNvFp4ExpertsModular,
)

register_experts(
TrtLlmNvFp4ExpertsModular,
standard_format,
nvfp4_types,
blocked_quantization_support=False,
supports_expert_map=True,
)

if has_aiter():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
Expand Down
7 changes: 5 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"mori",
"nixl_ep",
"allgather_reducescatter",
"flashinfer_all2allv",
"flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
"flashinfer_nvlink_two_sided",
"flashinfer_nvlink_one_sided",
]


Expand Down Expand Up @@ -158,7 +160,8 @@ class ParallelConfig:
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "nixl_ep": Use nixl-ep kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""

max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model
Expand Down
138 changes: 132 additions & 6 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,36 @@
from typing import Any

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided,
)
from vllm.utils.import_utils import has_deep_ep, has_mori

from .base_device_communicator import All2AllManagerBase, Cache

if has_flashinfer_all2all():
if has_flashinfer_nvlink_two_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)

if has_flashinfer_nvlink_one_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_moe_alltoall import (
MoeAlltoAll, # type: ignore[import-not-found]
moe_a2a_get_workspace_size_per_rank,
)


logger = init_logger(__name__)


Expand Down Expand Up @@ -529,9 +542,9 @@ def max_sms_used(self) -> int | None:
return 0


class FlashInferAllToAllManager(All2AllManagerBase):
class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
"""

# This type lint could be removed after all of the work in
Expand All @@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
world_size: int

def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
assert has_flashinfer_nvlink_two_sided(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group, tcp_store_group)
Expand Down Expand Up @@ -597,7 +610,7 @@ def initialize(

def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
if not has_flashinfer_nvlink_two_sided():
return False

if self.world_size <= 1:
Expand Down Expand Up @@ -633,6 +646,119 @@ def cleanup(self):
self.initialized = False


class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
"""
All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
This is a newer kernel from trtllm that should perform better than the kernel
used by flashinfer_nvlink_two_sided.
"""

rank: int
world_size: int

def __init__(self, cpu_group):
assert has_flashinfer_nvlink_one_sided(), (
"flashinfer trtllm_moe_alltoall module not found. "
"Please install/check flashinfer"
)
super().__init__(cpu_group)
logger.debug(
"Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.initialized = False
self.moe_alltoall: MoeAlltoAll | None = None
self.mapping = None

def initialize(
self,
max_num_tokens: int,
top_k: int,
num_experts: int,
hidden_size: int,
):
"""Initialize the MoeAlltoAll workspace."""
if self.initialized:
return

self.cleanup()
gpus_per_node = torch.accelerator.device_count()
logger.debug(
"Making One-sided NVLink mapping: rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.mapping = Mapping(
self.world_size,
self.rank,
gpus_per_node,
tp_size=self.world_size,
moe_ep_size=self.world_size,
)

from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator,
)

dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
)
total_dispatch_payload_size_per_token = (
hidden_size // 2 # nvfp4 hidden states
+ hidden_size // 16 # fp8 scaling factors
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
ep_size=self.world_size,
max_num_tokens=max_num_tokens,
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
combine_payload_size_per_token=combine_payload_size_per_token,
)

self.moe_alltoall = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_experts=num_experts,
workspace_size_per_rank=self.workspace_size,
mnnvl_config=dp_config,
)

self.gpus_per_node = gpus_per_node
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.initialized = True

logger.info(
"FlashInfer One-sided NVLink initialized for rank %s, size %s",
self.rank,
self.world_size,
)
dist.barrier()

def get_handle(self, kwargs):
return self

def cleanup(self):
"""Clean up resources."""
if self.initialized and self.moe_alltoall is not None:
try:
del self.moe_alltoall
except Exception as e:
logger.warning(
"Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
)
finally:
self.moe_alltoall = None
self.mapping = None
self.initialized = False


class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (
Expand Down
21 changes: 17 additions & 4 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,25 @@ def __init__(
self.all2all_manager = NixlEPAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager

self.all2all_manager = FlashInferAllToAllManager(
elif (
self.all2all_backend == "flashinfer_all2allv"
or self.all2all_backend == "flashinfer_nvlink_two_sided"
):
if self.all2all_backend == "flashinfer_all2allv":
logger.warning_once(
"'flashinfer_all2allv' is deprecated and has been renamed to"
"'flashinfer_nvlink_two_sided'. It will be removed in a future"
"release."
)
from .all2all import FlashInferNVLinkTwoSidedManager

self.all2all_manager = FlashInferNVLinkTwoSidedManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "flashinfer_nvlink_one_sided":
from .all2all import FlashInferNVLinkOneSidedManager

self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")

Expand Down
14 changes: 7 additions & 7 deletions vllm/distributed/device_communicators/mnnvl_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend

from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided

assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found"


class CustomCommunicator(CommBackend):
Expand All @@ -25,14 +25,14 @@ def allgather(self, data: int):
dist.all_gather_object(gathered, data, group=self._group)
return gathered

# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
obj_list = [data]
# broadcast_object_list mutates obj_list in-place
dist.broadcast_object_list(obj_list, src=root, group=self._group)
return obj_list[0]

def barrier(self) -> None:
raise NotImplementedError
dist.barrier(group=self._group)

def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
Loading
Loading