Skip to content

[Bug] Refactor max_num_batched_tokens to account for drafting#34898

Merged
benchislett merged 8 commits intovllm-project:mainfrom
CentML:bchislett/update-scheduler-slots-for-drafting
Feb 22, 2026
Merged

[Bug] Refactor max_num_batched_tokens to account for drafting#34898
benchislett merged 8 commits intovllm-project:mainfrom
CentML:bchislett/update-scheduler-slots-for-drafting

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Feb 19, 2026

Purpose

Alternative bugfix to #34671. Solves a crash of specdec on main.

To solve the consistency issue, we directly modify max_num_batched_tokens when initializing the VllmConfig, and then decrease is specifically in the scheduler so that the scheduling behaviour is unchanged.

Testing

Tested with Qwen3-Next MTP and GSM8k repeated twice with various concurrencies. All pass with 85% accuracy, matching the non-spec baseline.

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct  \
  --tokenizer-mode auto  --gpu-memory-utilization 0.8 \
  --speculative-config '{"method": "qwen3_next_mtp", "num_speculative_tokens": 5}' \
  --tensor-parallel-size 2 --port 8042

…cases

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request refactors the handling of max_num_batched_tokens for speculative decoding by extending the limit in VllmConfig to account for drafting slots and then compensating with a subtraction in the V1 scheduler. This ensures that model runner buffers and CUDA graph ranges are correctly sized for the maximum possible batch size during verification. However, there are several critical issues: the extension logic in SpeculativeConfig incorrectly ignores the extra slots needed for serial drafting (which still requires a multi-token verification pass in V1), and the in-place modification of scheduler_config has unintended side effects on the Mamba cache alignment check, the V0 scheduler, and proposer-specific buffer allocations.

Comment on lines +1204 to +1207
self.scheduler_config.max_num_batched_tokens += (
self.speculative_config.max_num_new_slots_for_drafting
* self.scheduler_config.max_num_seqs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying scheduler_config.max_num_batched_tokens in place changes the meaning of this field from 'scheduling token budget' to 'maximum buffer capacity'. This has several problematic side effects:

  1. V0 Scheduler Inconsistency: The V0 scheduler (e.g., in vllm/core/scheduler.py) uses this field directly for scheduling decisions and lacks the compensating subtraction logic added to the V1 scheduler. This will result in an unintended increase in the scheduling budget for V0 when speculative decoding is enabled.
  2. Mamba Cache Alignment Check: The check at line 1116 (block_size <= max_num_batched_tokens) now validates against the extended buffer capacity instead of the actual scheduling budget. This could allow configurations where the scheduler is unable to schedule a full block, breaking the alignment requirement for Mamba models.
  3. Double Extension in Proposers: The V1 proposer base class (SpecDecodeBaseProposer in vllm/v1/spec_decode/eagle.py, line 105) performs its own extension logic based on the config value. Since the config value is now already extended, the proposer's internal buffers will be double-extended, wasting GPU memory.

Consider storing the original scheduling budget separately or ensuring that all components are updated to distinguish between the scheduling limit and the buffer capacity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I don't think we need to consider V0 here. Especially since this is only for speculative decoding. Is V0 still a concern?
  2. This is what we're trying to fix. Everything in the model runner should assume the larger buffer size, for safety.
  3. Fixed.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would make sense to just clone the vllm_config in WorkerBase.__init__ and only call _extend_max_num_batched_tokens_for_drafting on that config instead of modifying the scheduler? Maybe thats too susceptible to consistency bugs.

I guess another option would be add a max_num_tokens_per_forward_pass property to VllmConfig that is extend range. This would help create a seperation of concerns between scheduler_config and the model executor, the model executor cares about max_num_tokens_per_forward_pass, the scheduler cares about scheduler_config. We'd have to audit all scheduler_config.max_num_batched_tokensusages though to see if they should bemax_num_tokens_per_forward_pass` :/

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now to fix main. I really dislike the scheduler changes though (feels a bit spaghetti), I think we should try to get rid of the scheduler changes in a follow up soon

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 19, 2026
@DarkLight1337
Copy link
Member

You can merge from main now, CI should be fixed

@tdoublep
Copy link
Member

tdoublep commented Feb 20, 2026

I'm trying to test the changes from this branch but when I do:

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 --tensor-parallel-size 4

It seems to get stuck forever running a bunch of compilation behind the scenes. Whereas on main this does not happen.

update: seems to work after merging in newest commits on main

@DarkLight1337
Copy link
Member

Please fix https://buildkite.com/vllm/ci/builds/52393/steps/canvas?jid=019c7962-2b3e-4e59-9105-6a9d632edf7b&tab=output

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett
Copy link
Collaborator Author

Modified the implementation to fix the bug and be more in line with Lucas' suggestions. PTAL

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett merged commit 682566b into vllm-project:main Feb 22, 2026
49 checks passed
@dosubot
Copy link

dosubot bot commented Feb 22, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

yugong333 pushed a commit to yugong333/vllm that referenced this pull request Feb 22, 2026
jmamou pushed a commit to jmamou/vllm that referenced this pull request Feb 23, 2026
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…roject#34898)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants