diff --git a/tests/config/draft_model_arch_groundtruth.json b/tests/config/draft_model_arch_groundtruth.json index dfe6f3d39e93..5ea8136e9bd9 100644 --- a/tests/config/draft_model_arch_groundtruth.json +++ b/tests/config/draft_model_arch_groundtruth.json @@ -38,7 +38,7 @@ "EagleDeepSeekMTPModel" ], "model_type": "eagle", - "text_model_type": "deepseek_mtp", + "text_model_type": "eagle", "hidden_size": 2560, "total_num_hidden_layers": 1, "total_num_attention_heads": 32, @@ -55,7 +55,7 @@ "EagleLlamaForCausalLM" ], "model_type": "eagle", - "text_model_type": "llama", + "text_model_type": "eagle", "hidden_size": 4096, "total_num_hidden_layers": 1, "total_num_attention_heads": 32, @@ -72,7 +72,7 @@ "Eagle3LlamaForCausalLM" ], "model_type": "eagle", - "text_model_type": "llama", + "text_model_type": "eagle", "hidden_size": 4096, "total_num_hidden_layers": 1, "total_num_attention_heads": 32, diff --git a/tests/test_config.py b/tests/test_config.py index da5080fadbb9..a35c34bbf8d6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,8 +13,10 @@ from vllm.config import ( CompilationConfig, ModelConfig, + ParallelConfig, PoolerConfig, SchedulerConfig, + SpeculativeConfig, VllmConfig, update_config, ) @@ -1093,3 +1095,23 @@ def test_needs_dp_coordination( vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config) assert vllm_config.needs_dp_coordinator == expected_needs_coordinator + + +def test_eagle_draft_model_config(): + """Test that EagleDraft model config is correctly set.""" + target_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True + ) + speculative_config = SpeculativeConfig( + model="yuhuili/EAGLE-LLaMA3-Instruct-8B", + num_speculative_tokens=1, + target_model_config=target_model_config, + target_parallel_config=ParallelConfig(), + ) + draft_model_config = speculative_config.draft_model_config + assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"] + assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"] + assert draft_model_config.hf_config.model_type == "eagle" + assert draft_model_config.hf_text_config.model_type == "eagle" + assert draft_model_config.architectures == ["EagleLlamaForCausalLM"] + assert draft_model_config.architecture == "EagleLlamaForCausalLM" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e80fa532a2c4..b6170cb12d1c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -12,6 +12,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger +from vllm.transformers_utils.config import get_hf_text_config from vllm.utils.hashing import safe_hash from vllm.utils.import_utils import LazyLoader, has_arctic_inference @@ -399,10 +400,23 @@ def __post_init__(self): method=self.method, model_type="eagle", ) + # EAGLEConfig primarily updates architectures, so update + # all architectures-related fields in draft_model_config self.draft_model_config.hf_config = eagle_config + self.draft_model_config.hf_text_config = get_hf_text_config( + self.draft_model_config.hf_config + ) self.draft_model_config.model_arch_config = ( self.draft_model_config.get_model_arch_config() ) + model_info, arch = ( + self.draft_model_config.registry.inspect_model_cls( + self.draft_model_config.architectures, + self.draft_model_config, + ) + ) + self.draft_model_config._model_info = model_info + self.draft_model_config._architecture = arch if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index dc067a09419b..6df4bb64dceb 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -201,7 +201,7 @@ def is_deepseek_mla(self) -> bool: # underlying architecture return ( self.hf_text_config.model.model_type - in ("deepseek_v2", "deepseek_v3", "deepseek_v32") + in ("deepseek_v2", "deepseek_v3", "deepseek_v32", "deepseek_mtp") and self.hf_text_config.kv_lora_rank is not None ) return False