diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 4e3706645ef2..54b796fde3bf 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -36,7 +36,7 @@ th {
| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] |
| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] |
-| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |
+| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |
!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index 340b6ff1cf29..57ef6e9cf148 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -577,6 +577,8 @@ def initialize(
top_k: int,
num_experts: int,
hidden_size: int,
+ dispatch_dtype_bytes_per_elem: int = 0,
+ dispatch_scale_bytes_per_token: int = 0,
):
"""Initialize the MoeAlltoAll workspace."""
if self.initialized:
@@ -607,9 +609,13 @@ def initialize(
ep_config = MnnvlConfig(
comm_backend=CustomCommunicator(self.cpu_group),
)
+ if dispatch_dtype_bytes_per_elem == 0:
+ hidden_bytes = hidden_size // 2
+ else:
+ hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem
total_dispatch_payload_size_per_token = (
- hidden_size // 2 # nvfp4 hidden states
- + hidden_size // 16 # fp8 scaling factors
+ hidden_bytes
+ + dispatch_scale_bytes_per_token
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index fba1d4c692af..2a6f0c71d936 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -228,23 +228,37 @@ def maybe_make_prepare_finalize(
elif moe.use_fi_nvl_one_sided_kernels:
assert quant_config is not None
- if quant_config.quant_dtype != "nvfp4":
- raise ValueError(
- "The 'flashinfer_nvlink_one_sided' all2all backend only "
- "supports nvfp4 activation quantization, but got "
- f"quant_dtype={quant_config.quant_dtype!r}. Use a different "
- "all2all backend (e.g. 'flashinfer_nvlink_two_sided' or "
- "'allgather_reducescatter') for non-nvfp4 models."
- )
max_num_tokens = (
get_current_vllm_config().scheduler_config.max_num_batched_tokens
)
+ if quant_config.quant_dtype is None:
+ dispatch_dtype_bytes_per_elem = 2
+ dispatch_scale_bytes_per_token = 0
+ elif quant_config.quant_dtype == "nvfp4":
+ dispatch_dtype_bytes_per_elem = 0
+ dispatch_scale_bytes_per_token = moe.hidden_dim // 16
+ elif quant_config.quant_dtype == "mxfp8":
+ dispatch_dtype_bytes_per_elem = 1
+ align = quant_config.mx_alignment
+ if align > 0:
+ padded_k = ((moe.hidden_dim + align - 1) // align) * align
+ else:
+ padded_k = moe.hidden_dim
+ dispatch_scale_bytes_per_token = padded_k // 32
+ else:
+ raise NotImplementedError(
+ "flashinfer_nvlink_one_sided dispatch supports nvfp4, mxfp8, "
+ "and bf16 (quant_dtype=None) today; got "
+ f"quant_dtype={quant_config.quant_dtype!r}"
+ )
prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
top_k=moe.experts_per_token,
num_experts=moe.num_experts,
hidden_size=moe.hidden_dim,
num_dispatchers=all2all_manager.world_size,
+ dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem,
+ dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token,
)
elif moe.use_ag_rs_all2all_kernels and allow_new_interface:
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 565df1324f62..8ffa5cffb551 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -254,6 +254,8 @@ class FusedMoEQuantConfig:
gemm1_beta: float | None = None
gemm1_clamp_limit: float | None = None
+ mx_alignment: int = 0
+
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
@@ -712,6 +714,7 @@ def mxfp4_mxfp8_moe_quant_config(
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
+ mx_alignment: int = 0,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
@@ -724,6 +727,7 @@ def mxfp4_mxfp8_moe_quant_config(
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
+ mx_alignment=mx_alignment,
)
diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
index f7af9aea70ad..69e5b7fe4f0e 100644
--- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
+++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
@@ -44,6 +44,9 @@ def __init__(
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
+ self.hidden_dim_unpadded = (
+ moe_config.hidden_dim_unpadded or moe_config.hidden_dim
+ )
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
@@ -82,9 +85,6 @@ def __init__(
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
- # P1-5 fix: use public quant_dtype property instead of private _a1
- self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8"
-
@staticmethod
def _supports_current_device() -> bool:
p = current_platform
@@ -121,8 +121,7 @@ def supports_expert_map(self) -> bool:
@property
def expects_unquantized_inputs(self) -> bool:
- # Expert handles MXFP8 quantization internally if needed
- return True
+ return False
class TrtLlmMxfp4ExpertsMonolithic(
@@ -181,24 +180,19 @@ def apply(
) -> torch.Tensor:
from flashinfer import trtllm_fp4_block_scale_moe
- # Handle input quantization
- if self.use_mxfp8_input:
- from flashinfer import mxfp8_quantize
-
- x_quant, x_scale = mxfp8_quantize(
- hidden_states,
- is_sf_swizzled_layout=False,
- alignment=256,
- )
- x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
- *hidden_states.shape[:-1], -1
- )
+ if a1q_scale is not None:
+ x_quant = hidden_states
+ x_scale = a1q_scale.view(torch.float8_e4m3fn)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None
-
- output = torch.empty_like(hidden_states)
+ output = torch.empty(
+ *hidden_states.shape[:-1],
+ self.hidden_dim_unpadded,
+ dtype=torch.bfloat16,
+ device=hidden_states.device,
+ )
from vllm.utils.flashinfer import _is_fi_autotuning, autotune
@@ -244,10 +238,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
Moved from trtllm_moe.py.
"""
- @property
- def expects_unquantized_inputs(self) -> bool:
- return True
-
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
@@ -284,7 +274,7 @@ def workspace_shapes(
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
- output = (M, K)
+ output = (M, self.hidden_dim_unpadded)
return (workspace1, workspace2, output)
def apply(
@@ -310,18 +300,9 @@ def apply(
intermediate_size = self.intermediate_size_per_partition
local_expert_offset = self.moe_config.ep_rank * local_num_experts
- # Handle input quantization
- if self.use_mxfp8_input:
- from flashinfer import mxfp8_quantize
-
- x_quant, x_scale = mxfp8_quantize(
- hidden_states,
- is_sf_swizzled_layout=False,
- alignment=256,
- )
- x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
- *hidden_states.shape[:-1], -1
- )
+ if a1q_scale is not None:
+ x_quant = hidden_states
+ x_scale = a1q_scale.view(torch.float8_e4m3fn)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
index f476d980d555..c1423362d737 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
@@ -1195,10 +1195,18 @@ def make_mxfp4_moe_quant_config(
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
- elif mxfp4_backend in (
- Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
- Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
- ):
+ elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8:
+ return mxfp4_mxfp8_moe_quant_config(
+ w1_bias=w1_bias,
+ w2_bias=w2_bias,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ gemm1_alpha=gemm1_alpha,
+ gemm1_beta=gemm1_beta,
+ gemm1_clamp_limit=swiglu_limit,
+ mx_alignment=256,
+ )
+ elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
@@ -1250,7 +1258,6 @@ def make_mxfp4_moe_kernel(
"""Create a FusedMoEKernel for the given MXFP4 backend."""
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
- # Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
index a04ff3b8b68f..6cc0d01cde6b 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py
@@ -31,6 +31,8 @@ def __init__(
num_experts: int,
hidden_size: int,
num_dispatchers: int = 1,
+ dispatch_dtype_bytes_per_elem: int = 0,
+ dispatch_scale_bytes_per_token: int = 0,
):
super().__init__()
self.max_num_tokens = max_num_tokens
@@ -38,6 +40,7 @@ def __init__(
self.num_experts = num_experts
self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers
+ self.scale_elems_per_token = dispatch_scale_bytes_per_token
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
@@ -49,6 +52,8 @@ def __init__(
top_k=self.top_k,
num_experts=self.num_experts,
hidden_size=self.hidden_size,
+ dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem,
+ dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token,
)
@property
@@ -92,19 +97,24 @@ def prepare(
else a1.shape[0]
)
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- quant_config.a1_gscale,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- is_fp4_scale_swizzled=False, # delay swizzle to after comm
- )
+ if defer_input_quant:
+ a1q, a1q_scale = a1, None
+ else:
+ a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ quant_config.a1_gscale,
+ quant_config.quant_dtype,
+ quant_config.per_act_token_quant,
+ quant_config.block_shape,
+ is_fp4_scale_swizzled=False, # delay swizzle to after comm
+ mx_alignment=quant_config.mx_alignment,
+ )
payloads = []
payloads.append(a1q)
if a1q_scale is not None:
payloads.append(a1q_scale)
+ topk_ids_payload_index = len(payloads)
payloads.append(topk_ids)
payloads.append(topk_weights)
@@ -113,6 +123,8 @@ def prepare(
token_selected_experts=topk_ids,
input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
+ invalid_token_expert_id=-1, # Follow TRTLLM Pattern
+ expert_id_payload_index=topk_ids_payload_index,
)
if a1q_scale is not None:
a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
@@ -124,7 +136,8 @@ def prepare(
a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
- a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
+ assert self.scale_elems_per_token > 0
+ a1q_scale_recv = a1q_scale_recv.view(-1, self.scale_elems_per_token)
else:
a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
a1q_scale_recv = None
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
index 47fe293d511e..78be414759f7 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py
@@ -174,6 +174,7 @@ def flashinfer_alltoall_dispatch(
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
+ mx_alignment=quant_config.mx_alignment,
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
index 2b21e2db9f68..5b3325ad0195 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
@@ -40,6 +40,7 @@ def _quantize_and_setup_dispatch(
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=False,
+ mx_alignment=quant_config.mx_alignment,
)
# Skip gathering scales if we have static quantization
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
index b9d57da08326..31a35bd60218 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
@@ -31,6 +31,7 @@ def _quantize_input(
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
+ mx_alignment=quant_config.mx_alignment,
)
return a1q, a1q_scale
diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py
index ffab3ca0bfa9..ed24cbe2b233 100644
--- a/vllm/model_executor/layers/fused_moe/utils.py
+++ b/vllm/model_executor/layers/fused_moe/utils.py
@@ -208,11 +208,12 @@ def _mxfp8_e4m3_quantize(
per_act_token_quant: bool,
block_shape: list[int] | None = None,
is_sf_swizzled_layout: bool = False,
+ mx_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None or block_shape == [1, 32]
- return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
+ return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout, mx_alignment)
def _mxfp6_e3m2_quantize(
@@ -258,6 +259,7 @@ def moe_kernel_quantize_input(
is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
quantization_emulation: bool = False,
+ mx_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
@@ -319,7 +321,8 @@ def moe_kernel_quantize_input(
A_scale,
per_act_token_quant,
block_shape,
- is_sf_swizzled_layout=is_fp4_scale_swizzled,
+ is_sf_swizzled_layout=False,
+ mx_alignment=mx_alignment,
)
elif quant_dtype == "mxfp6_e3m2":
if not quantization_emulation:
diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
index b9b7bd542738..a12918225348 100644
--- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
@@ -85,7 +85,9 @@ def _mxfp8_e4m3_quantize_torch(
def _mxfp8_e4m3_quantize_impl(
- x: torch.Tensor, is_sf_swizzled_layout: bool = False
+ x: torch.Tensor,
+ is_sf_swizzled_layout: bool = False,
+ alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.platforms import current_platform
@@ -93,7 +95,9 @@ def _mxfp8_e4m3_quantize_impl(
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
x_q, x_scales = flashinfer_mxfp8_quantize(
- x, is_sf_swizzled_layout=is_sf_swizzled_layout
+ x,
+ is_sf_swizzled_layout=is_sf_swizzled_layout,
+ alignment=alignment if alignment > 0 else 32,
)
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
x_scales = x_scales.view(x.size(0), -1)
@@ -103,9 +107,11 @@ def _mxfp8_e4m3_quantize_impl(
def mxfp8_e4m3_quantize(
- x: torch.Tensor, is_sf_swizzled_layout: bool = False
+ x: torch.Tensor,
+ is_sf_swizzled_layout: bool = False,
+ alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
- return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout)
+ return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout, alignment)
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
@@ -125,7 +131,9 @@ def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor
def mxfp8_e4m3_quantize_fake(
- x: torch.Tensor, is_sf_swizzled_layout: bool = False
+ x: torch.Tensor,
+ is_sf_swizzled_layout: bool = False,
+ alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile tracing."""
fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)