Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/model_executor/test_qwen3_5_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,51 @@ def test_qwen3_5_mtp_lm_head_receives_quant_config():
MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config


def test_qwen3_moe_lm_head_receives_quant_config():
from vllm.model_executor.models.qwen3_moe import (
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeSparseMoeBlock,
)

mock_quant_config = Mock()

mock_hf_config = Mock()
mock_hf_config.tie_word_embeddings = False
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64

mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_text_config = mock_hf_config
mock_vllm_config.quant_config = mock_quant_config

# Build just enough of a Qwen3-MoE layer to satisfy the metadata scan in
# Qwen3MoeForCausalLM.__init__ without constructing the full model.
fake_layer = object.__new__(Qwen3MoeDecoderLayer)
fake_mlp = object.__new__(Qwen3MoeSparseMoeBlock)
for attr, value in {
"experts": Mock(),
"n_logical_experts": 2,
"n_physical_experts": 2,
"n_local_physical_experts": 2,
"n_routed_experts": 2,
"n_redundant_experts": 0,
}.items():
object.__setattr__(fake_mlp, attr, value)
object.__setattr__(fake_layer, "mlp", fake_mlp)

with (
patch("vllm.model_executor.models.qwen3_moe.Qwen3MoeModel") as MockModel,
patch("vllm.model_executor.models.qwen3_moe.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.qwen3_moe.LogitsProcessor"),
):
MockModel.return_value.make_empty_intermediate_tensors = Mock()
MockModel.return_value.layers = [fake_layer]

Qwen3MoeForCausalLM(vllm_config=mock_vllm_config)

MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
17 changes: 17 additions & 0 deletions tests/quantization/test_modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ModelOptMixedPrecisionConfig,
ModelOptNvFp4Config,
ModelOptNvFp4LinearMethod,
ModelOptNvFp4W4A16LinearMethod,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
Expand Down Expand Up @@ -126,6 +127,22 @@ def test_modelopt_mixed_precision_quantizes_parallel_lm_head():
assert isinstance(method, ModelOptNvFp4LinearMethod)


@pytest.mark.parametrize("prefix", ["lm_head", "model.lm_head"])
def test_modelopt_mixed_precision_quantizes_w4a16_parallel_lm_head(prefix):
"""Official ModelOpt mixed-precision NVFP4 checkpoints may quantize
``lm_head`` as W4A16_NVFP4 instead of leaving it BF16. Keep this covered
separately from generic linear layers because LM heads are implemented by
``ParallelLMHead`` rather than ``LinearBase``.
"""
config = _mixed_precision_config(
{"lm_head": {"quant_algo": "W4A16_NVFP4", "group_size": 16}}
)

method = config.get_quant_method(_mock_lm_head(), prefix=prefix)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mock Marlin support before instantiating W4A16 method

This new parametrized test runs unconditionally, but get_quant_method() constructs ModelOptNvFp4W4A16LinearMethod, whose constructor directly instantiates MarlinNvFp4LinearKernel; that kernel asserts is_fp4_marlin_supported() (current_platform.is_cuda() and capability >= 75). In CPU-only or unsupported-GPU test jobs this fails with an AssertionError, unlike the adjacent NVFP4 tests that patch kernel selection. Patch the Marlin kernel/support check here or skip the test when FP4 Marlin is unavailable.

Useful? React with 👍 / 👎.


assert isinstance(method, ModelOptNvFp4W4A16LinearMethod)


def test_vocab_parallel_embedding_weight_loader_accepts_scalar_scale():
holder = Mock()
scale = torch.nn.Parameter(torch.empty(1))
Expand Down
Loading