diff --git a/tests/quantization/test_quark_maybe_update_config.py b/tests/quantization/test_quark_maybe_update_config.py new file mode 100644 index 000000000000..0142e869c22c --- /dev/null +++ b/tests/quantization/test_quark_maybe_update_config.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for QuarkConfig.maybe_update_config. + +Fetches real HF configs (metadata only, no model weights) to verify +that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models. + +Run: pytest tests/quantization/test_quark_maybe_update_config.py -v +""" + +import pytest +from transformers import AutoConfig + +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + + +def _make_quark_config() -> QuarkConfig: + """Create a minimal QuarkConfig for testing.""" + return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder") + + +# --------------------------------------------------------------------------- +# Non-deepseek models must not flip dynamic_mxfp4_quant +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + "model_name", + ["amd/MiniMax-M2.1-MXFP4"], +) +def test_non_deepseek_model_stays_false(model_name: str): + """Non-deepseek_v3 models must not enable dynamic_mxfp4_quant.""" + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + qcfg = _make_quark_config() + + qcfg.maybe_update_config(model_name, hf_config=hf_config) + + assert qcfg.dynamic_mxfp4_quant is False + + +# --------------------------------------------------------------------------- +# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + "model_name", + ["amd/DeepSeek-R1-MXFP4-ASQ"], +) +def test_deepseek_family_fp4_enables_flag(model_name: str): + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + qcfg = _make_quark_config() + + qcfg.maybe_update_config(model_name, hf_config=hf_config) + + assert qcfg.dynamic_mxfp4_quant is True + + +# --------------------------------------------------------------------------- +# Missing hf_config → warn and stay False +# --------------------------------------------------------------------------- +def test_missing_hf_config_stays_false(): + qcfg = _make_quark_config() + + qcfg.maybe_update_config("some/model") + + assert qcfg.dynamic_mxfp4_quant is False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b6be7f10bdb0..55f35c13553f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -526,7 +526,10 @@ def _get_quantization_config( f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}" ) - quant_config.maybe_update_config(model_config.model) + quant_config.maybe_update_config( + model_config.model, + hf_config=model_config.hf_config, + ) return quant_config return None diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 3cf3116f0670..58bb75d0a9ed 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -5,6 +5,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -146,7 +147,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 426b9aa71562..03dfaa7949c0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -6,6 +6,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn import Parameter +from transformers import PretrainedConfig import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops @@ -332,7 +333,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 06fe4270c713..eedc62f7d4d5 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -7,6 +7,7 @@ import torch from torch import nn +from transformers import PretrainedConfig if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods @@ -168,10 +169,23 @@ def apply_vllm_mapper( # noqa: B027 # TODO (@kylesayrs): add implementations for all subclasses pass - def maybe_update_config(self, model_name: str): # noqa: B027 + def maybe_update_config( # noqa: B027 + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): """ Interface to update values after config initialization. + + Args: + model_name: The name of the model + hf_config: The Hugging Face config of the model + revision: The revision of the model + Returns: """ + # TODO: revision is never passed currently in vllm.py, + # but is used in subclasses, should we remove this parameter? pass def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py index ea7afef27ebd..3dba317438ec 100644 --- a/vllm/model_executor/layers/quantization/cpu_wna16.py +++ b/vllm/model_executor/layers/quantization/cpu_wna16.py @@ -5,6 +5,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig from vllm._custom_ops import ( cpu_gemm_wna16, @@ -133,7 +134,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 154347a930a9..458741478538 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -9,6 +9,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter +from transformers import PretrainedConfig from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -193,7 +194,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d7b2a366e1f0..8e367c88346f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -6,6 +6,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops @@ -299,7 +300,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper): self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 78c64bac6187..d0362cedcf2b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, cast import torch +from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -36,7 +37,6 @@ ) from vllm.model_executor.models.utils import WeightsMapper from vllm.platforms import current_platform -from vllm.transformers_utils.config import get_config if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -45,6 +45,10 @@ logger = init_logger(__name__) +# model_type values that use dynamic MXFP4 re-quantization for +# OCP MX fp4 Quark checkpoints +_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"}) + class QuarkConfig(QuantizationConfig): def __init__( @@ -63,19 +67,28 @@ def __init__( self.pack_method = pack_method self.dynamic_mxfp4_quant = False - def maybe_update_config(self, model_name: str, revision: str | None = None): - self.hf_config = get_config( - model=model_name, - trust_remote_code=False, # or get from model_config if available - revision=revision, - config_format="auto", - ) + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): + """Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints.""" - quant_config = getattr(self.hf_config, "quantization_config", None) + if ( + getattr(hf_config, "model_type", None) + not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES + ): + return + + quant_config = getattr(hf_config, "quantization_config", None) if quant_config is not None: - quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"] - model_type = self.hf_config.model_type - if quant_dtype == "fp4" and model_type == "deepseek_v3": + quant_dtype = ( + quant_config.get("global_quant_config", {}) + .get("weight", {}) + .get("dtype") + ) + if quant_dtype == "fp4": self.dynamic_mxfp4_quant = True def get_linear_method(self) -> "QuarkLinearMethod":