From 3b11a35401caab014cbe7050c2d86754fa517f2c Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 20 Apr 2026 21:09:34 +0000 Subject: [PATCH 1/6] Add bf16 + defer-input-quant support to flashinfer_nvlink_one_sided all2all The one-sided MoeAlltoAll dispatch workspace was hardcoded for nvfp4 hidden states + fp8 scales, so any other activation dtype overran the buffer. Parameterize the workspace sizing by bytes-per-elem and whether an fp8 scale payload is present, then route non-nvfp4 quant configs to a bf16 dispatch (2 B/elem, no scale) via a new defer_input_quant hint. trtllm_mxfp4 experts already advertise expects_unquantized_inputs=True (they call mxfp8_quantize internally). Wire make_mxfp4_moe_kernel to pass that signal into maybe_make_prepare_finalize, and have the one- sided prepare() honor the per-call defer_input_quant flag by shipping a1 as bf16 with no scale payload. Two-sided already handled this. NOTE: the flashinfer moe_a2a_dispatch C++ kernel only templates top_k in {1, 2, 4, 8}; models with other top_k (e.g. DeepSeek-V4 top_k=6) must use flashinfer_nvlink_two_sided instead. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Yongye Zhu --- docs/design/moe_kernel_features.md | 2 +- .../device_communicators/all2all.py | 21 +++++++++++++--- .../layers/fused_moe/all2all_utils.py | 15 +++++++++++ .../layers/fused_moe/oracle/mxfp4.py | 10 ++++++++ .../flashinfer_nvlink_one_sided.py | 25 +++++++++++++------ 5 files changed, 61 insertions(+), 12 deletions(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 4e3706645ef2..2fdcdceae836 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,7 +36,7 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] | | flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] | -| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4,bf16 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 340b6ff1cf29..f871ce9c6ff6 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -577,8 +577,18 @@ def initialize( top_k: int, num_experts: int, hidden_size: int, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_has_fp8_scale: bool = True, ): - """Initialize the MoeAlltoAll workspace.""" + """Initialize the MoeAlltoAll workspace. + + dispatch_dtype_bytes_per_elem: bytes/elem for the dispatched hidden + states. Use 0 as a sentinel for sub-byte nvfp4 (0.5 B/elem); use + 1 for fp8, 2 for bf16/fp16. + dispatch_has_fp8_scale: whether a per-16-elem fp8 scale tensor is + dispatched alongside the hidden states (true for nvfp4/fp8, + false for bf16 passthrough). + """ if self.initialized: return @@ -607,9 +617,14 @@ def initialize( ep_config = MnnvlConfig( comm_backend=CustomCommunicator(self.cpu_group), ) + if dispatch_dtype_bytes_per_elem == 0: + hidden_bytes = hidden_size // 2 # nvfp4 + else: + hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem + scale_bytes = hidden_size // 16 if dispatch_has_fp8_scale else 0 total_dispatch_payload_size_per_token = ( - hidden_size // 2 # nvfp4 hidden states - + hidden_size // 16 # fp8 scaling factors + hidden_bytes + + scale_bytes + top_k * 4 # int32 topks ids + top_k * 4 # float32 topk weights ) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index fba1d4c692af..aeef2c32fd16 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -92,6 +92,7 @@ def maybe_make_prepare_finalize( routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, allow_new_interface: bool = False, use_monolithic: bool = False, + defer_input_quant: bool = False, ) -> FusedMoEPrepareAndFinalize | None: # NOTE(rob): we are migrating each quant_method to hold the MK # in all cases. The allow_new_interface=False flag allow us to fall @@ -239,12 +240,26 @@ def maybe_make_prepare_finalize( max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens ) + if defer_input_quant or quant_config.quant_dtype is None: + # Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize + # post-dispatch; ship bf16 tokens with no per-token scale payload. + dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False + elif quant_config.quant_dtype == "nvfp4": + dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True + else: + raise NotImplementedError( + "flashinfer_nvlink_one_sided dispatch only supports nvfp4, " + "bf16, and defer_input_quant paths today; got " + f"quant_dtype={quant_config.quant_dtype!r}" + ) prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( max_num_tokens=max_num_tokens, top_k=moe.experts_per_token, num_experts=moe.num_experts, hidden_size=moe.hidden_dim, num_dispatchers=all2all_manager.world_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_has_fp8_scale=dispatch_has_fp8_scale, ) elif moe.use_ag_rs_all2all_kernels and allow_new_interface: diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..55b1f1185a8d 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1250,6 +1250,15 @@ def make_mxfp4_moe_kernel( """Create a FusedMoEKernel for the given MXFP4 backend.""" is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) + # Some experts (trtllm_mxfp4 with mxfp8 activations) prefer bf16 tokens + # on dispatch and quantize internally; signal this to the prepare/finalize + # so workspace + prepare path ship bf16 instead of the quant_config dtype. + from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import ( + TrtLlmMxfp4ExpertsBase, + ) + + defer_input_quant = issubclass(experts_cls, TrtLlmMxfp4ExpertsBase) + # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, @@ -1257,6 +1266,7 @@ def make_mxfp4_moe_kernel( routing_tables=routing_tables, allow_new_interface=True, use_monolithic=is_monolithic, + defer_input_quant=defer_input_quant, ) assert prepare_finalize is not None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index a04ff3b8b68f..40b94cfbec61 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -31,6 +31,8 @@ def __init__( num_experts: int, hidden_size: int, num_dispatchers: int = 1, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_has_fp8_scale: bool = True, ): super().__init__() self.max_num_tokens = max_num_tokens @@ -49,6 +51,8 @@ def __init__( top_k=self.top_k, num_experts=self.num_experts, hidden_size=self.hidden_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_has_fp8_scale=dispatch_has_fp8_scale, ) @property @@ -92,14 +96,19 @@ def prepare( else a1.shape[0] ) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) + if defer_input_quant: + # Experts (e.g. trtllm_mxfp4_moe with mxfp8 activations) will + # quantize post-dispatch. Ship bf16 tokens and skip scales. + a1q, a1q_scale = a1, None + else: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) payloads = [] payloads.append(a1q) From 5c5dc8080f4542249d2c387e859f4d28da2c10ac Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Tue, 28 Apr 2026 22:01:20 +0000 Subject: [PATCH 2/6] mxfp8 dispatch support Signed-off-by: Yongye Zhu --- docs/design/moe_kernel_features.md | 2 +- .../device_communicators/all2all.py | 17 ++----- .../layers/fused_moe/all2all_utils.py | 33 +++++++------ .../model_executor/layers/fused_moe/config.py | 4 ++ .../fused_moe/experts/trtllm_mxfp4_moe.py | 48 +++++-------------- .../layers/fused_moe/oracle/mxfp4.py | 27 +++++------ .../flashinfer_nvlink_one_sided.py | 11 +++-- .../flashinfer_nvlink_two_sided.py | 1 + .../fused_moe/prepare_finalize/naive_dp_ep.py | 1 + .../fused_moe/prepare_finalize/no_dp_ep.py | 1 + vllm/model_executor/layers/fused_moe/utils.py | 5 +- .../layers/quantization/utils/mxfp8_utils.py | 18 +++++-- 12 files changed, 76 insertions(+), 92 deletions(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 2fdcdceae836..54b796fde3bf 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,7 +36,7 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] | | flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] | -| flashinfer_nvlink_one_sided | standard | nvfp4,bf16 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index f871ce9c6ff6..57ef6e9cf148 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -578,17 +578,9 @@ def initialize( num_experts: int, hidden_size: int, dispatch_dtype_bytes_per_elem: int = 0, - dispatch_has_fp8_scale: bool = True, + dispatch_scale_bytes_per_token: int = 0, ): - """Initialize the MoeAlltoAll workspace. - - dispatch_dtype_bytes_per_elem: bytes/elem for the dispatched hidden - states. Use 0 as a sentinel for sub-byte nvfp4 (0.5 B/elem); use - 1 for fp8, 2 for bf16/fp16. - dispatch_has_fp8_scale: whether a per-16-elem fp8 scale tensor is - dispatched alongside the hidden states (true for nvfp4/fp8, - false for bf16 passthrough). - """ + """Initialize the MoeAlltoAll workspace.""" if self.initialized: return @@ -618,13 +610,12 @@ def initialize( comm_backend=CustomCommunicator(self.cpu_group), ) if dispatch_dtype_bytes_per_elem == 0: - hidden_bytes = hidden_size // 2 # nvfp4 + hidden_bytes = hidden_size // 2 else: hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem - scale_bytes = hidden_size // 16 if dispatch_has_fp8_scale else 0 total_dispatch_payload_size_per_token = ( hidden_bytes - + scale_bytes + + dispatch_scale_bytes_per_token + top_k * 4 # int32 topks ids + top_k * 4 # float32 topk weights ) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index aeef2c32fd16..2a6f0c71d936 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -92,7 +92,6 @@ def maybe_make_prepare_finalize( routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, allow_new_interface: bool = False, use_monolithic: bool = False, - defer_input_quant: bool = False, ) -> FusedMoEPrepareAndFinalize | None: # NOTE(rob): we are migrating each quant_method to hold the MK # in all cases. The allow_new_interface=False flag allow us to fall @@ -229,27 +228,27 @@ def maybe_make_prepare_finalize( elif moe.use_fi_nvl_one_sided_kernels: assert quant_config is not None - if quant_config.quant_dtype != "nvfp4": - raise ValueError( - "The 'flashinfer_nvlink_one_sided' all2all backend only " - "supports nvfp4 activation quantization, but got " - f"quant_dtype={quant_config.quant_dtype!r}. Use a different " - "all2all backend (e.g. 'flashinfer_nvlink_two_sided' or " - "'allgather_reducescatter') for non-nvfp4 models." - ) max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens ) - if defer_input_quant or quant_config.quant_dtype is None: - # Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize - # post-dispatch; ship bf16 tokens with no per-token scale payload. - dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False + if quant_config.quant_dtype is None: + dispatch_dtype_bytes_per_elem = 2 + dispatch_scale_bytes_per_token = 0 elif quant_config.quant_dtype == "nvfp4": - dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True + dispatch_dtype_bytes_per_elem = 0 + dispatch_scale_bytes_per_token = moe.hidden_dim // 16 + elif quant_config.quant_dtype == "mxfp8": + dispatch_dtype_bytes_per_elem = 1 + align = quant_config.mx_alignment + if align > 0: + padded_k = ((moe.hidden_dim + align - 1) // align) * align + else: + padded_k = moe.hidden_dim + dispatch_scale_bytes_per_token = padded_k // 32 else: raise NotImplementedError( - "flashinfer_nvlink_one_sided dispatch only supports nvfp4, " - "bf16, and defer_input_quant paths today; got " + "flashinfer_nvlink_one_sided dispatch supports nvfp4, mxfp8, " + "and bf16 (quant_dtype=None) today; got " f"quant_dtype={quant_config.quant_dtype!r}" ) prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( @@ -259,7 +258,7 @@ def maybe_make_prepare_finalize( hidden_size=moe.hidden_dim, num_dispatchers=all2all_manager.world_size, dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, - dispatch_has_fp8_scale=dispatch_has_fp8_scale, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) elif moe.use_ag_rs_all2all_kernels and allow_new_interface: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 565df1324f62..8ffa5cffb551 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -254,6 +254,8 @@ class FusedMoEQuantConfig: gemm1_beta: float | None = None gemm1_clamp_limit: float | None = None + mx_alignment: int = 0 + def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" @@ -712,6 +714,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha: float | None = None, gemm1_beta: float | None = None, gemm1_clamp_limit: float | None = None, + mx_alignment: int = 0, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -724,6 +727,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha=gemm1_alpha, gemm1_beta=gemm1_beta, gemm1_clamp_limit=gemm1_clamp_limit, + mx_alignment=mx_alignment, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index f7af9aea70ad..01f455296f3a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -82,9 +82,6 @@ def __init__( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) - # P1-5 fix: use public quant_dtype property instead of private _a1 - self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8" - @staticmethod def _supports_current_device() -> bool: p = current_platform @@ -121,8 +118,7 @@ def supports_expert_map(self) -> bool: @property def expects_unquantized_inputs(self) -> bool: - # Expert handles MXFP8 quantization internally if needed - return True + return False class TrtLlmMxfp4ExpertsMonolithic( @@ -181,24 +177,19 @@ def apply( ) -> torch.Tensor: from flashinfer import trtllm_fp4_block_scale_moe - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states x_scale = None - - output = torch.empty_like(hidden_states) + output = torch.empty( + *hidden_states.shape[:-1], + self.hidden_dim, + dtype=torch.bfloat16, + device=hidden_states.device, + ) from vllm.utils.flashinfer import _is_fi_autotuning, autotune @@ -244,10 +235,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula Moved from trtllm_moe.py. """ - @property - def expects_unquantized_inputs(self) -> bool: - return True - @staticmethod def _supports_parallel_config( moe_parallel_config: FusedMoEParallelConfig, @@ -310,18 +297,9 @@ def apply( intermediate_size = self.intermediate_size_per_partition local_expert_offset = self.moe_config.ep_rank * local_num_experts - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 55b1f1185a8d..c1423362d737 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1195,10 +1195,18 @@ def make_mxfp4_moe_quant_config( gemm1_beta=gemm1_beta, gemm1_clamp_limit=swiglu_limit, ) - elif mxfp4_backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, - ): + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8: + return mxfp4_mxfp8_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + mx_alignment=256, + ) + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: return mxfp4_mxfp8_moe_quant_config( w1_bias=w1_bias, w2_bias=w2_bias, @@ -1250,23 +1258,12 @@ def make_mxfp4_moe_kernel( """Create a FusedMoEKernel for the given MXFP4 backend.""" is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) - # Some experts (trtllm_mxfp4 with mxfp8 activations) prefer bf16 tokens - # on dispatch and quantize internally; signal this to the prepare/finalize - # so workspace + prepare path ship bf16 instead of the quant_config dtype. - from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import ( - TrtLlmMxfp4ExpertsBase, - ) - - defer_input_quant = issubclass(experts_cls, TrtLlmMxfp4ExpertsBase) - - # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, routing_tables=routing_tables, allow_new_interface=True, use_monolithic=is_monolithic, - defer_input_quant=defer_input_quant, ) assert prepare_finalize is not None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index 40b94cfbec61..ef8cc7d17b21 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -32,7 +32,7 @@ def __init__( hidden_size: int, num_dispatchers: int = 1, dispatch_dtype_bytes_per_elem: int = 0, - dispatch_has_fp8_scale: bool = True, + dispatch_scale_bytes_per_token: int = 0, ): super().__init__() self.max_num_tokens = max_num_tokens @@ -40,6 +40,7 @@ def __init__( self.num_experts = num_experts self.hidden_size = hidden_size self.num_dispatchers_ = num_dispatchers + self.scale_elems_per_token = dispatch_scale_bytes_per_token device_communicator = get_ep_group().device_communicator assert device_communicator is not None @@ -52,7 +53,7 @@ def __init__( num_experts=self.num_experts, hidden_size=self.hidden_size, dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, - dispatch_has_fp8_scale=dispatch_has_fp8_scale, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) @property @@ -97,8 +98,6 @@ def prepare( ) if defer_input_quant: - # Experts (e.g. trtllm_mxfp4_moe with mxfp8 activations) will - # quantize post-dispatch. Ship bf16 tokens and skip scales. a1q, a1q_scale = a1, None else: a1q, a1q_scale = moe_kernel_quantize_input( @@ -108,6 +107,7 @@ def prepare( quant_config.per_act_token_quant, quant_config.block_shape, is_fp4_scale_swizzled=False, # delay swizzle to after comm + mx_alignment=quant_config.mx_alignment, ) payloads = [] @@ -133,7 +133,8 @@ def prepare( a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) a1q_scale_recv = a1q_scale_recv.view(torch.uint8) a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) - a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + assert self.scale_elems_per_token > 0 + a1q_scale_recv = a1q_scale_recv.view(-1, self.scale_elems_per_token) else: a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads a1q_scale_recv = None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py index 47fe293d511e..78be414759f7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py @@ -174,6 +174,7 @@ def flashinfer_alltoall_dispatch( # the hidden states, breaking the A2A kernel. So, we # delay the swizzling until after the A2A. is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) x = MnnvlMoe.mnnvl_moe_alltoallv( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 2b21e2db9f68..5b3325ad0195 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -40,6 +40,7 @@ def _quantize_and_setup_dispatch( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) # Skip gathering scales if we have static quantization diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py index b9d57da08326..31a35bd60218 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py @@ -31,6 +31,7 @@ def _quantize_input( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, + mx_alignment=quant_config.mx_alignment, ) return a1q, a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ffab3ca0bfa9..23d3f53fe6a7 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -208,11 +208,12 @@ def _mxfp8_e4m3_quantize( per_act_token_quant: bool, block_shape: list[int] | None = None, is_sf_swizzled_layout: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant assert block_shape is None or block_shape == [1, 32] - return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout) + return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout, mx_alignment) def _mxfp6_e3m2_quantize( @@ -258,6 +259,7 @@ def moe_kernel_quantize_input( is_fp4_scale_swizzled: bool = True, ocp_mx_scheme: str | None = None, quantization_emulation: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation if ocp_mx_scheme is not None: @@ -320,6 +322,7 @@ def moe_kernel_quantize_input( per_act_token_quant, block_shape, is_sf_swizzled_layout=is_fp4_scale_swizzled, + mx_alignment=mx_alignment, ) elif quant_dtype == "mxfp6_e3m2": if not quantization_emulation: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index b9b7bd542738..9237d03efd98 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -85,7 +85,9 @@ def _mxfp8_e4m3_quantize_torch( def _mxfp8_e4m3_quantize_impl( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: from vllm.platforms import current_platform @@ -93,7 +95,9 @@ def _mxfp8_e4m3_quantize_impl( from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize x_q, x_scales = flashinfer_mxfp8_quantize( - x, is_sf_swizzled_layout=is_sf_swizzled_layout + x, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment if alignment > 0 else None, ) if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: x_scales = x_scales.view(x.size(0), -1) @@ -103,9 +107,11 @@ def _mxfp8_e4m3_quantize_impl( def mxfp8_e4m3_quantize( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout) + return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout, alignment) def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: @@ -125,7 +131,9 @@ def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor def mxfp8_e4m3_quantize_fake( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Fake implementation for torch.compile tracing.""" fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE) From 44ff180df72e8311f5e37d6861956f2cc7a32850 Mon Sep 17 00:00:00 2001 From: Zijing Liu Date: Tue, 28 Apr 2026 18:22:27 -0700 Subject: [PATCH 3/6] Sanitize unfilled recv slots in flashinfer_nvlink_one_sided dispatch (#9) Padded rows in the [ep_size, max_num_tokens, ...] workspace retain stale topk_ids from prior dispatch calls (the workspace is zeroed only once at init). Those stale ids cause the downstream trtllm_fp4 grouped GEMM to do phantom work for random local experts every layer, which (a) inflates expert GEMM time and (b) creates the cross-rank skew that the combine kernel then has to wait on. Setting `invalid_token_expert_id` to `num_experts` (one past the valid expert range) makes the flashinfer worker overwrite all top_k topk_ids slots of padded rows with that sentinel (moeA2ASanitizeExpertIdsKernel in moeAlltoAllKernels.cu); the trtllm grouped GEMM then sees those rows as routed to no local expert (out of [local_expert_offset, local_expert_offset + local_num_experts)) and skips them. Signed-off-by: Zijing Liu Signed-off-by: Yongye Zhu --- .../fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index ef8cc7d17b21..b41700d9d82e 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -114,6 +114,7 @@ def prepare( payloads.append(a1q) if a1q_scale is not None: payloads.append(a1q_scale) + topk_ids_payload_index = len(payloads) payloads.append(topk_ids) payloads.append(topk_weights) @@ -122,6 +123,8 @@ def prepare( token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + invalid_token_expert_id=num_experts, + expert_id_payload_index=topk_ids_payload_index, ) if a1q_scale is not None: a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads From 43d304ecd7497d09d1b07a2ff7e3db07b98630c4 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Thu, 30 Apr 2026 04:46:55 +0000 Subject: [PATCH 4/6] update Signed-off-by: Yongye Zhu --- .../fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index b41700d9d82e..6cc0d01cde6b 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -123,7 +123,7 @@ def prepare( token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, - invalid_token_expert_id=num_experts, + invalid_token_expert_id=-1, # Follow TRTLLM Pattern expert_id_payload_index=topk_ids_payload_index, ) if a1q_scale is not None: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 23d3f53fe6a7..ed24cbe2b233 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -321,7 +321,7 @@ def moe_kernel_quantize_input( A_scale, per_act_token_quant, block_shape, - is_sf_swizzled_layout=is_fp4_scale_swizzled, + is_sf_swizzled_layout=False, mx_alignment=mx_alignment, ) elif quant_dtype == "mxfp6_e3m2": From 57226cf3cad9e74c423b6913ccfd6469681e703e Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Thu, 30 Apr 2026 06:05:23 +0000 Subject: [PATCH 5/6] Use hidden_dim_unpadded for trtllm_mxfp4 output buffer When prepare-side mxfp8_quantize pads K to mx_alignment (e.g. gpt-oss hidden=2880 -> 3072 with align=256), pre-PR's torch.empty_like(hidden_states) naturally produced an unpadded output because hidden_states was the original bf16 input. With prepare-side quantize, hidden_states entering apply() is the padded fp8 tensor, so allocating output by self.hidden_dim (which is the post-roundup padded value from maybe_roundup_sizes) propagates padding into lm_head. Use moe_config.hidden_dim_unpadded so trtllm internally truncates back to the original hidden, matching pre-PR behavior. Apply the same fix to the modular workspace_shapes for non-aligned hiddens with EP. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Yongye Zhu --- .../layers/fused_moe/experts/trtllm_mxfp4_moe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index 01f455296f3a..69e5b7fe4f0e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -44,6 +44,9 @@ def __init__( moe_config.intermediate_size_per_partition ) self.hidden_dim = moe_config.hidden_dim + self.hidden_dim_unpadded = ( + moe_config.hidden_dim_unpadded or moe_config.hidden_dim + ) self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank @@ -186,7 +189,7 @@ def apply( x_scale = None output = torch.empty( *hidden_states.shape[:-1], - self.hidden_dim, + self.hidden_dim_unpadded, dtype=torch.bfloat16, device=hidden_states.device, ) @@ -271,7 +274,7 @@ def workspace_shapes( # The workspaces for this implementation are managed by flashinfer. workspace1 = (0,) workspace2 = (0,) - output = (M, K) + output = (M, self.hidden_dim_unpadded) return (workspace1, workspace2, output) def apply( From 96d415aad8dc2cc4d187f829b603a69e6df6a39e Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Thu, 30 Apr 2026 16:33:43 +0000 Subject: [PATCH 6/6] change alignment none to default32 Signed-off-by: Yongye Zhu --- vllm/model_executor/layers/quantization/utils/mxfp8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index 9237d03efd98..a12918225348 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -97,7 +97,7 @@ def _mxfp8_e4m3_quantize_impl( x_q, x_scales = flashinfer_mxfp8_quantize( x, is_sf_swizzled_layout=is_sf_swizzled_layout, - alignment=alignment if alignment > 0 else None, + alignment=alignment if alignment > 0 else 32, ) if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: x_scales = x_scales.view(x.size(0), -1)