Skip to content

[V1][Spec Decode] Add Dynamic SD#32374

Open
ekagra-ranjan wants to merge 45 commits intovllm-project:mainfrom
ekagra-ranjan:er-dynami-sd
Open

[V1][Spec Decode] Add Dynamic SD#32374
ekagra-ranjan wants to merge 45 commits intovllm-project:mainfrom
ekagra-ranjan:er-dynami-sd

Conversation

@ekagra-ranjan
Copy link
Copy Markdown
Contributor

@ekagra-ranjan ekagra-ranjan commented Jan 15, 2026

Why is Dynamic SD needed?

SD methods need to verify K tokens for each sequence during decoding. As BS increases, the effective BS becomes BS * K which increases the compute requirement during verification. When this BS*K goes beyond a critical BS then SD negatively impacts the TPOT. DSD helps by tuning down the K to an optimal value such that we continue to reap the benefits from SD.

Use cases

  • Possibility of High workload using same deployment. Here K would go down as workload increases.
  • During RL rollout where we start off with high BS but then end up with small BS due to very few long tail request which end up generating a lot of tokens stalling the progress of the current rollout. Here K would go up during the end of rollout.

What this PR does

Addresses #4565
V0 had milestone 0. V1 didn't have any form of Dynamic SD.

This PR implements something between Milestone 2 and 3 of Dynamic SD (DSD) where we dynamically determine the proposed length for speculative decoding using runtime information such as batch size and position level acceptance rate in conjunction with profiled parameters like token acceptance rate (for cold start) and the comparative costs of running the draft versus the target model. This approach allows us to adjust the proposed length in real-time, optimizing performance based on current system conditions.

Before inference happens, the approach uses a representative dataset to profile (similar to how the optimal K is selected for SD w/o Dynamic by iterating on a representative dataset):

  1. the position level acceptance rate for solving the cold start problem
  2. cost of running draft and target model

During inference runtime, the optimal K is found using:

  1. the current batch size
  2. average of position level acceptance rate that the system has seen so far. It waits for warmup_steps before it starts using the measured AR so far. Till warmup_steps it uses the AR from the offline profiling on a representative dataset.

This balances the cold start problem and allows the system to adapt to running request. There are many ways to extend this strategy like resetting AR after some steps but those are left for future work. The purpose of the PR is to have at least something working in vLLM.

The PR computes the goodput similar TurboSpec. However, there is some change to the formula to make it simpler and easier to extend to future models. For a given BS and K: goodput = AL / ITL where AL is a function of K and ITL is a function of K and BS.

TurboSpec on the other hand profiles draft and target separately and builds a regression model which is a function of Model config, KV cache size and batch size to find goodput. This PR follows a simplified approach where the ITL (inter token latency) of the SD model, i.e., target + draft, is directly noted across batch sizes which encapsulates the model config. This makes the setup easier to adapt when model arch changes like SWA or a new change come into picture in future which would make the equation more complicated. The setup profiles using some given batch sizes (BS) and num of draft (K) and linearly interpolates the values between neighboring values for each BS and K bw min and max values of BS and K. While simple, it works effectively as shown in the results.

Results

Offline profiled on MTBench and Tested on MTBench

<style type="text/css"></style>

1xH100      
llama 3.1 8b      
MTBench Vanilla EAGLE Dynamic EAGLE
BS 1 6.3 3.98 3.98
BS 4 6.38 4.03 4.05
BS 16 6.77 4.45 4.45
BS 64 7.94 6.78 6.56
BS 128 10.15 11.19 9.88
BS 256 16.2 19.96 17.2
image Above measures TPOT (ms). Lower is better.

As we can see,

  • At lower BS, DSD is equal to SD and both are better than vanilla
  • At higher BS, SD is worse than vanilla and DSD is better than SD and closer to vanilla. However, DSD has some overhead of running the draft model to prefill even though its not used during decode even though DSD would assign K=0. This is fine because the setup can change BS in future so having all tokens prefilled in draft model is needed.

Offline profiled on MTBench and Tested on InstructCoder

