[Bugfix][Model] Fix Eagle3 speculative decoding for Qwen3Next-based models#36527
[Bugfix][Model] Fix Eagle3 speculative decoding for Qwen3Next-based models#36527NikitosKh wants to merge 1 commit intovllm-project:mainfrom
Conversation
…odels Qwen3NextModel.forward() was missing auxiliary hidden state capture logic, making Eagle3 non-functional for all Qwen3Next-based models (including Qwen3.5). The SupportsEagle3 protocol was declared on Qwen3_5ForConditionalGeneration via inheritance, but the inner model forward() never captured the states — silently breaking Eagle3. Fix: - Add aux_hidden_state_layers init and capture logic to Qwen3NextModel.forward(), mirroring the Qwen2Model pattern - Add SupportsEagle3 to Qwen3NextForCausalLM and Qwen3_5ForCausalLMBase with set/get methods - Add 9 tests verifying protocol compliance, forward() behavior, and consistency with the Qwen2 reference implementation Tested with Qwen3.5-9B + Eagle3 draft model on vLLM, achieving 2.1x speedup over autoregressive baseline. Signed-off-by: NikitosKh <nikitak.khomich.work@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request aims to fix Eagle3 speculative decoding for Qwen3Next-based models by implementing auxiliary hidden state capturing in Qwen3NextModel. However, it introduces a critical bug in the initialization of auxiliary hidden state layers, specifically an invalid call to a type alias (tuple[int, ...]()). This will cause a runtime TypeError during the model's forward pass when Eagle3 speculative decoding is not enabled, leading to a denial of service for standard use cases. Additionally, the layer indexing logic in the forward pass is incompatible with Pipeline Parallelism, which may cause crashes or incorrect behavior in distributed configurations. The changes correctly follow the pattern from Qwen2Model but these two critical issues prevent the feature from working as intended.
| else: | ||
| self.norm = PPMissingLayer() | ||
|
|
||
| self.aux_hidden_state_layers = tuple[int, ...]() |
There was a problem hiding this comment.
The initialization of self.aux_hidden_state_layers using tuple[int, ...]() is invalid Python syntax and will cause a TypeError at runtime. This occurs in the __init__ method of Qwen3NextModel. If Eagle3 speculative decoding is not enabled, the forward pass will crash when attempting to perform an in check on a GenericAlias object, leading to a complete denial of service for Qwen3Next-based models without Eagle3 enabled. To initialize an empty tuple, use the literal ().
| self.aux_hidden_state_layers = tuple[int, ...]() | |
| self.aux_hidden_state_layers: tuple[int, ...] = () |
| for idx, layer in enumerate( | ||
| islice(self.layers, self.start_layer, self.end_layer) | ||
| ): | ||
| if idx in self.aux_hidden_state_layers: |
There was a problem hiding this comment.
The current layer iteration logic using enumerate(islice(...)) causes idx to be relative to the current shard's start layer. Since self.aux_hidden_state_layers stores absolute layer indices, this check will fail for shards where start_layer > 0. This incompatibility with Pipeline Parallelism means no auxiliary hidden states will be captured on those shards, likely causing the Eagle3 engine to crash or produce incorrect results. To correctly support Pipeline Parallelism, the loop structure needs to be adjusted to ensure idx represents the absolute layer index.
| for idx, layer in enumerate( | |
| islice(self.layers, self.start_layer, self.end_layer) | |
| ): | |
| if idx in self.aux_hidden_state_layers: | |
| for idx, layer in islice( | |
| enumerate(self.layers), self.start_layer, self.end_layer | |
| ): | |
| if idx in self.aux_hidden_state_layers: |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Closing since #36658 covered the same fix |
Purpose
While training an Eagle3 draft model for Qwen3.5-9B (weights on HF), I discovered that Eagle3 speculative decoding is silently broken for all
Qwen3Next-based models (Qwen3.5 and future models using this base).The root cause is straightforward:
Qwen3NextModel.forward()never captures auxiliary hidden states. TheSupportsEagle3protocol is declared onQwen3_5ForConditionalGeneration(inherited fromQwen3VLForConditionalGeneration), so vLLM happily callsset_aux_hidden_state_layers()— but the inner model'sforward()just ignores the setting. No error is raised; Eagle3 simply doesn't work.Compare with
Qwen2Model.forward()which correctly implementsenumerate(islice(...))with aux hidden state capture.Qwen3NextModelwas missing the same logic.Changes
qwen3_next.py(the actual fix):self.aux_hidden_state_layersinQwen3NextModel.__init__Qwen3NextModel.forward()to capturehidden_states + residualat specified layer indices, mirroring theQwen2ModelpatternSupportsEagle3+set_aux_hidden_state_layers+get_eagle3_aux_hidden_state_layerstoQwen3NextForCausalLMqwen3_5.py(propagate to the Qwen3.5-specific class):SupportsEagle3toQwen3_5ForCausalLMBasewith the same Eagle3 methodstests/v1/spec_decode/test_qwen3next_eagle3_support.py(new):forward()behavior, and consistency with theQwen2Modelreference implementation — all fail before the fix, all pass afterTest Plan
Test Result
9/9 unit tests pass. E2E server starts, loads the draft model, and serves requests with speculative decoding active.
The modest accept length is due to the draft model being trained on a small dataset.
Benchmark on a single H100 (Qwen3.5-9B + our Eagle3 draft):