Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/speculative_decoding/speculators/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator"),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier"),
])
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
monkeypatch):
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def __init__(self,

config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
quant_config = self.get_quant_config(vllm_config)

self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
Expand Down Expand Up @@ -328,6 +328,11 @@ def forward(
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

def get_quant_config(
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
"""Get quantization config for this layer. Override in subclasses."""
return vllm_config.quant_config


@support_torch_compile
class LlamaModel(nn.Module):
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand All @@ -33,7 +35,7 @@ def __init__(self,
super().__init__(vllm_config, prefix=prefix, config=config)

config = config or vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
quant_config = self.get_quant_config(vllm_config)

# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
Expand All @@ -53,6 +55,16 @@ def __init__(self,
else:
self._residual_norm = self._norm_after_residual

def get_quant_config(
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config

return VllmConfig.get_quantization_config(
draft_model_config,
draft_load_config) if draft_model_config else None

def _norm_before_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down