<style type="text/css"></style>

  Profiled on MTB      
InstructCoder Vanilla EAGLE Dynamic EAGLE Dynamic EAGLE with runtime AL
BS 128 12.69 11.55 11.85 11.43
BS 256 21.19 21.5 21.07 21.07
image

Here, "Dynamic EAGLE" is not using runtime AL at all. As we can see adding runtime AL to goodput calculation after sometime give some minor improvement here so for this dataset MTBench numbers are well transferrable to InstrucrCoder but the runtime AL connection would help in adapting more to current workload.

Cmds

Generate DSD Config

time python3 vllm/v1/spec_decode/dynamic/generate_config.py \
    --method eagle \
    --model-dir 'meta-llama/Llama-3.1-8B-Instruct' \
    --draft-dir 'yuhuili/EAGLE-LLaMA3.1-Instruct-8B' \
    --tp 1 \
    --temp 0 \
    --top-p 1.0 \
    --top-k -1 \
    --max-vllm-batch-size 256 \
    --batch-size-list 1 4 16 64 256 \
    --num-speculative-tokens-list 1 3 5 \
    --num-batches 20 \
    --dataset-name hf \
    --dataset-path 'philschmid/mt-bench' \
    --no-oversample \
    --result-dir './log/dynamic_sd_test'
Example of `dynamic_speculative_config.json` generated
{
    "is_online": false,
    "batch_stats": {
        "1": {
            "0": 6.520589930005372,
            "1": 7.367628160864115,
            "3": 8.84066498838365,
            "5": 10.32649097032845
        },
        "4": {
            "0": 6.601515458896756,
            "1": 7.472813129425049,
            "3": 8.981170016340911,
            "5": 10.400271974503994
        },
        "16": {
            "0": 6.898819003254175,
            "1": 7.852344075217843,
            "3": 9.518282022327185,
            "5": 11.196403065696359
        },
        "64": {
            "0": 7.774091092869639,
            "1": 9.656429989263415,
            "3": 13.497876934707165,
            "5": 16.831180080771446
        },
        "256": {
            "0": 14.491415582597256,
            "1": 27.138127014040947,
            "3": 41.848431108519435,
            "5": 57.40421102382243
        }
    },
    "max_num_speculative_tokens": 5,
    "acceptance_rate_per_pos": [
        0.6811801775995416,
        0.3914351188771126,
        0.20352334574620454,
        0.1014036092810083,
        0.051417931824692065
    ]
}

Benchmark

We chose 20*MAX_CONCURRENCY as the num of prompt so that each setting has at least 20 batches. Without this since MTBench only has 80 samples so MAX_CONCURRENCY=1 would have 80 batches and MAX_CONCURRENCY=128 will have only 1 BS.

# vanilla
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
  --port 9001 \
  --no-enable-prefix-caching \
  --max-num-seqs 256

# Eagle
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
  --port 9001 \
  --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3}' \
  --no-enable-prefix-caching \
  --max-num-seqs 256

# Dynamic Eagle
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
  --port 9001 \
  --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3, "dynamic_config_path": "log/dynamic_sd_test_2/tp-1_temp-0.0_top_p-1.0_top_k--1/philschmid/mt-bench/dynamic_speculative_config.json"}' \
  --no-enable-prefix-caching \
  --max-num-seqs 256

# change MAX_CONCURRENCY here.
MAX_CONCURRENCY=1
NUM_PROMPTS=$((MAX_CONCURRENCY * 20))  
time vllm bench serve --port 9001 --save-result --save-detailed \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --result-dir "./log/EAGLE-1"

