Skip to content

Add: Eagle3 support for Qwen3.5#36658

Merged
vllm-bot merged 4 commits intovllm-project:mainfrom
neuralmagic:add-eagle3-support-qwen3.5
Mar 11, 2026
Merged

Add: Eagle3 support for Qwen3.5#36658
vllm-bot merged 4 commits intovllm-project:mainfrom
neuralmagic:add-eagle3-support-qwen3.5

Conversation

@rahul-tuli
Copy link
Copy Markdown
Contributor

@rahul-tuli rahul-tuli commented Mar 10, 2026

This PR adds support for EAGLE-3 speculative decoding to Qwen3.5, enabling faster inference with draft models like BLR2/Qwen3.5-9B-Eagle3-ShareGPT.

Changes

Modified Files

  • vllm/model_executor/models/qwen3_next.py
  • vllm/model_executor/models/qwen3_5.py

Implementation Details

  1. Updated Qwen3NextModel (qwen3_next.py)

    • Added aux_hidden_state_layers attribute to track which layers output auxiliary hidden states
    • Modified forward() to collect auxiliary hidden states at specified global layer indices
    • Returns (hidden_states, aux_hidden_states) when auxiliary states are collected; otherwise returns hidden_states unchanged (zero overhead when Eagle3 is not active)
  2. Added SupportsEagle3 Interface to Qwen3_5ForCausalLMBase (qwen3_5.py)

    • Imported and added SupportsEagle3 to Qwen3_5ForCausalLMBase class inheritance — inherited automatically by both Qwen3_5ForCausalLM and Qwen3_5MoeForCausalLM
    • Implements set_aux_hidden_state_layers() and get_eagle3_aux_hidden_state_layers() returning layer indices (2, num_layers // 2, num_layers - 3)
    • Kept self.aux_hidden_state_layers = () in Qwen3_5Model.__init__ because it calls super(Qwen3NextModel, self).__init__(), skipping Qwen3NextModel.__init__

Testing

Tested with Qwen/Qwen3.5-9B and EAGLE-3 drafter BLR2/Qwen3.5-9B-Eagle3-ShareGPT on mt-bench:

CUDA_VISIBLE_DEVICES=0 python examples/offline_inference/spec_decode.py \
    --model-dir Qwen/Qwen3.5-9B \
    --eagle-dir BLR2/Qwen3.5-9B-Eagle3-ShareGPT \
    --method eagle3 \
    --num-spec-tokens 3 \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts 80 \
    --enable-chunked-prefill \
    --temp 0 \
    --print-output
--------------------------------------------------
total_num_output_tokens: 20480
num_drafts: 8750
num_draft_tokens: 26250
num_accepted_tokens: 11720
mean acceptance length: 2.34
--------------------------------------------------
acceptance at token 0: 0.68
acceptance at token 1: 0.42
acceptance at token 2: 0.24

Related

This implementation follows the same pattern as existing EAGLE-3 support in:

  • Qwen2ForCausalLM
  • Qwen3ForCausalLM
  • LlamaForCausalLM

Offline inference script

from vllm import LLM, SamplingParams

# Initialize with EAGLE-3 speculative decoding
llm = LLM(
    model="Qwen/Qwen3.5-9B",
    tensor_parallel_size=1,
    speculative_config={
        "model": "BLR2/Qwen3.5-9B-Eagle3-ShareGPT",
        "method": "eagle3",
        "num_speculative_tokens": 3,
    },
    max_model_len=16384,
)

# Generate with speculative decoding
prompts = [
    "Hello, my name is",
    "The capital of France is",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    print(f"Prompt: {output.prompt}")
    print(f"Generated: {output.outputs[0].text}")

@mergify mergify Bot added the qwen Related to Qwen models label Mar 10, 2026
@rahul-tuli rahul-tuli marked this pull request as ready for review March 10, 2026 14:23
@rahul-tuli rahul-tuli requested a review from sighingnow as a code owner March 10, 2026 14:23
Copy link
Copy Markdown
Contributor

@gambletan gambletan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition of Eagle3 support for Qwen3.5 and Qwen3Next.

One thing I noticed: in qwen3_next.py, the forward() method's return type annotation is changed to torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]], but the caller of this method needs to handle the new tuple return type. Is there a common dispatch layer (e.g., in the Eagle3 speculative decoding code) that already pattern-matches on isinstance(result, tuple) for other models? If not, this could break callers that expect only torch.Tensor | IntermediateTensors.

Also, in qwen3_5.py line ~601, get_eagle3_aux_hidden_state_layers hardcodes (2, num_layers // 2, num_layers - 3). If the model has very few layers (e.g., a small variant with < 5 layers), num_layers - 3 could overlap with layer 2 or even be negative. A guard like assert num_layers >= 6 or similar would make this more robust against unexpected model configurations.

Minor: aux_hidden_state_layers is initialized as an empty tuple () in Qwen3NextModel.__init__ but the forward method checks if aux_hidden_states: (checking the list, not the config tuple). This works correctly but could be slightly confusing to future readers — a brief comment clarifying that aux_hidden_states is populated only when self.aux_hidden_state_layers is non-empty would help.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for EAGLE-3 speculative decoding to Qwen3.5 models, which is a great enhancement for inference performance. The changes are well-structured and follow the existing patterns for Eagle3 support in other models within the vLLM codebase. The implementation correctly collects auxiliary hidden states with minimal overhead when the feature is not active. I have one suggestion to improve the robustness of the layer index generation to ensure it gracefully handles models with a small number of layers.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/model_executor/models/qwen3_5.py (584-586)

high

The current implementation for generating auxiliary layer indices can produce out-of-bounds (e.g., negative or too large) or duplicate values when num_layers is small. While the current usage with the in operator in the forward pass implicitly filters these invalid indices, this approach is brittle and not explicit. For robustness and clarity, it's better to ensure that the returned tuple contains only unique, sorted, and valid layer indices. This prevents potential issues if this method is used in other contexts in the future where direct indexing might be assumed.

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        indices = {2, num_layers // 2, num_layers - 3}
        return tuple(sorted(i for i in indices if 0 <= i < num_layers))

class Qwen3_5ForCausalLMBase(
nn.Module,
HasInnerState,
SupportsEagle3,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add SupportsEagle. It's not currently used everywhere but I'm trying to get all the models to have both for consistency, at least for now. See #36063

Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benchislett
Copy link
Copy Markdown
Collaborator

@gambletan none of those feedbacks are relevant. This PR's implementation is the canonical way of handling EAGLE3 support in vLLM

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 10, 2026 16:15
@vllm-bot vllm-bot merged commit 9d07a3d into vllm-project:main Mar 11, 2026
58 of 63 checks passed
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants