diff --git a/tests/evals/gsm8k/configs/models-mi3xx.txt b/tests/evals/gsm8k/configs/models-mi3xx.txt index 6cf833b64642..dfa4bc8eb53f 100644 --- a/tests/evals/gsm8k/configs/models-mi3xx.txt +++ b/tests/evals/gsm8k/configs/models-mi3xx.txt @@ -2,3 +2,5 @@ DeepSeek-R1-TP_MI325.yaml DeepSeek-R1-DP_MI325.yaml DeepSeek-V3.2-TP_MI325.yaml DeepSeek-V3.2-DP_MI325.yaml +Qwen3-30B-A3B-NVFP4.yaml +Qwen3.5-35B-A3B-MXFP4-TP2.yaml \ No newline at end of file diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index 30f69f62130f..db3719b2f432 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -120,3 +120,20 @@ def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): with vllm_runner(model, enforce_eager=eager) as llm: output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) assert output[0][1] == "1 2 3 4 5 6" + + +# Qwen3-30B-A3B is 60 GB vs Llama-4-Scout-17B-16E-Instruct-FP4 that is 210 GB. +@pytest.mark.parametrize( + "model", + [ + "nvidia/Qwen3-30B-A3B-NVFP4", + "RedHatAI/Qwen3-30B-A3B-NVFP4", + ], +) +@pytest.mark.parametrize("eager", EAGER) +@pytest.mark.parametrize("backend", ["emulation"]) +def test_nvfp4_moe(vllm_runner, model, eager, backend, monkeypatch): + monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) + with vllm_runner(model, enforce_eager=eager, moe_backend="emulation") as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index c51b3461c63b..a84fa3e936d1 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -228,6 +228,8 @@ def get_model_args( model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2", excepted_value=10.6, ), + # This one raises `RuntimeError: wrong! device_gemm with the specified compilation + # parameters does not support this GEMM problem` on MI355X. AccuracyTestConfig( model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4 ), @@ -238,8 +240,13 @@ def get_model_args( not QUARK_MXFP4_AVAILABLE, reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", ) -@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) -@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize( + "config", + [pytest.param(val, id=f"config:{val}") for val in WIKITEXT_ACCURACY_CONFIGS], +) +@pytest.mark.parametrize( + "tp_size", [pytest.param(val, id=f"tp_size:{val}") for val in [1, 2]] +) def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): device_count = torch.accelerator.device_count() if device_count < tp_size: @@ -266,6 +273,54 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_nvfp4_wikitext_correctness(tp_size: int): + device_count = torch.accelerator.device_count() + if device_count < tp_size: + pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}") + + # model_name = "amd-quark/Qwen3-30B-A3B-nvfp4-quark" + # NOTE: expected_value from nvidia/Qwen3-30B-A3B-NVFP4 + expected_value = 11.2391 + + model_name = "amd-quark/Qwen3-30B-A3B-nvfp4-quark" + task = "wikitext" + + rtol = 0.25 + + config = AccuracyTestConfig( + model_name=model_name, + excepted_value=expected_value, + ) + + model_args = config.get_model_args( + tp_size=tp_size, + kwargs={ + "cudagraph_capture_sizes": [16], + }, + ) + model_args.pop("add_bos_token") + + # Smaller cudagraph_capture_sizes to speed up the test. + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=task, + batch_size=64, + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][task]["word_perplexity,none"] + assert ( + measured_value < EXPECTED_VALUE + rtol + and measured_value > EXPECTED_VALUE - rtol + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + @pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS) @pytest.mark.skipif( not QUARK_MXFP4_AVAILABLE, diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index f3ffbe4e8b19..8d8e37a0549a 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -115,6 +115,7 @@ def with_default( "flashinfer_cutedsl", "marlin", "aiter", + "emulation", ] @@ -142,7 +143,10 @@ class KernelConfig: - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) - "marlin": Use Marlin kernels (weight-only quantization) - - "aiter": Use AMD AITer kernels (ROCm only)""" + - "aiter": Use AMD AITer kernels (ROCm only) + - "emulation": use BF16/FP16 GEMM, dequantizing weights and + running QDQ on activations. + """ @field_validator("moe_backend", mode="before") @classmethod diff --git a/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py new file mode 100644 index 000000000000..f1a0ee7ac52d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +NVFP4 quantization emulation for MoE. + +This file implements NVFP4 emulation for NVFP4 MOE in case the hardware used does not +natively support NVFP4 MOE. + +Weights are dequantized on the fly during each forward, we fall back to calling +`TritonExperts` using BF16, and fake NVFP4 quantize-dequantize +is applied on `a13`, `a2`. +""" + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + dequantize_to_dtype, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) + +logger = init_logger(__name__) + + +class Nvfp4QuantizationEmulationTritonExperts(TritonExperts): + """ + Extension of TritonExperts to support emulated NVFP4 MoE experts. + + It may be used for NVFP4 models when the device does not have + native support for this dtype. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + logger.warning_once( + "Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will" + " dequantize weights on the fly and may be slower than native" + " quantized MOE. Consider using a device with native quantization" + " support (e.g. Nvidia Blackwell) for better performance." + ) + + # `TritonExperts.apply` expects pre-dequantized weights, + # which we handle in `apply` below. + self.w1_scale_val = self.quant_config.w1_scale + self.w2_scale_val = self.quant_config.w2_scale + + self.quant_config._w1.scale = None + self.quant_config._w2.scale = None + + self.quantization_emulation = True + + @property + def quant_dtype(self) -> torch.dtype | str | None: + return "nvfp4" + + @property + def expects_unquantized_inputs(self) -> bool: + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + """ + Apply emulated quantized MoE computation. + + This dequantizes the weights on the fly and calls fused_experts_impl + with activation quantization support. + """ + # Dequantize weights if they are quantized + # For NVFP4, weights are packed in uint8 format + # w1 shape: [num_experts, 2*intermediate_size, hidden_size//2] + # w2 shape: [num_experts, hidden_size, intermediate_size//2] + assert w1.dtype == torch.uint8 + assert w2.dtype == torch.uint8 + + # Dequantize w1 from packed NVFP4 to fp16/bf16 + w13_global_scale = self.quant_config.g1_alphas + + w1_dequant = dequantize_to_dtype( + tensor_fp4=w1, + tensor_sf=self.w1_scale_val, + global_scale=w13_global_scale, + dtype=hidden_states.dtype, + block_size=16, + swizzle=False, + ) + + # Dequantize w2 from packed NVFP4 to fp16/bf16 + w2_global_scale = self.quant_config.g2_alphas + + w2_dequant = dequantize_to_dtype( + tensor_fp4=w2, + tensor_sf=self.w2_scale_val, + global_scale=w2_global_scale, + dtype=hidden_states.dtype, + block_size=16, + swizzle=False, + ) + + hidden_states, _ = moe_kernel_quantize_input( + A=hidden_states, + A_scale=self.quant_config.a1_gscale, + quant_dtype="nvfp4", + per_act_token_quant=False, + quantization_emulation=True, + ) + + # Activation quantization/dequantization is deferred to + # `moe_kernel_quantize_input` in TritonExperts.apply. + super().apply( + output=output, + hidden_states=hidden_states, + w1=w1_dequant, + w2=w2_dequant, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + a1q_scale=None, + a2_scale=self.quant_config.a2_gscale, + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py new file mode 100644 index 000000000000..f1d2669e6b89 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/ocp_mx_emulation_moe.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +OCP MX quantization emulation for MoE. + +This file implements OCP MX (MXFP4/MXFP6) emulation for MoE in case the +hardware used does not natively support OCP MX MoE. + +Weights are dequantized on the fly during each forward, we fall back to calling +`TritonExperts` using BF16, and fake OCP MX quantize-dequantize +is applied on activations via `moe_kernel_quantize_input`. +""" + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( + OCP_MX_Scheme, +) + +logger = init_logger(__name__) + + +class OCP_MXQuantizationEmulationTritonExperts(TritonExperts): + """ + Extension of TritonExperts to support emulated OCP MX MoE experts. + + It may be used for OCP MX (MXFP4/MXFP6) models when the device does not + have native support for these dtypes. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + logger.warning_once( + "Using OCP_MXQuantizationEmulationTritonExperts MOE backend. This" + " will dequantize weights on the fly and may be slower than native" + " quantized MOE. Consider using a device with native OCP MX" + " quantization support for better performance." + ) + + self.ocp_mx_scheme = quant_config.ocp_mx_scheme + assert self.ocp_mx_scheme is not None, ( + "ocp_mx_scheme must be set in quant_config for" + " OCP_MXQuantizationEmulationTritonExperts" + ) + + # `TritonExperts.apply` expects pre-dequantized weights, + # which we handle in `apply` below. + self.w1_scale_val = self.quant_config.w1_scale + self.w2_scale_val = self.quant_config.w2_scale + + self.quant_config._w1.scale = None + self.quant_config._w2.scale = None + + self.quantization_emulation = True + + if self.ocp_mx_scheme in { + OCP_MX_Scheme.w_mxfp4_a_mxfp4, + }: + # Weight has to be dequantized for mxfp4 emulation. + self._quant_dtype = "mxfp4" + elif self.ocp_mx_scheme in [ + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, + OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3, + ]: + self._quant_dtype = "mxfp6" + elif self.ocp_mx_scheme in [ + OCP_MX_Scheme.w_mxfp4_a_fp8, + OCP_MX_Scheme.w_mxfp6_e3m2_a_fp8, + ]: + # TODO: double check this one + self._quant_dtype = "mxfp8" + + @property + def quant_dtype(self) -> torch.dtype | str | None: + return self._quant_dtype + + @property + def expects_unquantized_inputs(self) -> bool: + return True + + @staticmethod + def _supports_quant_scheme( + weight_key, + activation_key, + ) -> bool: + # This class is used for emulation only - the oracle selects it + # directly rather than via quant scheme matching. + return True + + def _dequant_weights( + self, + w: torch.Tensor, + w_scale: torch.Tensor, + dtype: torch.dtype, + ) -> torch.Tensor: + """Dequantize weights based on the OCP MX scheme.""" + if self.ocp_mx_scheme.startswith("w_mxfp4"): # type: ignore[union-attr] + return dequant_mxfp4(w, w_scale, dtype) + elif self.ocp_mx_scheme.startswith("w_mxfp6_e3m2"): # type: ignore[union-attr] + return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e3m2", float_dtype=dtype) + elif self.ocp_mx_scheme.startswith("w_mxfp6_e2m3"): # type: ignore[union-attr] + return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e2m3", float_dtype=dtype) + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={self.ocp_mx_scheme}") + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + """ + Apply emulated quantized MoE computation. + + This dequantizes the weights on the fly and calls TritonExperts.apply + with activation quantization support. + """ + assert w1.dtype == torch.uint8 + assert w2.dtype == torch.uint8 + + # Dequantize w1 and w2 from packed OCP MX format to bf16/fp16 + w1_dequant = self._dequant_weights(w1, self.w1_scale_val, hidden_states.dtype) + w2_dequant = self._dequant_weights(w2, self.w2_scale_val, hidden_states.dtype) + + # Apply activation QDQ if needed by the OCP MX scheme + hidden_states, _ = moe_kernel_quantize_input( + A=hidden_states, + A_scale=None, + quant_dtype=self.quant_config.quant_dtype, + per_act_token_quant=False, + ocp_mx_scheme=self.ocp_mx_scheme, + quantization_emulation=True, + ) + + # Activation quantization/dequantization is deferred to + # `moe_kernel_quantize_input` in TritonExperts.apply. + super().apply( + output=output, + hidden_states=hidden_states, + w1=w1_dequant, + w2=w2_dequant, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + a1q_scale=None, + a2_scale=None, + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 81b778c8f4a7..74bf98645378 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import flashinfer + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -181,6 +181,8 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): + import flashinfer + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] assert a1q_scale is not None assert self.quant_config.w1_scale is not None @@ -299,6 +301,8 @@ def apply( routed_scaling_factor: float | None = None, topk_group: int | None = None, ) -> torch.Tensor: + import flashinfer + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] assert a1q_scale is not None assert self.quant_config.w1_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6de25da051ad..cb59b95078e6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -36,8 +36,6 @@ disable_inplace, moe_kernel_quantize_input, ) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 -from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Dynamic128Sym, @@ -1657,22 +1655,18 @@ def fused_experts_impl( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: + if ocp_mx_scheme is not None: + raise NotImplementedError( + f"Using ocp_mx_scheme={ocp_mx_scheme} in functional fused_experts call is " + "deprecated. Please use OCP_MXQuantizationEmulationTritonExperts." + ) + # Convert string activation to enum for internal use activation_enum = MoEActivation.from_str(activation) # Check constraints. if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" - elif ocp_mx_scheme is not None: - if ocp_mx_scheme.startswith("w_mxfp4"): - # 16bit activation and fp4x2 packed weight - assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" - elif ocp_mx_scheme.startswith("w_mxfp6"): - assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( - "hidden size mismatch" - ) - else: - raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") else: assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" @@ -1697,7 +1691,6 @@ def fused_experts_impl( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - ocp_mx_scheme=ocp_mx_scheme, dtype=hidden_states.dtype, ) @@ -1706,7 +1699,7 @@ def fused_experts_impl( quant_dtype = _get_config_quant_dtype( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, - ocp_mx_scheme=ocp_mx_scheme, + ocp_mx_scheme=None, ) get_config_func = functools.partial( @@ -1751,44 +1744,12 @@ def fused_experts_impl( out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) - if ocp_mx_scheme is not None: - # TODO: On platforms for which `current_platform.supports_mx()` is True - # and for which we have a native OCP mx fused MOE kernel, - # this dequantization step should not be done. - if ocp_mx_scheme.startswith("w_mxfp4"): - # Weight has to be dequantized for mxfp4 emulation. - w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) - w1_scale = None - w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) - w2_scale = None - elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): - w1 = dequant_mxfp6( - w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype - ) - w1_scale = None - w2 = dequant_mxfp6( - w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype - ) - w2_scale = None - elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): - w1 = dequant_mxfp6( - w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype - ) - w1_scale = None - w2 = dequant_mxfp6( - w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype - ) - w2_scale = None - else: - raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") - qhidden_states, a1q_scale = moe_kernel_quantize_input( A=hidden_states, A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, - ocp_mx_scheme=ocp_mx_scheme, ) # SPARSITY_FACTOR is a heuristic margin ensuring num_tokens * top_k @@ -1856,7 +1817,6 @@ def fused_experts_impl( quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, - ocp_mx_scheme=ocp_mx_scheme, ) if expert_map is not None: @@ -1902,6 +1862,9 @@ def __init__( moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): + # Whether quantized MOE runs natively, or through + # higher-precision + activation QDQ. + self.quantization_emulation = False super().__init__(moe_config, quant_config) @staticmethod @@ -2112,6 +2075,7 @@ def apply( self.quant_dtype, self.per_act_token_quant, self.block_shape, + quantization_emulation=self.quantization_emulation, ) invoke_fused_moe_triton_kernel( diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index d4a0817e0be0..b42b48f293e5 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -61,6 +61,8 @@ class Mxfp4MoeBackend(Enum): TRITON_UNFUSED = "TRITON_UNFUSED" # XPU XPU = "XPU" + # Emulation + EMULATION = "EMULATION" # Backends that share the same TRTLLM weight format @@ -142,6 +144,13 @@ def backend_to_kernel_cls( return [XPUExpertsMXFp4] + elif backend == Mxfp4MoeBackend.EMULATION: + from vllm.model_executor.layers.fused_moe.experts.ocp_mx_emulation_moe import ( + OCP_MXQuantizationEmulationTritonExperts, + ) + + return [OCP_MXQuantizationEmulationTritonExperts] + else: raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}") @@ -157,6 +166,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: "marlin": Mxfp4MoeBackend.MARLIN, "aiter": Mxfp4MoeBackend.AITER, "xpu": Mxfp4MoeBackend.XPU, + "emulation": Mxfp4MoeBackend.EMULATION, } if backend := mapping.get(runner_backend): return backend @@ -180,6 +190,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, Mxfp4MoeBackend.XPU, + Mxfp4MoeBackend.EMULATION, ] return _AVAILABLE_BACKENDS @@ -762,6 +773,17 @@ def _interleave_mxfp4_cutlass_sm90(w): w13_bias, w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.EMULATION: + # No additional transformation needed for emulation backend, + # weights are dequantized on the fly in the experts class. + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) else: raise ValueError( f"Unsupported mxfp4_backend: {mxfp4_backend}: " diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 597d784d3b63..6d0b66cb9f53 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -45,6 +45,7 @@ class NvFp4MoeBackend(Enum): FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" VLLM_CUTLASS = "VLLM_CUTLASS" MARLIN = "MARLIN" + EMULATION = "EMULATION" FLASHINFER_NVFP4_MOE_BACKENDS = [ @@ -118,6 +119,12 @@ def backend_to_kernel_cls( ) return [MarlinExperts] + elif backend == NvFp4MoeBackend.EMULATION: + from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import ( + Nvfp4QuantizationEmulationTritonExperts, + ) + + return [Nvfp4QuantizationEmulationTritonExperts] else: raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") @@ -130,6 +137,7 @@ def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend: "flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS, "flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL, "marlin": NvFp4MoeBackend.MARLIN, + "emulation": NvFp4MoeBackend.EMULATION, } if backend := mapping.get(runner_backend): return backend @@ -157,6 +165,7 @@ def select_nvfp4_moe_backend( NvFp4MoeBackend.FLASHINFER_CUTLASS, NvFp4MoeBackend.VLLM_CUTLASS, NvFp4MoeBackend.MARLIN, + NvFp4MoeBackend.EMULATION, ] # NOTE(rob): this is kind of a hack. We need to peak into @@ -372,6 +381,30 @@ def convert_to_nvfp4_moe_kernel_format( w2_scale_2=w2_scale_2, is_act_and_mul=is_act_and_mul, ) + elif nvfp4_backend == NvFp4MoeBackend.EMULATION: + if a13_scale is None or a2_scale is None: + raise ValueError( + "Activation global scales should not be None, got" + f" a13_scale={a13_scale}, a2_scale={a2_scale}" + ) + + if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1: + logger.warning_once( + "In NVFP4 linear, the activation global scale for inputs are different" + " for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using" + " a13_scale = a13_scale.max() and a2_scale = a2_scale.max()." + ) + + # 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py. + # 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant + # use the inverse scale directly (large global scale). + # NOTE: Before this point, `a13_scale` and `a2_scale` are such that: + # `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`, + # and `global_scale[expert_id]` are small (~1e-4). + # Taking the largest global scale likely results in overflowing the FP8 range + # for other experts - other selection strategies may be used. + a13_scale = 1.0 / a13_scale.max().to(torch.float32) + a2_scale = 1.0 / a2_scale.max().to(torch.float32) else: raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") @@ -403,6 +436,15 @@ def make_nvfp4_moe_quant_config( w1_scale=w13_scale, w2_scale=w2_scale, ) + elif backend == NvFp4MoeBackend.EMULATION: + return nvfp4_moe_quant_config( + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, + a1_gscale=a13_scale, + a2_gscale=a2_scale, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) # Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas. # The expert's process_weights_after_loading will fuse activation diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ce1e49bc4b0b..d8e174051b51 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -22,6 +22,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_e4m3_quantize, ) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + ref_nvfp4_quant_dequant, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( per_tensor_dequantize, ) @@ -253,6 +256,7 @@ def moe_kernel_quantize_input( block_shape: list[int] | None = None, is_fp4_scale_swizzled: bool = True, ocp_mx_scheme: str | None = None, + quantization_emulation: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation if ocp_mx_scheme is not None: @@ -274,16 +278,41 @@ def moe_kernel_quantize_input( # activation quantization below. if quant_dtype == current_platform.fp8_dtype(): + if quantization_emulation: + raise NotImplementedError( + f"moe_kernel_quantize_input does not support quant_dtype={quant_dtype}" + " MOE quantization emulation. Please open an issue." + ) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: + if quantization_emulation: + raise NotImplementedError( + "moe_kernel_quantize_input does not support quant_dtype=torch.int8" + " MOE quantization emulation. Please open an issue." + ) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "nvfp4": - return _nvfp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) + if not quantization_emulation: + return _nvfp4_quantize( + A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled + ) + else: + return ref_nvfp4_quant_dequant(A, A_scale, block_size=16) elif quant_dtype == "mxfp4": + if not quantization_emulation: + raise NotImplementedError( + "moe_kernel_quantize_input should not be used for native" + " quant_dtype='mxfp4' MOE. Please open an issue." + ) return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "mxfp8": # TODO: `quant_dtype == "mxfp8"` is ambiguous, # should be fp8_e4m3. OCP MX also defines `fp8_e5m2`. + if quantization_emulation: + raise NotImplementedError( + "moe_kernel_quantize_input does not support quant_dtype='mxfp8' MOE " + "quantization emulation. Please open an issue." + ) return _mxfp8_e4m3_quantize( A, A_scale, @@ -292,8 +321,20 @@ def moe_kernel_quantize_input( is_sf_swizzled_layout=is_fp4_scale_swizzled, ) elif quant_dtype == "mxfp6_e3m2": + if not quantization_emulation: + raise NotImplementedError( + "moe_kernel_quantize_input should not be used for native " + " quant_dtype='mxfp6_e3m2'MOE. Please open an issue." + ) + return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "mxfp6_e2m3": + if not quantization_emulation: + raise NotImplementedError( + "moe_kernel_quantize_input should not be used for native" + " quant_dtype='mxfp6_e2m3' MOE. Please open an issue." + ) + return _mxfp6_e2m3_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 33bd0cfc22e8..f53c87002c84 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -25,6 +25,7 @@ QuarkMoEMethod, ) from vllm.model_executor.layers.quantization.quark.schemes import ( + QuarkNVFP4, QuarkOCP_MX, QuarkScheme, QuarkW4A8_MXFP4_FP8, @@ -83,13 +84,15 @@ def maybe_update_config( quant_config = getattr(hf_config, "quantization_config", None) if quant_config is not None: - quant_dtype = ( - quant_config.get("global_quant_config", {}) - .get("weight", {}) - .get("dtype") - ) - if quant_dtype == "fp4": - self.dynamic_mxfp4_quant = True + # global_quant_config's weight may be a list for NVFP4. + weight_config = quant_dtype = quant_config.get( + "global_quant_config", {} + ).get("weight", []) + + if not isinstance(weight_config, list): + quant_dtype = weight_config["dtype"] + if quant_dtype == "fp4": + self.dynamic_mxfp4_quant = True def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -420,6 +423,54 @@ def _is_dynamic_per_token_w8a8( and is_weight_symmetric ) + def _is_nvfp4( + self, + weight_quant: dict[str, Any] | list[dict[str, Any]] | None, + input_quant: dict[str, Any] | list[dict[str, Any]] | None, + ) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm both weight_quant and input_quant are lists with 2 elements + if not isinstance(weight_quant, list) or len(weight_quant) != 2: + return False + if not isinstance(input_quant, list) or len(input_quant) != 2: + return False + + # First element should be fp4 with per_group quantization + is_fp4_per_group_weight = ( + weight_quant[0].get("dtype") == "fp4" + and weight_quant[0].get("qscheme") == "per_group" + and weight_quant[0].get("group_size") == 16 + and not weight_quant[0].get("is_dynamic") + ) + is_fp4_per_group_input = ( + input_quant[0].get("dtype") == "fp4" + and input_quant[0].get("qscheme") == "per_group" + and input_quant[0].get("group_size") == 16 + and input_quant[0].get("is_dynamic") + ) + + # Second element should be fp8_e4m3 with per_tensor quantization + is_fp8_per_tensor_weight = ( + weight_quant[1].get("dtype") == "fp8_e4m3" + and weight_quant[1].get("qscheme") == "per_tensor" + and not weight_quant[1].get("is_dynamic") + ) + is_fp8_per_tensor_input = ( + input_quant[1].get("dtype") == "fp8_e4m3" + and input_quant[1].get("qscheme") == "per_tensor" + and not input_quant[1].get("is_dynamic") + ) + + return ( + is_fp4_per_group_weight # type: ignore[return-value] + and is_fp4_per_group_input + and is_fp8_per_tensor_weight + and is_fp8_per_tensor_input + ) + def _is_w_ocp_mx_a_x( self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None ) -> bool: @@ -568,7 +619,9 @@ def _get_scheme_from_config( weight_config = cast(dict[str, Any], config.get("weight")) input_config = cast(dict[str, Any], config.get("input_tensors")) - if self._is_fp8_w8a8(weight_config, input_config): + if self._is_nvfp4(weight_config, input_config): + return QuarkNVFP4() + elif self._is_fp8_w8a8(weight_config, input_config): is_fp8_w8a8_supported = self._check_scheme_supported( QuarkW8A8Fp8.get_min_capability(), error=False ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58ed8940b385..ac9927a5bdda 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -30,12 +30,19 @@ from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( TRITON_BACKENDS, Mxfp4MoeBackend, + backend_to_kernel_cls, convert_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) @@ -46,7 +53,11 @@ OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + kNvfp4Dynamic, + kNvfp4Static, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, @@ -60,7 +71,9 @@ __all__ = [ "QuarkMoEMethod", + "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod", + "QuarkNvfp4MoEMethod", "QuarkOCP_MX_MoEMethod_OSS", ] @@ -90,6 +103,10 @@ def get_moe_method( if quant_config._is_fp8_w4a8(weight_config, input_config): return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) + elif quant_config._is_nvfp4(weight_config, input_config): + return QuarkNvfp4MoEMethod( + weight_config, input_config, module.moe_config, quant_config + ) elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): @@ -986,7 +1003,10 @@ def __init__( f"Please check that the combination is supported in OCP_MX_Scheme." ) + # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes, + # use kernel abstraction for all OCP MX MOE implementations. self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE + self.experts_cls: type[mk.FusedMoEExperts] | None = None self.moe_kernel: mk.FusedMoEKernel | None = None @@ -994,12 +1014,6 @@ def __init__( self.w13_precision_config = None self.w2_precision_config = None - if self.ocp_mx_scheme == "w_mxfp4": - self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) - elif self.ocp_mx_scheme.startswith("w_mxfp4"): - # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. - self.mxfp4_backend = Mxfp4MoeBackend.NONE - if self.input_quant is not None: self.static_input_scales = not self.input_quant.get("is_dynamic") else: @@ -1035,6 +1049,18 @@ def __init__( self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe ) + if self.ocp_mx_scheme == "w_mxfp4": + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + + if self.emulate: + # We use the same code path between MXFP4/MXFP6 emulation. + self.mxfp4_backend = Mxfp4MoeBackend.EMULATION + + # TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that + # all MXFP4 backends use the kernel abstraction. + if self.mxfp4_backend != Mxfp4MoeBackend.NONE: + self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0] + if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " @@ -1063,7 +1089,12 @@ def maybe_roundup_sizes( act_dtype=act_dtype, moe_parallel_config=moe_parallel_config, ) - if self.mxfp4_backend is not None: + # In case quantization emulation backend is used, there is no need to apply + # MXFP4-specific padding logic as the compute happens in higher precision. + if ( + self.mxfp4_backend is not None + and self.mxfp4_backend != Mxfp4MoeBackend.EMULATION + ): hidden_size, intermediate_size_per_partition = ( mxfp4_round_up_hidden_size_and_intermediate_size( self.mxfp4_backend, hidden_size, intermediate_size_per_partition @@ -1237,21 +1268,13 @@ def process_weights_after_loading(self, layer): ) # For w_mxfp4, use oracle functions - if ( + if self.emulate or ( self.ocp_mx_scheme == "w_mxfp4" and self.mxfp4_backend != Mxfp4MoeBackend.NONE ): self._setup_kernel_via_oracle(layer) return - # TODO(bowenbao): gradually migrate to oracles. - # secondly, process mxfp weights for other schemes - if self.emulate: - # Build quant config for emulation path - self.moe_quant_config = self.get_fused_moe_quant_config(layer) - torch.accelerator.empty_cache() - return - # Existing AITER path for w_mxfp4_a_mxfp4 and other schemes from aiter.utility.fp4_utils import e8m0_shuffle @@ -1345,9 +1368,9 @@ def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: # For w_mxfp4 with oracle backend, use oracle function - if ( - self.ocp_mx_scheme == "w_mxfp4" - and self.mxfp4_backend != Mxfp4MoeBackend.NONE + if self.ocp_mx_scheme == "w_mxfp4" and self.mxfp4_backend not in ( + Mxfp4MoeBackend.NONE, + Mxfp4MoeBackend.EMULATION, ): w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale @@ -1362,9 +1385,7 @@ def get_fused_moe_quant_config( w2_bias=getattr(layer, "w2_bias", None), ) - # Existing code for other schemes - # TODO(bowenbao): kept for emulation fallback, to be refactored into - # dedicated emulation backend. + # Emulation and other schemes if self.ocp_mx_scheme == "w_mxfp4": return mxfp4_w4a16_moe_quant_config( w1_scale=layer.w13_weight_scale, @@ -1414,7 +1435,7 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor: - # For w_mxfp4 with oracle kernel + # For oracle kernel or emulation kernel if self.moe_kernel is not None: return self.moe_kernel.apply( hidden_states=x, @@ -1429,39 +1450,23 @@ def apply( shared_experts_input=shared_experts_input, ) - # Existing code for emulation/AITER paths - if not self.emulate: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) - - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - quant_config=self.moe_quant_config, - moe_config=layer.moe_config, - expert_map=layer.expert_map, - ) - else: - from vllm.model_executor.layers.fused_moe import fused_experts + # AITER path + # TODO: Refactor this to use modular MOE kernel as well. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=not self.moe.disable_inplace, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - ) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + moe_config=layer.moe_config, + expert_map=layer.expert_map, + ) def apply_monolithic( self, @@ -1613,3 +1618,228 @@ def apply_monolithic( unpadded_N_w2=self.moe.hidden_dim_unpadded, unpadded_K_w2=self.moe.intermediate_size_per_partition_unpadded, ) + + +class QuarkNvfp4MoEMethod(QuarkMoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + ): + super().__init__(moe) + self.weight_quant = weight_config + self.input_quant = input_config + self.quant_config = quant_config + self.group_size = 16 + + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=kNvfp4Dynamic, + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + # GEMM 1 - w13 weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # GEMM 2 - w2 weight + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight scales (per-group FP8 scales) + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Global weight scales (per-tensor FP32 scales) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input global scales (per-tensor FP32 scales) + w13_input_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale_2", w13_input_scale_2) + set_weight_attrs(w13_input_scale_2, extra_weight_attrs) + + w2_input_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale_2", w2_input_scale_2) + set_weight_attrs(w2_input_scale_2, extra_weight_attrs) + + def process_weights_after_loading(self, layer: FusedMoE) -> None: + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ + + if not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): + raise ValueError("Different global scales for w1 and w3 is not supported.") + + # Use a single gscale for w13 + w13_weight_scale_2 = torch.maximum( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ).contiguous() + + w2_weight_scale_2 = layer.w2_weight_scale_2 + + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=w13_weight_scale_2, + a13_scale=layer.w13_input_scale_2, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=w2_weight_scale_2, + a2_scale=layer.w2_input_scale_2, + is_act_and_mul=self.moe.is_act_and_mul, + ) + + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w13_weight_scale_2", w13_scale_2) + replace_parameter(layer, "w13_input_scale_2", a13_scale) + + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) + replace_parameter(layer, "w2_input_scale_2", a2_scale) + + # Setup modular kernel. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + assert self.experts_cls is not None + self.moe_mk = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale_2, + a2_scale=layer.w2_input_scale_2, + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: Any | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.moe_mk is not None + return self.moe_mk.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index a5e33a0442b1..1ef5824fec53 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .quark_nvfp4 import QuarkNVFP4 from .quark_ocp_mx import QuarkOCP_MX from .quark_scheme import QuarkScheme from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8 @@ -13,4 +14,5 @@ "QuarkW8A8Int8", "QuarkOCP_MX", "QuarkW4A8_MXFP4_FP8", + "QuarkNVFP4", ] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py new file mode 100644 index 000000000000..c71b255d15b7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_nvfp4.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.kernels.linear import init_nvfp4_linear_kernel +from vllm.model_executor.kernels.linear.nvfp4.emulation import ( + EmulationNvFp4LinearKernel, +) +from vllm.model_executor.layers.quantization.quark.schemes.quark_scheme import ( + QuarkScheme, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) + +__all__ = ["QuarkNVFP4"] + +logger = init_logger(__name__) + + +class QuarkNVFP4(QuarkScheme): + """ + Quark NVFP4 quantization scheme. + + Supports loading NVFP4 checkpoints with the following structure: + - weight: uint8, shape [out_features, in_features // 2] (packed FP4) + - weight_scale: float8_e4m3fn, shape [out_features, in_features // group_size] + - weight_scale_2: bfloat16/float32, scalar (global weight scale) + - input_scale_2: bfloat16/float32, scalar (global input scale) + """ + + def __init__( + self, + ): + self.kernel = init_nvfp4_linear_kernel() + self.group_size = 16 + + if not isinstance(self.kernel, EmulationNvFp4LinearKernel): + logger.warning_once( + "Only EmulationNvFp4LinearKernel is tested with" + " QuarkNVFP4, got kernel=%s. Use at your own risk.", + type(self.kernel).__name__, + ) + + @classmethod + def get_min_capability(cls) -> int: + # FP4 requires Turing (75) or newer + return 75 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % self.group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must be " + f"divisible by group size ({self.group_size})" + ) + + # Weight: FP4 packed as uint8 (2 FP4 values per uint8) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Per-group weight scale (FP8 E4M3) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # Global weight scale (scalar, per partition) + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Global input scale (scalar, per partition) + input_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale_2", input_scale_2) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_global_scale = layer.input_scale_2.max().to(torch.float32) + layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) + del layer.input_scale_2 + + weight_global_scale = layer.weight_scale_2.to(torch.float32) + + if torch.unique(weight_global_scale).numel() != 1: + logger.warning_once( + "In NVFP4 linear, the global scale for weight are different" + " for parallel layers (e.g. q_proj, k_proj, v_proj). This" + " will likely result in reduced accuracy. Please verify the" + " model accuracy. Consider using a checkpoint with a shared" + " global NVFP4 scale for fused layers." + ) + + weight_global_scale = weight_global_scale.max() + + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + del layer.weight_scale_2 + + layer.alpha = Parameter( + layer.input_global_scale * layer.weight_global_scale, requires_grad=False + ) + layer.input_global_scale_inv = Parameter( + (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False + ) + + # Convert layer to NVFP4 linear kernel format + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer=layer, x=x, bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 98ac1a4f355e..ee55e5d39e70 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -17,7 +17,8 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return False return all(deep_compare(dict1[k], dict2[k]) for k in dict1) elif isinstance(dict1, list): - return set(dict1) == set(dict2) + # `dict1` may be a list of dict. + return all(deep_compare(dict1[i], dict2[i]) for i in range(len(dict1))) else: return dict1 == dict2 diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 9a0c52b62c10..a9e85abcf08d 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -53,26 +53,52 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): def dequantize_to_dtype( tensor_fp4: torch.Tensor, tensor_sf: torch.Tensor, - global_scale: torch.Tensor | float, + global_scale: torch.Tensor, dtype: torch.dtype, block_size: int = 16, swizzle: bool | None = True, ): - """Dequantize the fp4 tensor back to high precision.""" + """Dequantize the fp4 tensor back to high precision. + + Supports both 2D and 3D inputs: + - 2D: [m, packed_k] -> [m, k] + - 3D: [dim0, m, packed_k] -> [dim0, m, k] + """ # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape + + # We handle 3D tensors reshaping them to 2D. + is_3d = tensor_fp4.ndim == 3 + + if is_3d: + dim0, m, packed_k = tensor_fp4.shape + tensor_fp4 = tensor_fp4.reshape(-1, packed_k) + tensor_sf = tensor_sf.reshape(-1, tensor_sf.shape[-1]) + global_scale = global_scale[:, None, None] + else: + m, packed_k = tensor_fp4.shape + k = packed_k * 2 tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_f32 = tensor_f32.reshape(-1, k // block_size, block_size) tensor_sf = tensor_sf.view(torch.float8_e4m3fn) if swizzle: - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf = convert_swizzled_to_linear( # noqa: E501 + tensor_sf, tensor_f32.size(0), k, block_size + ) + + if is_3d: + tensor_sf = tensor_sf.reshape(dim0, m, k // block_size) tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale + if is_3d: + tensor_f32 = tensor_f32.reshape(dim0, m, -1, block_size) + # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + out = tensor_f32 * tensor_sf_dtype.unsqueeze(-1) + out = out.reshape(*out.shape[:-2], -1) + return out.to(dtype) @@ -117,6 +143,29 @@ def ref_nvfp4_quant(x, global_scale, block_size): return cast_to_fp4(clipped_x), scale.squeeze(-1) +def ref_nvfp4_quant_dequant( + x: torch.Tensor, global_scale: torch.Tensor, block_size: int +) -> tuple[torch.Tensor, None]: + """ + NVFP4 quantize-dequantize operation. + + `global_scale` is expected to have a single element. + """ + x_m, x_k = x.shape + output_dtype = x.dtype + + # quantize input to (FP4 and interleaved block scale) + x_fp4, x_blockscale = ref_nvfp4_quant(x, global_scale, block_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size) + x_blockscale = x_blockscale.unsqueeze(-1) / global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + return x_dq, None + + def run_nvfp4_emulations( x: torch.Tensor, input_global_scale: torch.Tensor, @@ -125,18 +174,10 @@ def run_nvfp4_emulations( weight_global_scale: torch.Tensor, swizzle: bool | None = True, ): - group_size = 16 - x_m, x_k = x.shape output_dtype = x.dtype + group_size = 16 - # quantize input to (FP4 and interleaved block scale) - x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size) - - # dequantize input - x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size) - x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale - x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) - del x_fp4, x_blockscale + x_dq, _ = ref_nvfp4_quant_dequant(x, input_global_scale, block_size=group_size) # dequantize weight w_fp4 = weight.data.view(torch.uint8)