diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d1b25aa92e8d..2ad620b0a7f1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -47,9 +47,6 @@ disable_inplace, moe_kernel_quantize_input, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 -from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 -from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -2017,7 +2014,8 @@ def fused_experts_impl( # Check constraints. if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" - elif ocp_mx_scheme is not None: + elif ocp_mx_scheme is not None and w1_scale is not None: + # Size checks for packed weights (native MXFP or before dequantization) if ocp_mx_scheme in { "w_mxfp4_a_mxfp4", "w_mxfp4_a_mxfp6_e3m2", @@ -2035,6 +2033,7 @@ def fused_experts_impl( else: raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") else: + # Normal weights or MXFP emulation (weights already dequantized) assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" ) @@ -2112,41 +2111,6 @@ def fused_experts_impl( else: out_hidden_states = torch.empty_like(hidden_states) - if ocp_mx_scheme is not None: - # TODO: On platforms for which `current_platform.supports_mx()` is True - # and for which we have a native OCP mx fused MOE kernel, - # this dequantization step should not be done. - if ocp_mx_scheme in { - OCP_MX_Scheme.w_mxfp4_a_mxfp4, - OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, - OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, - }: - # Weight has to be dequantized for mxfp4 emulation. - w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) - w1_scale = None - w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) - w2_scale = None - elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: - w1 = dequant_mxfp6( - w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype - ) - w1_scale = None - w2 = dequant_mxfp6( - w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype - ) - w2_scale = None - elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: - w1 = dequant_mxfp6( - w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype - ) - w1_scale = None - w2 = dequant_mxfp6( - w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype - ) - w2_scale = None - else: - raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") - for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = ( chunk * CHUNK_SIZE, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 6b731314825a..1de28d177170 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -25,6 +25,12 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import ( + dequant_mxfp6, +) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, @@ -635,6 +641,56 @@ def get_packed_dim(self, dim: int, quant_dtype: str): assert (dim * 3) % 4 == 0 return (dim * 3) // 4 + def _dequantize_weights( + self, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dequantize MXFP4/MXFP6 weights to high precision for emulation. + + Args: + w1: Packed w13 weights (uint8) + w2: Packed w2 weights (uint8) + w1_scale: Weight scales for w13 + w2_scale: Weight scales for w2 + dtype: Target dtype for dequantization (fp16/bf16/fp32) + + Returns: + Tuple of (dequantized_w1, dequantized_w2) + """ + if self.ocp_mx_scheme in { + OCP_MX_Scheme.w_mxfp4_a_mxfp4, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, + }: + # MXFP4 weights + dequant_w1 = dequant_mxfp4(w1, w1_scale, dtype) + dequant_w2 = dequant_mxfp4(w2, w2_scale, dtype) + elif self.ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: + # MXFP6 e3m2 weights + dequant_w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=dtype + ) + dequant_w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=dtype + ) + elif self.ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: + # MXFP6 e2m3 weights + dequant_w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=dtype + ) + dequant_w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=dtype + ) + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={self.ocp_mx_scheme}") + + return dequant_w1, dequant_w2 + def create_weights( self, layer: torch.nn.Module, @@ -736,15 +792,29 @@ def process_weights_after_loading(self, layer): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return ocp_mx_moe_quant_config( - quant_dtype=self.input_dtype, - weight_dtype=self.weight_dtype, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - ) + if self.emulate: + # Emulation mode: weights are dequantized in apply(), but intermediate + # activations still need quantization. Set scales to None since + # weights are already dequantized. + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) + else: + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) @property def allow_inplace(self) -> bool: @@ -780,10 +850,19 @@ def apply( else: from vllm.model_executor.layers.fused_moe import fused_experts - out = fused_experts( - x, + # Dequantize weights for MXFP emulation + dequant_w1, dequant_w2 = self._dequantize_weights( layer.w13_weight, layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + x.dtype, + ) + + out = fused_experts( + x, + dequant_w1, + dequant_w2, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True,