Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/model_executor/test_nemotron_h_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import Mock, patch


def test_nemotron_h_lm_head_receives_quant_config():
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM

mock_quant_config = Mock()

mock_hf_config = Mock()
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64

mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_config = mock_hf_config
mock_vllm_config.model_config.dtype = None
mock_vllm_config.scheduler_config = Mock()
mock_vllm_config.quant_config = mock_quant_config

with (
patch("vllm.model_executor.models.nemotron_h.NemotronHModel") as MockModel,
patch("vllm.model_executor.models.nemotron_h.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.nemotron_h.LogitsProcessor"),
):
MockModel.return_value.make_empty_intermediate_tensors = Mock()
MockModel.return_value.has_moe = False

NemotronHForCausalLM(vllm_config=mock_vllm_config)

MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
78 changes: 78 additions & 0 deletions tests/model_executor/test_qwen3_5_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import Mock, patch


def test_qwen3_5_lm_head_receives_quant_config():
from vllm.model_executor.models.qwen3_5 import Qwen3_5ForCausalLMBase

mock_quant_config = Mock()

mock_hf_config = Mock()
mock_hf_config.tie_word_embeddings = False
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64

mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_text_config = mock_hf_config
mock_vllm_config.cache_config.mamba_cache_mode = "align"
mock_vllm_config.scheduler_config = Mock()
mock_vllm_config.quant_config = mock_quant_config
mock_vllm_config.lora_config = None

mock_pp_group = Mock()
mock_pp_group.is_last_rank = True

with (
patch("vllm.model_executor.models.qwen3_5.Qwen3_5Model") as MockModel,
patch("vllm.model_executor.models.qwen3_5.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.qwen3_5.LogitsProcessor"),
patch(
"vllm.model_executor.models.qwen3_5.get_pp_group",
return_value=mock_pp_group,
),
):
MockModel.return_value.make_empty_intermediate_tensors = Mock()

Qwen3_5ForCausalLMBase(vllm_config=mock_vllm_config)

MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config


def test_qwen3_5_mtp_lm_head_receives_quant_config():
from vllm.config import CompilationMode
from vllm.model_executor.models.qwen3_5_mtp import Qwen3_5MTP

mock_quant_config = Mock()

mock_hf_config = Mock()
mock_hf_config.tie_word_embeddings = False
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64

mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_text_config = mock_hf_config
mock_vllm_config.cache_config.mamba_cache_mode = "align"
mock_vllm_config.compilation_config.mode = CompilationMode.NONE
mock_vllm_config.quant_config = mock_quant_config

mock_pp_group = Mock()
mock_pp_group.is_last_rank = True

with (
patch("vllm.model_executor.models.qwen3_5_mtp.Qwen3_5MultiTokenPredictor"),
patch("vllm.model_executor.models.qwen3_5_mtp.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.qwen3_5_mtp.LogitsProcessor"),
patch(
"vllm.model_executor.models.qwen3_5_mtp.get_pp_group",
return_value=mock_pp_group,
),
):
Qwen3_5MTP(vllm_config=mock_vllm_config)

MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
94 changes: 93 additions & 1 deletion tests/quantization/test_modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@

import os
from typing import Any, NoReturn
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm.config.model import ModelConfig
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8Config,
ModelOptMixedPrecisionConfig,
ModelOptNvFp4Config,
ModelOptNvFp4LinearMethod,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)


@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -44,6 +55,87 @@ def _snapshot_download_or_skip(model_id: str) -> str:
_skip(f"Failed to download {model_id} from the HF Hub: {e}")


def _mock_lm_head() -> Mock:
lm_head = Mock(spec=ParallelLMHead)
lm_head.__class__ = ParallelLMHead
return lm_head


def _mixed_precision_config(quantized_layers: dict) -> ModelOptMixedPrecisionConfig:
return ModelOptMixedPrecisionConfig(
kv_cache_quant_method=None,
exclude_modules=[],
quantized_layers=quantized_layers,
fp8_config=ModelOptFp8Config(
quant_method="FP8",
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=None,
exclude_modules=[],
),
nvfp4_config=ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
),
w4a16_nvfp4_config=ModelOptNvFp4Config(
quant_method="W4A16_NVFP4",
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
),
)


def test_modelopt_nvfp4_quantizes_parallel_lm_head():
config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)

with patch(
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
):
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")

assert isinstance(method, ModelOptNvFp4LinearMethod)


def test_modelopt_nvfp4_leaves_excluded_parallel_lm_head_unquantized():
config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=["lm_head"],
)

method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")

assert isinstance(method, UnquantizedLinearMethod)


def test_modelopt_mixed_precision_quantizes_parallel_lm_head():
config = _mixed_precision_config(
{"lm_head": {"quant_algo": "NVFP4", "group_size": 16}}
)

with patch(
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
):
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")

assert isinstance(method, ModelOptNvFp4LinearMethod)


def test_vocab_parallel_embedding_weight_loader_accepts_scalar_scale():
holder = Mock()
scale = torch.nn.Parameter(torch.empty(1))
loaded_scale = torch.tensor(2.0)

VocabParallelEmbedding.weight_loader(holder, scale, loaded_scale)

assert torch.equal(scale, loaded_scale.reshape(1))


@pytest.mark.skipif(
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
Expand Down Expand Up @@ -187,7 +188,7 @@ def get_quant_method(

# handle exclusion
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
return None

Expand All @@ -200,7 +201,7 @@ def get_quant_method(
return UnquantizedLinearMethod()

# now, the layer is quantized, handle it here
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
Expand Down Expand Up @@ -2393,13 +2394,13 @@ def get_quant_method(

# Excluded layers
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
return None

quant_algo = self._resolve_quant_algo(prefix)

if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def __init__(

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the vocabulary dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(
Expand Down Expand Up @@ -438,6 +439,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
if (
loaded_weight.ndim == 0
and param.data.ndim == 1
and param.data.numel() == 1
):
loaded_weight = loaded_weight.reshape(1)
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
Expand Down
Loading