Skip to content

[0.13.0][cherry-pick][BugFix] Support setting tp=1 for the Eagle draft model to take effect#5804

Merged
wangxiyuan merged 1 commit intovllm-project:releases/v0.13.0from
zhaomingyu13:releases
Jan 13, 2026
Merged

[0.13.0][cherry-pick][BugFix] Support setting tp=1 for the Eagle draft model to take effect#5804
wangxiyuan merged 1 commit intovllm-project:releases/v0.13.0from
zhaomingyu13:releases

Conversation

@zhaomingyu13
Copy link
Copy Markdown
Contributor

@zhaomingyu13 zhaomingyu13 commented Jan 12, 2026

What this PR does / why we need it?

According to the official documentation, the parameter "draft_tensor_parallel_size": 1 is supposed to be applied to the Eagle3 model. However, based on actual debugging, it was found that the number of tensor parallelisms (tp) of the Eagle model is consistent with that of the target model. The setting of tp for the draft model did not take effect as expected.

Note: This feature has not been superimposed and tested with sp and dp. It will be adapted later

Does this PR introduce any user-facing change?

No

How was this patch tested?

from vllm import LLM, SamplingParams

def main():
    prompts = [
        "The future of AI is",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    # Create an LLM.
    llm = LLM(
            model="meta-llama/Llama-3.1-8B-Instruct",
            tensor_parallel_size=4,
            gpu_memory_utilization=0.9,
            enforce_eager=True,
            speculative_config={
                "method": "eagle3",
                "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
                "draft_tensor_parallel_size": 1,
                "num_speculative_tokens": 3,
            },
        )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    print(f"Outputs: {outputs}")
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Fixes vllm-project/vllm#31345

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
Co-authored-by: drslark <slarksblood@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a bug where setting tensor_parallel_size=1 for the Eagle draft model was not taking effect. The approach of creating a temporary tensor parallel group and patching the global TP group during the draft model loading is sound. However, I've identified a critical issue in the implementation that prevents it from working as intended. The new tensor parallel group is being created with group_name="tp", which is the same name as the main tensor parallel group. This causes init_model_parallel_group to return the existing main TP group instead of creating a new one, rendering the patch ineffective. I have provided a suggestion to resolve this. The accompanying test additions and an unrelated but correct fix for UniformTypeKVCacheSpecs are well-implemented.

Comment thread vllm_ascend/spec_decode/eagle_proposer.py
@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Jan 12, 2026
@wangxiyuan wangxiyuan merged commit 7c71736 into vllm-project:releases/v0.13.0 Jan 13, 2026
17 checks passed
@wangxiyuan wangxiyuan changed the title [BugFix] Support setting tp=1 for the Eagle draft model to take effect [0.13.0][cherry-pick][BugFix] Support setting tp=1 for the Eagle draft model to take effect Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants