Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/config/draft_model_arch_groundtruth.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"EagleDeepSeekMTPModel"
],
"model_type": "eagle",
"text_model_type": "deepseek_mtp",
"text_model_type": "eagle",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
Expand All @@ -55,7 +55,7 @@
"EagleLlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"text_model_type": "eagle",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
Expand All @@ -72,7 +72,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"text_model_type": "eagle",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from vllm.config import (
CompilationConfig,
ModelConfig,
ParallelConfig,
PoolerConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
update_config,
)
Expand Down Expand Up @@ -1093,3 +1095,23 @@ def test_needs_dp_coordination(
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)

assert vllm_config.needs_dp_coordinator == expected_needs_coordinator


def test_eagle_draft_model_config():
"""Test that EagleDraft model config is correctly set."""
target_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
)
speculative_config = SpeculativeConfig(
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
draft_model_config = speculative_config.draft_model_config
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_config.model_type == "eagle"
assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
14 changes: 14 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference

Expand Down Expand Up @@ -399,10 +400,23 @@ def __post_init__(self):
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
Comment on lines +403 to +419
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this correctly fixes the issue, manually re-running parts of ModelConfig.__post_init__ here makes the code fragile. If the initialization logic in ModelConfig changes in the future, this code might become outdated and introduce subtle bugs.

To improve maintainability and encapsulation, I suggest adding a private helper method to ModelConfig to reset these architecture-dependent fields. This would also avoid accessing private members like _model_info and _architecture from outside the class.

You could add the following method to vllm/config/model.py (this file is not in the PR, so it would be an expansion of scope, but would improve the codebase):

# In vllm/config/model.py, class ModelConfig:
def _reset_architecture_fields(self):
    """
    Resets fields that are derived from the model architecture.
    This is useful when hf_config.architectures is modified after
    initialization.
    """
    from vllm.transformers_utils.config import get_hf_text_config
    self.hf_text_config = get_hf_text_config(self.hf_config)
    self.model_arch_config = self.get_model_arch_config()
    model_info, arch = self.registry.inspect_model_cls(
        self.architectures,
        self,
    )
    self._model_info = model_info
    self._architecture = arch

Then, you can simplify the code here as follows:

                        # EAGLEConfig primarily updates architectures, so update
                        # all architectures-related fields in draft_model_config
                        self.draft_model_config.hf_config = eagle_config
                        self.draft_model_config._reset_architecture_fields()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _reset_architecture_fields was only used by eagle draft models, currently I keep those logics inside speculative.py. Happy to change if reviewer also agree with gemini

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think gemini's solution can solve the problem that people may forget to update this code. So I'm OK with current implementation.


if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def is_deepseek_mla(self) -> bool:
# underlying architecture
return (
self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
in ("deepseek_v2", "deepseek_v3", "deepseek_v32", "deepseek_mtp")
and self.hf_text_config.kv_lora_rank is not None
)
return False
Expand Down