diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 4e3706645ef2..54b796fde3bf 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,7 +36,7 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] | | flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] | -| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 340b6ff1cf29..57ef6e9cf148 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -577,6 +577,8 @@ def initialize( top_k: int, num_experts: int, hidden_size: int, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_scale_bytes_per_token: int = 0, ): """Initialize the MoeAlltoAll workspace.""" if self.initialized: @@ -607,9 +609,13 @@ def initialize( ep_config = MnnvlConfig( comm_backend=CustomCommunicator(self.cpu_group), ) + if dispatch_dtype_bytes_per_elem == 0: + hidden_bytes = hidden_size // 2 + else: + hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem total_dispatch_payload_size_per_token = ( - hidden_size // 2 # nvfp4 hidden states - + hidden_size // 16 # fp8 scaling factors + hidden_bytes + + dispatch_scale_bytes_per_token + top_k * 4 # int32 topks ids + top_k * 4 # float32 topk weights ) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index fba1d4c692af..2a6f0c71d936 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -228,23 +228,37 @@ def maybe_make_prepare_finalize( elif moe.use_fi_nvl_one_sided_kernels: assert quant_config is not None - if quant_config.quant_dtype != "nvfp4": - raise ValueError( - "The 'flashinfer_nvlink_one_sided' all2all backend only " - "supports nvfp4 activation quantization, but got " - f"quant_dtype={quant_config.quant_dtype!r}. Use a different " - "all2all backend (e.g. 'flashinfer_nvlink_two_sided' or " - "'allgather_reducescatter') for non-nvfp4 models." - ) max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens ) + if quant_config.quant_dtype is None: + dispatch_dtype_bytes_per_elem = 2 + dispatch_scale_bytes_per_token = 0 + elif quant_config.quant_dtype == "nvfp4": + dispatch_dtype_bytes_per_elem = 0 + dispatch_scale_bytes_per_token = moe.hidden_dim // 16 + elif quant_config.quant_dtype == "mxfp8": + dispatch_dtype_bytes_per_elem = 1 + align = quant_config.mx_alignment + if align > 0: + padded_k = ((moe.hidden_dim + align - 1) // align) * align + else: + padded_k = moe.hidden_dim + dispatch_scale_bytes_per_token = padded_k // 32 + else: + raise NotImplementedError( + "flashinfer_nvlink_one_sided dispatch supports nvfp4, mxfp8, " + "and bf16 (quant_dtype=None) today; got " + f"quant_dtype={quant_config.quant_dtype!r}" + ) prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( max_num_tokens=max_num_tokens, top_k=moe.experts_per_token, num_experts=moe.num_experts, hidden_size=moe.hidden_dim, num_dispatchers=all2all_manager.world_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) elif moe.use_ag_rs_all2all_kernels and allow_new_interface: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 565df1324f62..8ffa5cffb551 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -254,6 +254,8 @@ class FusedMoEQuantConfig: gemm1_beta: float | None = None gemm1_clamp_limit: float | None = None + mx_alignment: int = 0 + def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" @@ -712,6 +714,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha: float | None = None, gemm1_beta: float | None = None, gemm1_clamp_limit: float | None = None, + mx_alignment: int = 0, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -724,6 +727,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha=gemm1_alpha, gemm1_beta=gemm1_beta, gemm1_clamp_limit=gemm1_clamp_limit, + mx_alignment=mx_alignment, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index f7af9aea70ad..69e5b7fe4f0e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -44,6 +44,9 @@ def __init__( moe_config.intermediate_size_per_partition ) self.hidden_dim = moe_config.hidden_dim + self.hidden_dim_unpadded = ( + moe_config.hidden_dim_unpadded or moe_config.hidden_dim + ) self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank @@ -82,9 +85,6 @@ def __init__( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) - # P1-5 fix: use public quant_dtype property instead of private _a1 - self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8" - @staticmethod def _supports_current_device() -> bool: p = current_platform @@ -121,8 +121,7 @@ def supports_expert_map(self) -> bool: @property def expects_unquantized_inputs(self) -> bool: - # Expert handles MXFP8 quantization internally if needed - return True + return False class TrtLlmMxfp4ExpertsMonolithic( @@ -181,24 +180,19 @@ def apply( ) -> torch.Tensor: from flashinfer import trtllm_fp4_block_scale_moe - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states x_scale = None - - output = torch.empty_like(hidden_states) + output = torch.empty( + *hidden_states.shape[:-1], + self.hidden_dim_unpadded, + dtype=torch.bfloat16, + device=hidden_states.device, + ) from vllm.utils.flashinfer import _is_fi_autotuning, autotune @@ -244,10 +238,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula Moved from trtllm_moe.py. """ - @property - def expects_unquantized_inputs(self) -> bool: - return True - @staticmethod def _supports_parallel_config( moe_parallel_config: FusedMoEParallelConfig, @@ -284,7 +274,7 @@ def workspace_shapes( # The workspaces for this implementation are managed by flashinfer. workspace1 = (0,) workspace2 = (0,) - output = (M, K) + output = (M, self.hidden_dim_unpadded) return (workspace1, workspace2, output) def apply( @@ -310,18 +300,9 @@ def apply( intermediate_size = self.intermediate_size_per_partition local_expert_offset = self.moe_config.ep_rank * local_num_experts - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..c1423362d737 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1195,10 +1195,18 @@ def make_mxfp4_moe_quant_config( gemm1_beta=gemm1_beta, gemm1_clamp_limit=swiglu_limit, ) - elif mxfp4_backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, - ): + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8: + return mxfp4_mxfp8_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + mx_alignment=256, + ) + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: return mxfp4_mxfp8_moe_quant_config( w1_bias=w1_bias, w2_bias=w2_bias, @@ -1250,7 +1258,6 @@ def make_mxfp4_moe_kernel( """Create a FusedMoEKernel for the given MXFP4 backend.""" is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) - # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index a04ff3b8b68f..6cc0d01cde6b 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -31,6 +31,8 @@ def __init__( num_experts: int, hidden_size: int, num_dispatchers: int = 1, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_scale_bytes_per_token: int = 0, ): super().__init__() self.max_num_tokens = max_num_tokens @@ -38,6 +40,7 @@ def __init__( self.num_experts = num_experts self.hidden_size = hidden_size self.num_dispatchers_ = num_dispatchers + self.scale_elems_per_token = dispatch_scale_bytes_per_token device_communicator = get_ep_group().device_communicator assert device_communicator is not None @@ -49,6 +52,8 @@ def __init__( top_k=self.top_k, num_experts=self.num_experts, hidden_size=self.hidden_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) @property @@ -92,19 +97,24 @@ def prepare( else a1.shape[0] ) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) + if defer_input_quant: + a1q, a1q_scale = a1, None + else: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + mx_alignment=quant_config.mx_alignment, + ) payloads = [] payloads.append(a1q) if a1q_scale is not None: payloads.append(a1q_scale) + topk_ids_payload_index = len(payloads) payloads.append(topk_ids) payloads.append(topk_weights) @@ -113,6 +123,8 @@ def prepare( token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + invalid_token_expert_id=-1, # Follow TRTLLM Pattern + expert_id_payload_index=topk_ids_payload_index, ) if a1q_scale is not None: a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads @@ -124,7 +136,8 @@ def prepare( a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) a1q_scale_recv = a1q_scale_recv.view(torch.uint8) a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) - a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + assert self.scale_elems_per_token > 0 + a1q_scale_recv = a1q_scale_recv.view(-1, self.scale_elems_per_token) else: a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads a1q_scale_recv = None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py index 47fe293d511e..78be414759f7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py @@ -174,6 +174,7 @@ def flashinfer_alltoall_dispatch( # the hidden states, breaking the A2A kernel. So, we # delay the swizzling until after the A2A. is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) x = MnnvlMoe.mnnvl_moe_alltoallv( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 2b21e2db9f68..5b3325ad0195 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -40,6 +40,7 @@ def _quantize_and_setup_dispatch( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) # Skip gathering scales if we have static quantization diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py index b9d57da08326..31a35bd60218 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py @@ -31,6 +31,7 @@ def _quantize_input( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, + mx_alignment=quant_config.mx_alignment, ) return a1q, a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ffab3ca0bfa9..ed24cbe2b233 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -208,11 +208,12 @@ def _mxfp8_e4m3_quantize( per_act_token_quant: bool, block_shape: list[int] | None = None, is_sf_swizzled_layout: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant assert block_shape is None or block_shape == [1, 32] - return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout) + return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout, mx_alignment) def _mxfp6_e3m2_quantize( @@ -258,6 +259,7 @@ def moe_kernel_quantize_input( is_fp4_scale_swizzled: bool = True, ocp_mx_scheme: str | None = None, quantization_emulation: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation if ocp_mx_scheme is not None: @@ -319,7 +321,8 @@ def moe_kernel_quantize_input( A_scale, per_act_token_quant, block_shape, - is_sf_swizzled_layout=is_fp4_scale_swizzled, + is_sf_swizzled_layout=False, + mx_alignment=mx_alignment, ) elif quant_dtype == "mxfp6_e3m2": if not quantization_emulation: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index b9b7bd542738..a12918225348 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -85,7 +85,9 @@ def _mxfp8_e4m3_quantize_torch( def _mxfp8_e4m3_quantize_impl( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: from vllm.platforms import current_platform @@ -93,7 +95,9 @@ def _mxfp8_e4m3_quantize_impl( from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize x_q, x_scales = flashinfer_mxfp8_quantize( - x, is_sf_swizzled_layout=is_sf_swizzled_layout + x, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment if alignment > 0 else 32, ) if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: x_scales = x_scales.view(x.size(0), -1) @@ -103,9 +107,11 @@ def _mxfp8_e4m3_quantize_impl( def mxfp8_e4m3_quantize( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout) + return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout, alignment) def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: @@ -125,7 +131,9 @@ def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor def mxfp8_e4m3_quantize_fake( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Fake implementation for torch.compile tracing.""" fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)