-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[Eagle] [Quantization] Add complete quantization support to the draft model in Eagle #27434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
shreyas269
wants to merge
3
commits into
vllm-project:main
from
capitalone-contributions:shreyas269-vllm-quantize-eagle
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from unittest.mock import Mock, patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.config import LoadConfig, ModelConfig, SpeculativeConfig, VllmConfig | ||
| from vllm.model_executor.models.utils import get_draft_quant_config | ||
| from vllm.platforms import current_platform | ||
|
|
||
| DEVICES = ( | ||
| [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||
| if current_platform.is_cuda_alike() | ||
| else ["cpu"] | ||
| ) | ||
|
|
||
|
|
||
| def test_get_draft_quant_config_with_draft_model(): | ||
| mock_draft_model_config = Mock(spec=ModelConfig) | ||
| mock_load_config = Mock(spec=LoadConfig) | ||
| mock_speculative_config = Mock(spec=SpeculativeConfig) | ||
| mock_speculative_config.draft_model_config = mock_draft_model_config | ||
|
|
||
| mock_vllm_config = Mock(spec=VllmConfig) | ||
| mock_vllm_config.speculative_config = mock_speculative_config | ||
| mock_vllm_config.load_config = mock_load_config | ||
|
|
||
| mock_quant_config = Mock() | ||
| with patch.object( | ||
| VllmConfig, "get_quantization_config", return_value=mock_quant_config | ||
| ): | ||
| result = get_draft_quant_config(mock_vllm_config) | ||
|
|
||
| # Verify the function calls get_quantization_config with draft model config | ||
| VllmConfig.get_quantization_config.assert_called_once_with( | ||
| mock_draft_model_config, mock_load_config | ||
| ) | ||
| assert result == mock_quant_config | ||
|
|
||
|
|
||
| def test_get_draft_quant_config_without_draft_model(): | ||
| mock_speculative_config = Mock(spec=SpeculativeConfig) | ||
| mock_speculative_config.draft_model_config = None | ||
|
|
||
| mock_vllm_config = Mock(spec=VllmConfig) | ||
| mock_vllm_config.speculative_config = mock_speculative_config | ||
| mock_vllm_config.load_config = Mock(spec=LoadConfig) | ||
|
|
||
| result = get_draft_quant_config(mock_vllm_config) | ||
|
|
||
| assert result is None | ||
|
|
||
|
|
||
| @torch.inference_mode() | ||
| @pytest.mark.parametrize("device", DEVICES) | ||
| def test_fc_layer_quant_config_usage(dist_init, device) -> None: | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.linear import ReplicatedLinear | ||
|
|
||
| if current_platform.is_cuda_alike(): | ||
| torch.cuda.set_device(device) | ||
|
|
||
| torch.set_default_device(device) | ||
|
|
||
| input_size = 256 | ||
| output_size = 128 | ||
|
|
||
| fc_no_quant = ReplicatedLinear( | ||
| input_size=input_size, | ||
| output_size=output_size, | ||
| bias=False, | ||
| params_dtype=torch.float16, | ||
| quant_config=None, | ||
| prefix="fc", | ||
| ) | ||
|
|
||
| assert fc_no_quant.quant_config is None | ||
| assert fc_no_quant.input_size == input_size | ||
| assert fc_no_quant.output_size == output_size | ||
|
|
||
| mock_quant_config = Mock() | ||
| fc_with_quant = ReplicatedLinear( | ||
| input_size=input_size, | ||
| output_size=output_size, | ||
| bias=False, | ||
| params_dtype=torch.float16, | ||
| quant_config=mock_quant_config, | ||
| prefix="fc", | ||
| ) | ||
|
|
||
| assert fc_with_quant.quant_config == mock_quant_config | ||
|
|
||
| # Check forward pass | ||
| x = torch.randn(2, input_size, dtype=torch.float16) | ||
| output, _ = fc_no_quant(x) | ||
| assert output.shape == (2, output_size) | ||
|
|
||
|
|
||
| def test_kv_cache_scale_name_handling(): | ||
| # Mock a quant config that supports cache scales | ||
| mock_quant_config = Mock() | ||
| mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") | ||
|
|
||
| # Condition check in load_weights | ||
| name = "layers.0.self_attn.k_proj.weight" | ||
| scale_name = mock_quant_config.get_cache_scale(name) | ||
|
|
||
| # Check if get_cache_scale is called and returns expected value | ||
| mock_quant_config.get_cache_scale.assert_called_once_with(name) | ||
| assert scale_name == "layers.0.self_attn.kv_scale" | ||
|
|
||
|
|
||
| def test_kv_cache_scale_name_no_scale(): | ||
| # Mock a quant config that returns None for get_cache_scale | ||
| mock_quant_config = Mock() | ||
| mock_quant_config.get_cache_scale = Mock(return_value=None) | ||
|
|
||
| name = "layers.0.mlp.gate_proj.weight" | ||
| scale_name = mock_quant_config.get_cache_scale(name) | ||
|
|
||
| # Should return None for weights that don't have cache scales | ||
| assert scale_name is None | ||
|
|
||
|
|
||
| def test_maybe_remap_kv_scale_name(): | ||
| from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name | ||
|
|
||
| params_dict = { | ||
| "layers.0.self_attn.kv_scale": Mock(), | ||
| "layers.1.self_attn.kv_scale": Mock(), | ||
| } | ||
|
|
||
| name = "layers.0.self_attn.some_scale" | ||
| remapped = maybe_remap_kv_scale_name(name, params_dict) | ||
|
|
||
| assert remapped in params_dict or remapped == name or remapped is None | ||
|
|
||
|
|
||
| def test_load_weights_kv_scale_handling(): | ||
| kv_scale_param = Mock() | ||
| kv_scale_param.weight_loader = Mock() | ||
|
|
||
| params_dict = { | ||
| "layers.0.self_attn.kv_scale": kv_scale_param, | ||
| } | ||
|
|
||
| mock_quant_config = Mock() | ||
| mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale") | ||
|
|
||
| # Load_weights logic for KV cache scales | ||
| name = "layers.0.self_attn.k_proj.weight" | ||
| loaded_weight_tensor = torch.tensor([1.0, 2.0]) | ||
|
|
||
| if mock_quant_config is not None: | ||
| scale_name = mock_quant_config.get_cache_scale(name) | ||
| if scale_name: | ||
| param = params_dict[scale_name] | ||
| assert param is kv_scale_param | ||
| weight_to_load = ( | ||
| loaded_weight_tensor | ||
| if loaded_weight_tensor.dim() == 0 | ||
| else loaded_weight_tensor[0] | ||
| ) | ||
|
|
||
| assert scale_name == "layers.0.self_attn.kv_scale" | ||
| assert weight_to_load == loaded_weight_tensor[0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,20 +11,23 @@ | |
| from vllm.config import VllmConfig, get_current_vllm_config | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.linear import QKVParallelLinear | ||
| from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.quantization.base_config import QuantizationConfig | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| DEFAULT_VOCAB_PADDING_SIZE, | ||
| ParallelLMHead, | ||
| VocabParallelEmbedding, | ||
| ) | ||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
| from vllm.model_executor.model_loader.weight_utils import ( | ||
| default_weight_loader, | ||
| maybe_remap_kv_scale_name, | ||
| ) | ||
| from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||
| from vllm.multimodal.inputs import NestedTensors | ||
|
|
||
| from .utils import AutoWeightsLoader, maybe_prefix | ||
| from .utils import AutoWeightsLoader, get_draft_quant_config, maybe_prefix | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -67,14 +70,7 @@ def __init__( | |
|
|
||
| def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: | ||
| """Use drafter's quantization config instead of verifier's.""" | ||
| draft_model_config = vllm_config.speculative_config.draft_model_config | ||
| draft_load_config = vllm_config.load_config | ||
|
|
||
| return ( | ||
| VllmConfig.get_quantization_config(draft_model_config, draft_load_config) | ||
| if draft_model_config | ||
| else None | ||
| ) | ||
| return get_draft_quant_config(vllm_config) | ||
|
|
||
| def _norm_before_residual( | ||
| self, hidden_states: torch.Tensor | ||
|
|
@@ -141,6 +137,9 @@ def __init__( | |
| self.config = vllm_config.speculative_config.draft_model_config.hf_config | ||
| self.vocab_size = self.config.vocab_size | ||
|
|
||
| # Get drafter's quantization config | ||
| self.quant_config = get_draft_quant_config(vllm_config) | ||
|
|
||
| current_vllm_config = get_current_vllm_config() | ||
|
|
||
| self.embed_tokens = VocabParallelEmbedding( | ||
|
|
@@ -161,13 +160,19 @@ def __init__( | |
| ] | ||
| ) | ||
| if hasattr(self.config, "target_hidden_size"): | ||
| self.fc = torch.nn.Linear( | ||
| self.config.target_hidden_size * 3, self.config.hidden_size, bias=False | ||
| ) | ||
| fc_input_size = self.config.target_hidden_size * 3 | ||
| else: | ||
| self.fc = torch.nn.Linear( | ||
| self.config.hidden_size * 3, self.config.hidden_size, bias=False | ||
| ) | ||
| fc_input_size = self.config.hidden_size * 3 | ||
| self.fc = ReplicatedLinear( | ||
| input_size=fc_input_size, | ||
| output_size=self.config.hidden_size, | ||
| bias=False, | ||
| params_dtype=vllm_config.model_config.dtype, | ||
| quant_config=self.quant_config, | ||
| prefix=maybe_prefix(prefix, "fc"), | ||
| return_bias=False, | ||
| ) | ||
|
|
||
| self.norm = RMSNorm( | ||
| self.config.hidden_size, | ||
| eps=self.config.rms_norm_eps, | ||
|
|
@@ -212,6 +217,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| for name, loaded_weight in weights: | ||
| if "midlayer." in name: | ||
| name = name.replace("midlayer.", "layers.0.") | ||
| # Handle kv cache quantization scales | ||
| if self.quant_config is not None and ( | ||
| scale_name := self.quant_config.get_cache_scale(name) | ||
| ): | ||
| # Loading kv cache quantization scales | ||
| param = params_dict[scale_name] | ||
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
| loaded_weight = ( | ||
| loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] | ||
| ) | ||
| weight_loader(param, loaded_weight) | ||
| loaded_params.add(scale_name) | ||
| continue | ||
| # Remapping the name FP8 kv-scale | ||
| if "scale" in name: | ||
| name = maybe_remap_kv_scale_name(name, params_dict) | ||
| if name is None: | ||
| continue | ||
|
Comment on lines
+220
to
+237
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for handling KV cache quantization scales and remapping FP8 scale names is duplicated in
vllm/model_executor/models/llama_eagle3.py. To improve maintainability and avoid potential bugs from inconsistent updates, this logic should be extracted into a shared utility function, perhaps invllm.model_executor.model_loader.weight_utils. This would follow the same good practice you've already applied by refactoringget_draft_quant_config.