Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/quantization/test_quark_maybe_update_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for QuarkConfig.maybe_update_config.

Fetches real HF configs (metadata only, no model weights) to verify
that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models.

Run: pytest tests/quantization/test_quark_maybe_update_config.py -v
"""

import pytest
from transformers import AutoConfig

from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig


def _make_quark_config() -> QuarkConfig:
"""Create a minimal QuarkConfig for testing."""
return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder")


# ---------------------------------------------------------------------------
# Non-deepseek models must not flip dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/MiniMax-M2.1-MXFP4"],
)
def test_non_deepseek_model_stays_false(model_name: str):
"""Non-deepseek_v3 models must not enable dynamic_mxfp4_quant."""
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()

qcfg.maybe_update_config(model_name, hf_config=hf_config)

assert qcfg.dynamic_mxfp4_quant is False


# ---------------------------------------------------------------------------
# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"model_name",
["amd/DeepSeek-R1-MXFP4-ASQ"],
)
def test_deepseek_family_fp4_enables_flag(model_name: str):
hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qcfg = _make_quark_config()

qcfg.maybe_update_config(model_name, hf_config=hf_config)

assert qcfg.dynamic_mxfp4_quant is True


# ---------------------------------------------------------------------------
# Missing hf_config → warn and stay False
# ---------------------------------------------------------------------------
def test_missing_hf_config_stays_false():
qcfg = _make_quark_config()

qcfg.maybe_update_config("some/model")

assert qcfg.dynamic_mxfp4_quant is False
5 changes: 4 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,10 @@ def _get_quantization_config(
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}"
)
quant_config.maybe_update_config(model_config.model)
quant_config.maybe_update_config(
model_config.model,
hf_config=model_config.hf_config,
)
return quant_config
return None

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig

from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand Down Expand Up @@ -146,7 +147,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.modules_to_not_convert
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn import Parameter
from transformers import PretrainedConfig

import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
Expand Down Expand Up @@ -332,7 +333,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.modules_to_not_convert
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

Expand Down
16 changes: 15 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from torch import nn
from transformers import PretrainedConfig

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
Expand Down Expand Up @@ -168,10 +169,23 @@ def apply_vllm_mapper( # noqa: B027
# TODO (@kylesayrs): add implementations for all subclasses
pass

def maybe_update_config(self, model_name: str): # noqa: B027
def maybe_update_config( # noqa: B027
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
"""
Interface to update values after config initialization.

Args:
model_name: The name of the model
hf_config: The Hugging Face config of the model
revision: The revision of the model
Returns:
"""
# TODO: revision is never passed currently in vllm.py,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea should be okay to drop revision.

cc @dllehr-amd

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do on a follow up PR

# but is used in subclasses, should we remove this parameter?
pass

def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/cpu_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig

from vllm._custom_ops import (
cpu_gemm_wna16,
Expand Down Expand Up @@ -133,7 +134,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.modules_to_not_convert
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_to_not_convert:
return

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig

from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand Down Expand Up @@ -193,7 +194,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.modules_in_block_to_quantize
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig

import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
Expand Down Expand Up @@ -299,7 +300,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper):
self.modules_in_block_to_quantize
)

def maybe_update_config(self, model_name: str, revision: str | None = None):
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
Expand Down
37 changes: 25 additions & 12 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, cast

import torch
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
Expand Down Expand Up @@ -36,7 +37,6 @@
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config

if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
Expand All @@ -45,6 +45,10 @@

logger = init_logger(__name__)

# model_type values that use dynamic MXFP4 re-quantization for
# OCP MX fp4 Quark checkpoints
_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"})


class QuarkConfig(QuantizationConfig):
def __init__(
Expand All @@ -63,19 +67,28 @@ def __init__(
self.pack_method = pack_method
self.dynamic_mxfp4_quant = False

def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
def maybe_update_config(
self,
model_name: str,
hf_config: PretrainedConfig | None = None,
revision: str | None = None,
):
"""Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints."""

quant_config = getattr(self.hf_config, "quantization_config", None)
if (
getattr(hf_config, "model_type", None)
not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES
):
return

quant_config = getattr(hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
quant_dtype = (
quant_config.get("global_quant_config", {})
.get("weight", {})
.get("dtype")
)
if quant_dtype == "fp4":
self.dynamic_mxfp4_quant = True

def get_linear_method(self) -> "QuarkLinearMethod":
Expand Down
Loading