Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions tests/model_executor/test_eagle_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import Mock, patch

import pytest
import torch

from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig
from vllm.model_executor.models.utils import get_draft_quant_config
from vllm.platforms import current_platform

DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
if current_platform.is_cuda_alike()
else ["cpu"]
)


def test_get_draft_quant_config_with_draft_model():
mock_draft_model_config = Mock(spec=ModelConfig)
mock_load_config = Mock(spec=LoadConfig)
mock_speculative_config = Mock(spec=SpeculativeConfig)
mock_speculative_config.draft_model_config = mock_draft_model_config

mock_vllm_config = Mock(spec=VllmConfig)
mock_vllm_config.speculative_config = mock_speculative_config
mock_vllm_config.load_config = mock_load_config

mock_quant_config = Mock()
with patch.object(
VllmConfig, "get_quantization_config", return_value=mock_quant_config
):
result = get_draft_quant_config(mock_vllm_config)

# Verify the function calls get_quantization_config with draft model config
VllmConfig.get_quantization_config.assert_called_once_with(
mock_draft_model_config, mock_load_config
)
assert result == mock_quant_config


def test_get_draft_quant_config_without_draft_model():
mock_speculative_config = Mock(spec=SpeculativeConfig)
mock_speculative_config.draft_model_config = None

mock_vllm_config = Mock(spec=VllmConfig)
mock_vllm_config.speculative_config = mock_speculative_config
mock_vllm_config.load_config = Mock(spec=LoadConfig)

result = get_draft_quant_config(mock_vllm_config)

assert result is None


@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
import torch

from vllm.model_executor.layers.linear import ReplicatedLinear

if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)

input_size = 256
output_size = 128

fc_no_quant = ReplicatedLinear(
input_size=input_size,
output_size=output_size,
bias=False,
params_dtype=torch.float16,
quant_config=None,
prefix="fc",
)

assert fc_no_quant.quant_config is None
assert fc_no_quant.input_size == input_size
assert fc_no_quant.output_size == output_size

mock_quant_config = Mock()
fc_with_quant = ReplicatedLinear(
input_size=input_size,
output_size=output_size,
bias=False,
params_dtype=torch.float16,
quant_config=mock_quant_config,
prefix="fc",
)

assert fc_with_quant.quant_config == mock_quant_config

# Check forward pass
x = torch.randn(2, input_size, dtype=torch.float16)
output, _ = fc_no_quant(x)
assert output.shape == (2, output_size)


def test_kv_cache_scale_name_handling():
# Mock a quant config that supports cache scales
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")

# Condition check in load_weights
name = "layers.0.self_attn.k_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)

# Check if get_cache_scale is called and returns expected value
mock_quant_config.get_cache_scale.assert_called_once_with(name)
assert scale_name == "layers.0.self_attn.kv_scale"


def test_kv_cache_scale_name_no_scale():
# Mock a quant config that returns None for get_cache_scale
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value=None)

name = "layers.0.mlp.gate_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)

# Should return None for weights that don't have cache scales
assert scale_name is None


def test_maybe_remap_kv_scale_name():
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name

params_dict = {
"layers.0.self_attn.kv_scale": Mock(),
"layers.1.self_attn.kv_scale": Mock(),
}

name = "layers.0.self_attn.some_scale"
remapped = maybe_remap_kv_scale_name(name, params_dict)

assert remapped in params_dict or remapped == name or remapped is None


def test_load_weights_kv_scale_handling():
kv_scale_param = Mock()
kv_scale_param.weight_loader = Mock()

params_dict = {
"layers.0.self_attn.kv_scale": kv_scale_param,
}

mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")

# Load_weights logic for KV cache scales
name = "layers.0.self_attn.k_proj.weight"
loaded_weight_tensor = torch.tensor([1.0, 2.0])

if mock_quant_config is not None:
scale_name = mock_quant_config.get_cache_scale(name)
if scale_name:
param = params_dict[scale_name]
assert param is kv_scale_param
weight_to_load = (
loaded_weight_tensor
if loaded_weight_tensor.dim() == 0
else loaded_weight_tensor[0]
)

assert scale_name == "layers.0.self_attn.kv_scale"
assert weight_to_load == loaded_weight_tensor[0]
48 changes: 36 additions & 12 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM

from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, get_draft_quant_config, maybe_prefix

logger = init_logger(__name__)

Expand All @@ -40,14 +44,7 @@ def __init__(

def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config

return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
return get_draft_quant_config(vllm_config)


@support_torch_compile
Expand All @@ -63,6 +60,9 @@ def __init__(
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)

self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
Expand All @@ -80,8 +80,14 @@ def __init__(
for i in range(self.config.num_hidden_layers)
]
)
self.fc = torch.nn.Linear(
self.config.hidden_size * 2, self.config.hidden_size, bias=False
self.fc = ReplicatedLinear(
input_size=self.config.hidden_size * 2,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -117,6 +123,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
Comment on lines +126 to +143
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code for handling KV cache quantization scales and remapping FP8 scale names is duplicated in vllm/model_executor/models/llama_eagle3.py. To improve maintainability and avoid potential bugs from inconsistent updates, this logic should be extracted into a shared utility function, perhaps in vllm.model_executor.model_loader.weight_utils. This would follow the same good practice you've already applied by refactoring get_draft_quant_config.

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
57 changes: 40 additions & 17 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors

from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, get_draft_quant_config, maybe_prefix

logger = init_logger(__name__)

Expand Down Expand Up @@ -67,14 +70,7 @@ def __init__(

def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config

return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)
return get_draft_quant_config(vllm_config)

def _norm_before_residual(
self, hidden_states: torch.Tensor
Expand Down Expand Up @@ -141,6 +137,9 @@ def __init__(
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)

current_vllm_config = get_current_vllm_config()

self.embed_tokens = VocabParallelEmbedding(
Expand All @@ -161,13 +160,19 @@ def __init__(
]
)
if hasattr(self.config, "target_hidden_size"):
self.fc = torch.nn.Linear(
self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.target_hidden_size * 3
else:
self.fc = torch.nn.Linear(
self.config.hidden_size * 3, self.config.hidden_size, bias=False
)
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)

self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
Expand Down Expand Up @@ -212,6 +217,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
Comment on lines +220 to +237
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for handling KV cache quantization scales is duplicated from vllm/model_executor/models/llama_eagle.py. To improve maintainability, please consider refactoring this shared logic into a common utility function.

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down
22 changes: 22 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,28 @@ def maybe_prefix(prefix: str, name: str) -> str:
return name if not prefix else f"{prefix}.{name}"


def get_draft_quant_config(vllm_config: VllmConfig):
"""Get quantization config for Draft models.

Draft models should use their own quantization config instead of the verifier/target
model's config. This helper retrieves the draft model's quantization config.

Args:
vllm_config: The vLLM configuration object.

Returns:
The draft model's config if available, None otherwise.
"""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config

return (
VllmConfig.get_quantization_config(draft_model_config, draft_load_config)
if draft_model_config
else None
)


def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int:
"""
Extract the layer index from the module name.
Expand Down