File changes:

  • vllm/v1/spec_decode/dynamic/generate_config.py is the master file which schedules different scripts and gets the config which is used by DSD during runtime. It has different stages:
    • Step 1: Uses offline script to get the AL across different positions.
      • vllm/v1/spec_decode/offline.py is used for it. This is offline_inference/spec_decode.py but moved to vllm/ so that it can be imported here. This offline script is also used in test in CI so is an important file.
    • Step 2: Runs profiling to get the ITL across different BS and K using vllm bench sweep
    • Step 3: Parses the various values generated for each BS and K and collates ITL from them in a config value.
    • Step 4: saves the Dynamic SD config as a config file
  • Adds config class DynamicSpeculativeConfig in vllm/config/speculative.py which holds the config values during DSD profiling. It also has path to the config values.
  • vllm/v1/spec_decode/dynamic/manager.py is the Dynamic SD Manager which reads the ITL from the DynamicSpeculativeConfig generated above and generates optimal K for each BS by interpolating across K and BS during the profiling and then provides it to the SD method during proposal.
  • vllm/v1/worker/gpu_model_runner.py will initalize the DSD Manager and provide the optimal K for the given BS during inference to resp SD method.
  • Introduces spec_decoding_stats_all in scheduler which collects the stats and is used in dynamic/manager.py to compute AR and use the updated values after certain warmup_steps

After Async scheduling and padded drafter compatibility

Similar to the synchronous scheduling

File changes for async and padded drafter

Old approach ### `vllm/v1/core/sched/async_scheduler.py` **Problem**: With async scheduling, when dynamic SD changes the optimal K (e.g., from 5 to 3), there's a pipeline latency issue: the scheduler has already committed accounting (num_computed_tokens, num_output_placeholders) for the in-flight batch using the old K. **Solution**: `_pending_optimal_k`: int | None — stores the optimal K from model output, deferred until the next schedule() call. `_in_flight_decode_req_k`: dict[str, int] — maps req_id -> committed spec token count for decode requests in the most recently dispatched batch. Used to know exactly which requests need accounting correction and by how much.

New method _apply_pending_dynamic_sd_update(): Called at the start of schedule(). Applies the deferred K update:

  • Updates _spec_token_placeholders to the new K length (controls how many spec positions the scheduler reserves for future batches → reduces KV block waste).
  • Corrects the in-flight batch's over-committed accounting: for each request in _in_flight_decode_req_k, computes diff = committed_k - optimal_k. If diff > 0 (K decreased), subtracts diff from request.num_output_placeholders and request.num_computed_tokens. If diff <= 0 (K increased), just updates request.spec_token_ids for the next scheduling step (can't retroactively add tokens to an in-flight batch).

Override schedule(): Calls _apply_pending_dynamic_sd_update() then delegates to super().schedule().
Modified _update_after_schedule(): Resets and populates _in_flight_decode_req_k with req_id -> cur_num_spec_tokens for each non-prefill decode request that was just committed with spec tokens > 0.

vllm/v1/worker/gpu_model_runner.py

Problem: the model runner still processes (and rejects) zero-padded speculative tokens beyond the optimal K, wasting compute. The SchedulerOutput seen by the model runner still contains the old (larger) K from when the batch was scheduled.

Solution:
New method _trim_spec_tokens_for_dynamic_sd(scheduler_output): Trims scheduled_spec_decode_tokens in-place to match self._optimal_num_speculative_tokens. For each request where scheduled_k > optimal_k

Modified _update_states(): Inserted a call to _trim_spec_tokens_for_dynamic_sd(scheduler_output) before the ngram_gpu handling block. Conditioned on _optimal_num_speculative_tokens is not None and use_async_scheduling and scheduled_spec_tokens. This ordering ensures original_num_spec_per_req (saved for ngram_gpu's prev_num_draft_len restoration) is based on the dynamically-trimmed K rather than the over-allocated K.

Modified take_draft_token_ids(): When dynamic SD reduced K below num_spec_tokens, truncates each request's draft token list to k entries (the GPU tensor is zero-padded to num_spec_tokens for scatter indexing, but the scheduler should only see real draft tokens).

image

New Approach

  • padded drafter

    • no padding is done.
    • Model runner at step N saves the K in prev_num_spec_tokens during _copy_draft_token_ids_to_cpu() so that the model runner at Step N+1 can correctly index the draft_token_ids where prev_num_spec_tokens (changes) is used for stride instead of num_spec_tokens (fixed)
  • async scheduling

    • scheduler.py at step N sets the num_spec_tokens_to_schedule to send to model runner at Step N
    • async_scheduler.py updates the spec token placeholder in _update_after_schedule() at step N so that the scheduler at step at N+1 can account for new K spec tokens to send for verification to the engine. The _spec_token_placeholders gets saved in request.spec_token_ids in _update_after_schedule() of async sched at step N which then gets used to create scheduled_spec_decode_tokens which gets consumed as draft_len in _prepare_input_ids()

So prev_num_spec_tokens decides how many draft token ids were drafted at step N and draft_len decides how many of them will be verified at Step N+1. draft_len <=prev_num_spec_tokens since draft_len comes from the token budget we have available in this fwd pass.
image

PENDING (some of them can be done in future PRs):

  • use online AL to refine the goodput after warmup
  • While this PR only tested EAGLE-1, it can be extended to other methods like EAGLE-3 etc
  • Probably vllm sweep can be used instead of the newly added profiling_client.py and profiling_server.py
  • padded drafter
  • async scheduling
  • add some tests

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jan 15, 2026

Documentation preview: https://vllm--32374.org.readthedocs.build/en/32374/

@mergify mergify Bot added documentation Improvements or additions to documentation speculative-decoding labels Jan 15, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jan 15, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Dynamic Speculative Decoding (DSD), a significant performance enhancement for vLLM. The implementation involves profiling to gather runtime statistics and then using those to dynamically adjust the number of speculative tokens. The changes are extensive, adding new scripts for configuration generation, profiling, and a manager for DSD logic. While the overall approach is sound, I've identified several critical issues, including potential server crashes due to division by zero, command injection vulnerabilities in the profiling scripts, and other high-severity bugs that could lead to incorrect behavior or system instability. These issues should be addressed to ensure the feature is robust and secure.

Comment thread vllm/v1/spec_decode/dynamic/manager.py
Comment thread vllm/v1/spec_decode/dynamic/process_benchmark_results.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/profiling_client.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/profiling_client.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/profiling_server.py Outdated
Comment thread vllm/v1/spec_decode/ngram_proposer.py Outdated
Comment thread vllm/config/speculative.py
Comment thread vllm/v1/spec_decode/dynamic/generate_config.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/profiling_server.py Outdated
Comment thread vllm/v1/spec_decode/eagle.py Outdated
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

Comment thread vllm/v1/spec_decode/ngram_proposer.py Outdated
Comment thread vllm/v1/spec_decode/ngram_proposer.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/profiling_client.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/manager.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/process_benchmark_results.py Outdated
Comment thread vllm/v1/spec_decode/dynamic/manager.py
Comment thread vllm/v1/spec_decode/offline.py Outdated
Comment thread vllm/v1/spec_decode/eagle.py
Comment thread vllm/v1/spec_decode/dynamic/manager.py Outdated
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Comment thread vllm/config/speculative.py Outdated
…mal K

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 17, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 17, 2026
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@mergify mergify Bot removed the needs-rebase label Mar 17, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 17, 2026

Hi @ekagra-ranjan, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@ekagra-ranjan
Copy link
Copy Markdown
Contributor Author

Thank you @LucasWilkinson @benchislett @hmellor for having a look!

I have addressed the comments so far and updated the PR with async scheduling and padded drafter compatibility. The trend is similar to the previously reported in sync scheduler. I have added summary of the changes for async scheduler in the PR description.
image

@benchislett
Copy link
Copy Markdown
Collaborator

@ekagra-ranjan do we have a solution for CUDA graphs? Are they being employed for all values of K?

If so, how do you explain the slight discrepancy at low concurrencies when using dynamic SD?

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

LucasWilkinson commented Mar 19, 2026

I did a early pass just didnt have time to respond yet but one thing i am wondering is: it might be alot cleaner to have the scheduler determine the number of draft tokens to generate (i.e. own DynamicSpeculativeDecodingManager) for the next step and send that via an optional attribute in SchedulerOutput. This would mean we wouldnt have to send spec_decoding_stats_all to the model runner and would mean we would know the correct number of placeholders to append in the scheduler (avoiding the need to correct). Not to mention would make this feature more easily ported to other HW/plugins.

LucasWilkinson added a commit to neuralmagic/vllm that referenced this pull request Mar 24, 2026
Implement dynamic speculative decoding where the scheduler computes the
optimal number of draft tokens (K) based on batch size and acceptance
rates. This inverts the original PR's paradigm where model runner decided K.

Key changes:
- DynamicSpeculativeConfig: Config holding profiled ITL stats per (BS, K)
- DynamicSpeculativeDecodingManager: Computes optimal K using goodput = AL/ITL
- SchedulerOutput.num_spec_tokens_to_schedule: Scheduler tells model runner
  how many tokens to speculate
- EagleProposer.propose() accepts num_speculative_tokens parameter
- Stats tracking in manager for online acceptance rate updates

This approach is cleaner because:
- Scheduler already knows batch size and tracks acceptance stats
- No round-trip needed (scheduler doesn't wait for ModelRunnerOutput)
- Placeholder accounting is simpler - scheduler knows K at schedule time

Based on PR vllm-project#32374 by ekagra-ranjan.

Co-authored-by: Ekagra Ranjan <ekagra.ranjan@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson
Copy link
Copy Markdown
Collaborator

LucasWilkinson commented Mar 24, 2026

I did a early pass just didnt have time to respond yet but one thing i am wondering is: it might be alot cleaner to have the scheduler determine the number of draft tokens to generate (i.e. own DynamicSpeculativeDecodingManager) for the next step and send that via an optional attribute in SchedulerOutput. This would mean we wouldnt have to send spec_decoding_stats_all to the model runner and would mean we would know the correct number of placeholders to append in the scheduler (avoiding the need to correct). Not to mention would make this feature more easily ported to other HW/plugins.

Vibe coded this here: #32374 (comment) to demonstrate the proposal

@benchislett
Copy link
Copy Markdown
Collaborator

@LucasWilkinson +1, this seems like the right high-level design to me. I don't see any downsides, @ekagra-ranjan what do you think? Do you foresee any challenges?

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 31, 2026

Hi @ekagra-ranjan, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 1, 2026

Hi @ekagra-ranjan, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
@ekagra-ranjan
Copy link
Copy Markdown
Contributor Author

ekagra-ranjan commented Apr 1, 2026

The DSD manager is now owned by scheduler and the PR has been updated to follow the high level design suggested by @LucasWilkinson. We now also dont pad to max K and operate with the optimal K when dealing with draft_ids in the model runner. The gap bw Dynamic Eagle and vanilla has reduced at BS256 compared to previous case.

image

Summary of the changes:

padded drafter

  • not padded to max K.
  • Model runner at step N saves the K in prev_num_spec_tokens (newly added in this PR) during _copy_draft_token_ids_to_cpu() so that the model runner at Step N+1 can correctly index the draft_token_ids where prev_num_spec_tokens (changes) is used for stride instead of num_spec_tokens (fixed)

async scheduling

  • scheduler.py at step N sets the num_spec_tokens_to_schedule to send to model runner at Step N
  • async_scheduler.py updates the spec token placeholder in _update_after_schedule() at step N so that the scheduler at step at N+1 can account for new K spec tokens to send for verification to the engine.
  • The _spec_token_placeholders gets saved in request.spec_token_ids in _update_after_schedule() of async sched at step N which then gets used to create scheduled_spec_decode_tokens which gets consumed as draft_len in _prepare_input_ids()

So prev_num_spec_tokens decides how many draft token ids were drafted at step N and draft_len decides how many of them will be verified at Step N+1. draft_len <=prev_num_spec_tokens since draft_len comes from the token budget we have available in this fwd pass.

Re this:

do we have a solution for CUDA graphs? Are they being employed for all values of K?
If so, how do you explain the slight discrepancy at low concurrencies when using dynamic SD?

@benchislett FCG is very likely not working with changing values of K. I dont have the solution for this yet but I'll read the code base around FCG in vLLM more in detail next. @LucasWilkinson said that his prev work on FCG can help here.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants