Skip to content

[Model Runner V2] Rebuild attn metadata between draft decode steps#41162

Merged
WoosukKwon merged 2 commits intovllm-project:mainfrom
TheEpicDolphin:mrv2-draft-decode-rebuild-attn-metadata
May 5, 2026
Merged

[Model Runner V2] Rebuild attn metadata between draft decode steps#41162
WoosukKwon merged 2 commits intovllm-project:mainfrom
TheEpicDolphin:mrv2-draft-decode-rebuild-attn-metadata

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Apr 28, 2026

Context

Investigating DSV4 "invalid memory access" crash that happens when MTP > 2. I believe that crash is caused by not rebuilding the attention metadata in between draft decode steps. DSV4's attention metadata builders have position-dependent state that must be updated whenever the position is advanced. For example:

self.decode_swa_lens[num_decode_tokens:] = 0
_compute_swa_indices_and_lens_kernel[(num_decode_tokens,)](
self.decode_swa_indices,
self.decode_swa_indices.stride(0),
self.decode_swa_lens,
self.window_size,
query_start_loc,
seq_lens,
token_to_req_indices,
is_valid_token,
block_table,
block_table.stride(0),
self.block_size,
TRITON_BLOCK_SIZE=1024,
)
.

During draft decoding, we only update the sequence lengths, positions, and slot mappings, but not the metadata properties dependent on those:

# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, num_tokens_padded
)
.

Any attention backend with builders that have position-dependent state can potentially be affected by this bug.

This PR

Fixes the above issue by rebuilding attention metadata in between draft generation steps during the multi-step decode loop. Because attention metadata building is not guaranateed to be a cudagraph-compatible operation, we must capture the single-step draft generation, and replay in each iteration of the loop. The captured draft generation routine consists of the following operations, in order:

  1. run_model - Draft model forward pass to get hidden states.
  2. compute_logits - Computes logits from hidden states.
  3. _sample_draft - Samples a draft token using either gumbel or greedy sampling (depending on what was configured)
  4. update_eagle_draft_inputs - Updates the output draft tokens (self.draft_tokens) for the current step, and updates the inputs for the next draft step, such as input ids, hidden states, positions, etc.

Because cudagraph prevents me from passing an integer step idx to determine which draft tokens/logits column to write the outputs to, I had to introduce a scalar tensor called self.current_draft_step. That single-element tensor is
used by update_eagle_draft_inputs to write the output draft token ids to the output buffer (self.draft_tokens).

Updating the correct column (step) of self.draft_logits was a bit tricker. I couldn't just pass it into the sample method like i did before:

return gumbel_sample(
    logits,
    idx_mapping,
    self.temperature,
    self.seeds,
    pos + 1,
    apply_temperature=True,
    processed_logits_out=self.draft_logits[:, step],
)

because step is now a scalar tensor instead of an int. Attempting to do so would trigger advanced indexing, returning a new copy of the tensor, not a slice. So I had to add support to gumbel_sample for accepting a column tensor (output_processed_logits_col), indicating which column in the logits tensor to write to.

Test Plan

Before, the following server and bench commands:

**Server command**
VLLM_USE_FLASHINFER_MOE_FP8=0 \
FLASHINFER_DISABLE_VERSION_CHECK=1 \
VLLM_USE_V2_MODEL_RUNNER=1 \
HF_HOME=/data/hf-models \
MODEL_CKPT="deepseek-ai/DeepSeek-V4-Flash" \
vllm serve $MODEL_CKPT \
    -dp 4 -ep \
    --port 8003 \
    --block-size 256 \
    --kv-cache-dtype fp8 \
    --tokenizer-mode deepseek_v4 \
    --tool-call-parser deepseek_v4 \
    --reasoning-parser deepseek_v4 \
    --moe-backend deep_gemm_mega_moe \
    --no-enable-prefix-caching \
    --kernel-config.enable_flashinfer_autotune=False \
    --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'

**Bench command**
MODEL="deepseek-ai/DeepSeek-V4-Flash" \
PORT=8003 \
vllm-bench --backend openai-chat \
     --base-url http://127.0.0.1:$PORT \
     --model ${MODEL} \
     --dataset-name speed-bench \
     --speed-bench-config throughput_16k \
     --speed-bench-max-input-len 8192 \
     --speed-bench-category low_entropy \
     --output-len 1024 \
     --save-result \
     --num-warmups 1 \
     --seed 420 \
     --sweep-max-concurrency 1,2,1024 \
     --sweep-num-prompts-factor 5 \
     --temperature 0 \
     --top-p 1

Would result in a crash with an IMA error.

Now with the changes in this PR, it serves to completion.

Benchmarks

I measured the effect that this PR has on performance, to make sure it doesn't result in any significant performance regressions. I measured all of them with 3 speculative tokens. The following benchmark command was used:

vllm-bench \
      --backend openai-chat \
      --base-url http://127.0.0.1:8000 \
      --model {model} \
      --dataset-name hf \
      --dataset-path philschmid/mt-bench \
      --ignore-eos \
      --request-rate inf \
      --temperature {temperature} \
      --num-warmups 50 \
      --output-len 1536 \
      --sweep-max-concurrency 1,16,64 \
      --sweep-num-prompts-factor 10

NOTE: GLM 4.7 Flash with 3 speculative tokens failed on main, likely due to the issue that this PR fixed.

GLM 4.7 Flash

Base: zai-org/GLM-4.7-Flash
Draft: MTP

Temperature = 0.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 205.38 202.61 -1.3% 2010.14 1887.19 -6.1% 6159.68 6504.64 +5.6%
Request Throughput (req/s) 1.6045 1.5829 -1.3% 15.7042 14.7437 -6.1% 48.1225 50.8175 +5.6%
Total Token Throughput (tok/s) 341.12 336.53 -1.3% 3075.18 2887.09 -6.1% 9364.64 9889.09 +5.6%
Mean TTFT (ms) 35.18 34.57 -1.7% 67.24 98.08 +45.9% 113.69 87.66 -22.9%
Mean TPOT (ms) 4.6302 4.7018 +1.5% 7.3146 7.6276 +4.3% 9.1607 9.0010 -1.7%
Mean ITL (ms) 10.0515 10.0525 +0.0% 15.5179 16.2109 +4.5% 19.4447 19.0327 -2.1%
Mean E2E Latency (ms) 623.21 631.70 +1.4% 996.19 1066.79 +7.1% 1277.09 1230.78 -3.6%
Median TTFT (ms) 33.51 33.81 +0.9% 50.97 50.86 -0.2% 65.88 64.89 -1.5%
Median TPOT (ms) 4.6429 4.7009 +1.3% 7.3075 7.3410 +0.5% 9.1235 8.9978 -1.4%
Median ITL (ms) 10.2447 10.2300 -0.1% 15.5085 15.2813 -1.5% 17.6433 17.6407 -0.0%
P99 TTFT (ms) 43.65 38.42 -12.0% 234.32 1004.84 +328.8% 432.87 260.04 -39.9%
P99 TPOT (ms) 4.8689 4.8699 +0.0% 8.1147 13.9298 +71.7% 11.2815 10.6485 -5.6%
P99 ITL (ms) 10.6638 10.7676 +1.0% 21.5950 21.2600 -1.6% 85.9391 77.6264 -9.7%
Spec Decode Accept Rate (%) 40.70 39.67 -2.5% 38.91 38.84 -0.2% 38.55 38.64 +0.2%
Spec Decode Accept Length 2.2209 2.1901 -1.4% 2.1674 2.1652 -0.1% 2.1565 2.1592 +0.1%

Temperature = 1.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) N/A 192.23 N/A 1886.56 N/A 5665.77
Request Throughput (req/s) N/A 1.5018 N/A 14.7387 N/A 44.2639
Total Token Throughput (tok/s) N/A 319.28 N/A 2886.12 N/A 8613.75
Mean TTFT (ms) N/A 35.69 N/A 59.20 N/A 174.90
Mean TPOT (ms) N/A 4.9618 N/A 7.8593 N/A 9.5913
Mean ITL (ms) N/A 10.0821 N/A 15.5942 N/A 19.0810
Mean E2E Latency (ms) N/A 665.84 N/A 1057.33 N/A 1392.99
Median TTFT (ms) N/A 34.15 N/A 50.26 N/A 62.04
Median TPOT (ms) N/A 4.9235 N/A 7.8293 N/A 9.5426
Median ITL (ms) N/A 10.2491 N/A 15.5669 N/A 17.7089
P99 TTFT (ms) N/A 43.20 N/A 146.34 N/A 1109.73
P99 TPOT (ms) N/A 5.3066 N/A 9.2968 N/A 12.0845
P99 ITL (ms) N/A 10.7639 N/A 22.9061 N/A 69.2927
Spec Decode Accept Rate (%) N/A 35.72 N/A 34.14 N/A 34.31
Spec Decode Accept Length N/A 2.0715 N/A 2.0242 N/A 2.0292

GPT-OSS 20B

Base: openai/gpt-oss-20b
Draft: RedHatAI/gpt-oss-20b-speculator.eagle3

Temperature = 0.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 431.39 432.03 +0.1% 3888.19 3781.10 -2.8% 9254.44 9171.40 -0.9%
Request Throughput (req/s) 3.6221 3.6274 +0.1% 31.0542 30.3398 -2.3% 74.4469 73.8605 -0.8%
Total Token Throughput (tok/s) 734.20 735.28 +0.1% 5970.57 5815.57 -2.6% 14146.07 14024.50 -0.9%
Mean TTFT (ms) 16.84 16.54 -1.7% 40.12 52.07 +29.8% 94.79 103.26 +8.9%
Mean TPOT (ms) 2.4485 2.4490 +0.0% 3.8069 3.8345 +0.7% 6.2519 6.2834 +0.5%
Mean ITL (ms) 5.3687 5.3656 -0.1% 9.4867 9.2596 -2.4% 15.1334 15.3950 +1.7%
Mean E2E Latency (ms) 276.05 275.64 -0.1% 496.46 505.92 +1.9% 830.77 838.04 +0.9%
Median TTFT (ms) 15.99 15.91 -0.5% 30.94 30.67 -0.9% 48.53 49.45 +1.9%
Median TPOT (ms) 2.0883 2.0857 -0.1% 3.6254 3.6568 +0.9% 5.8885 5.8287 -1.0%
Median ITL (ms) 5.3346 5.3215 -0.2% 8.6269 8.6313 +0.1% 13.8224 13.7318 -0.7%
P99 TTFT (ms) 23.76 23.53 -1.0% 134.73 267.67 +98.7% 531.30 423.95 -20.2%
P99 TPOT (ms) 5.5283 5.5468 +0.3% 9.8958 9.9159 +0.2% 16.2168 16.5772 +2.2%
P99 ITL (ms) 6.5873 5.8979 -10.5% 17.8788 15.1873 -15.1% 30.6647 36.1486 +17.9%
Spec Decode Accept Rate (%) 53.64 53.64 +0.0% 53.05 53.25 +0.4% 53.42 53.56 +0.3%
Spec Decode Accept Length 2.6091 2.6091 +0.0% 2.5914 2.5975 +0.2% 2.6025 2.6069 +0.2%

Temperature = 1.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 359.76 357.88 -0.5% 3241.75 3217.54 -0.7% 7585.31 7497.40 -1.2%
Request Throughput (req/s) 3.0055 2.9898 -0.5% 25.8745 25.6877 -0.7% 60.1807 59.5238 -1.1%
Total Token Throughput (tok/s) 611.03 607.83 -0.5% 4976.80 4940.06 -0.7% 11539.55 11408.48 -1.1%
Mean TTFT (ms) 15.21 15.52 +2.0% 45.13 46.86 +3.8% 147.01 160.58 +9.2%
Mean TPOT (ms) 2.8696 2.8800 +0.4% 4.5033 4.5519 +1.1% 7.2455 7.1855 -0.8%
Mean ITL (ms) 5.3180 5.3450 +0.5% 9.1962 9.2448 +0.5% 15.1608 14.7416 -2.8%
Mean E2E Latency (ms) 332.69 334.42 +0.5% 591.27 598.70 +1.3% 1031.95 1041.75 +0.9%
Median TTFT (ms) 15.44 15.66 +1.4% 29.78 30.08 +1.0% 46.30 46.67 +0.8%
Median TPOT (ms) 2.6274 2.6387 +0.4% 4.3387 4.4336 +2.2% 6.9796 6.9897 +0.1%
Median ITL (ms) 5.3105 5.3321 +0.4% 8.6610 8.7203 +0.7% 13.7379 13.6492 -0.6%
P99 TTFT (ms) 18.04 18.68 +3.5% 208.67 224.98 +7.8% 1073.82 1246.30 +16.1%
P99 TPOT (ms) 5.4027 5.3922 -0.2% 9.7073 9.7852 +0.8% 15.4919 15.5205 +0.2%
P99 ITL (ms) 5.8432 5.8117 -0.5% 15.9700 15.7048 -1.7% 22.7490 23.4073 +2.9%
Spec Decode Accept Rate (%) 37.98 37.98 +0.0% 38.16 37.70 -1.2% 37.38 36.85 -1.4%
Spec Decode Accept Length 2.1393 2.1393 +0.0% 2.1449 2.1311 -0.6% 2.1214 2.1054 -0.8%

Llama3 8B

Base: meta-llama/Meta-Llama-3-8B-Instruct
Draft: yuhuili/EAGLE-LLaMA3-Instruct-8B

Temperature = 0.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 291.05 291.20 +0.1% 4112.12 4092.28 -0.5% 9474.56 9376.01 -1.0%
Request Throughput (req/s) 2.2738 2.2750 +0.1% 32.1260 31.9710 -0.5% 74.0200 73.2500 -1.0%
Total Token Throughput (tok/s) 482.05 482.31 +0.1% 6289.06 6258.71 -0.5% 14401.06 14251.25 -1.0%
Mean TTFT (ms) 12.60 12.70 +0.7% 26.53 26.14 -1.5% 55.38 55.47 +0.2%
Mean TPOT (ms) 3.3634 3.3608 -0.1% 3.5028 3.5260 +0.7% 6.1544 6.1819 +0.4%
Mean ITL (ms) 6.2356 6.2307 -0.1% 6.9048 6.9399 +0.5% 12.1849 12.2880 +0.8%
Mean E2E Latency (ms) 439.76 439.51 -0.1% 471.38 473.94 +0.5% 836.99 840.57 +0.4%
Median TTFT (ms) 12.42 12.81 +3.1% 21.38 21.46 +0.4% 40.79 40.90 +0.3%
Median TPOT (ms) 3.2760 3.2760 -0.0% 3.3938 3.4484 +1.6% 6.0840 6.1086 +0.4%
Median ITL (ms) 6.3300 6.3252 -0.1% 6.8390 6.8551 +0.2% 11.3971 11.5826 +1.6%
P99 TTFT (ms) 14.81 14.92 +0.8% 86.46 68.64 -20.6% 199.94 233.97 +17.0%
P99 TPOT (ms) 3.9284 3.9243 -0.1% 4.8314 4.9058 +1.5% 8.2204 8.3512 +1.6%
P99 ITL (ms) 6.7395 6.7063 -0.5% 10.2078 10.4182 +2.1% 32.0866 35.0774 +9.3%
Spec Decode Accept Rate (%) 29.78 29.78 +0.0% 33.75 33.44 -0.9% 33.01 32.89 -0.3%
Spec Decode Accept Length 1.8933 1.8933 +0.0% 2.0126 2.0032 -0.5% 1.9902 1.9867 -0.2%

Temperature = 1.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 276.85 276.67 -0.1% 3900.91 3876.46 -0.6% 8846.05 8837.11 -0.1%
Request Throughput (req/s) 2.1629 2.1615 -0.1% 30.4758 30.2848 -0.6% 69.1098 69.0399 -0.1%
Total Token Throughput (tok/s) 458.53 458.23 -0.1% 5966.02 5928.63 -0.6% 13445.74 13432.14 -0.1%
Mean TTFT (ms) 14.30 13.89 -2.9% 28.37 27.72 -2.3% 53.68 60.68 +13.1%
Mean TPOT (ms) 3.5275 3.5332 +0.2% 3.7094 3.6992 -0.3% 6.6082 6.5903 -0.3%
Mean ITL (ms) 6.5782 6.5889 +0.2% 7.1867 7.1916 +0.1% 12.6040 12.5465 -0.5%
Mean E2E Latency (ms) 462.29 462.61 +0.1% 499.47 497.52 -0.4% 892.92 897.66 +0.5%
Median TTFT (ms) 14.58 13.96 -4.2% 22.37 22.26 -0.5% 41.46 41.64 +0.4%
Median TPOT (ms) 3.4193 3.4279 +0.3% 3.5997 3.6155 +0.4% 6.6191 6.5779 -0.6%
Median ITL (ms) 6.6868 6.7009 +0.2% 7.1261 7.1350 +0.1% 12.0220 12.0475 +0.2%
P99 TTFT (ms) 15.77 15.83 +0.4% 80.21 77.24 -3.7% 209.44 249.50 +19.1%
P99 TPOT (ms) 4.7113 4.7176 +0.1% 5.1128 5.1216 +0.2% 9.2443 9.1805 -0.7%
P99 ITL (ms) 7.0063 7.0475 +0.6% 10.8579 10.6905 -1.5% 34.3122 31.8153 -7.3%
Spec Decode Accept Rate (%) 30.30 30.30 +0.0% 32.64 32.89 +0.8% 31.20 31.42 +0.7%
Spec Decode Accept Length 1.9091 1.9091 +0.0% 1.9792 1.9867 +0.4% 1.9361 1.9425 +0.3%

Qwen3 8B

Base: Qwen/Qwen3-8B
Draft: RedHatAI/Qwen3-8B-speculator.eagle3

Temperature = 0.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 370.43 369.02 -0.4% 4961.65 4996.27 +0.7% 11192.22 11234.58 +0.4%
Request Throughput (req/s) 2.8940 2.8830 -0.4% 38.7629 39.0334 +0.7% 87.4392 87.7701 +0.4%
Total Token Throughput (tok/s) 657.22 654.73 -0.4% 7674.80 7728.37 +0.7% 17168.55 17233.53 +0.4%
Mean TTFT (ms) 13.74 14.78 +7.6% 26.07 27.09 +3.9% 55.76 58.28 +4.5%
Mean TPOT (ms) 2.6124 2.6145 +0.1% 2.8569 2.8478 -0.3% 5.1526 5.1110 -0.8%
Mean ITL (ms) 5.9349 5.9398 +0.1% 6.4985 6.4805 -0.3% 11.6197 11.9177 +2.6%
Mean E2E Latency (ms) 345.52 346.83 +0.4% 388.90 388.76 -0.0% 710.13 707.37 -0.4%
Median TTFT (ms) 13.90 15.00 +7.9% 20.97 22.09 +5.4% 42.73 44.18 +3.4%
Median TPOT (ms) 2.6899 2.6922 +0.1% 2.8418 2.8298 -0.4% 5.1640 5.1237 -0.8%
Median ITL (ms) 6.0418 6.0657 +0.4% 6.4218 6.4123 -0.1% 10.9392 10.6233 -2.9%
P99 TTFT (ms) 15.98 17.41 +8.9% 73.35 74.97 +2.2% 204.25 202.24 -1.0%
P99 TPOT (ms) 2.7856 2.7958 +0.4% 4.0276 4.0070 -0.5% 7.1513 7.1189 -0.5%
P99 ITL (ms) 6.5262 6.4506 -1.2% 10.2878 9.9438 -3.3% 23.9119 28.6535 +19.8%
Spec Decode Accept Rate (%) 44.63 44.63 +0.0% 44.49 44.54 +0.1% 43.77 43.72 -0.1%
Spec Decode Accept Length 2.3388 2.3388 +0.0% 2.3347 2.3363 +0.1% 2.3130 2.3116 -0.1%

Temperature = 1.0

Metric Before (c=1) After (c=1) Delta Before (c=16) After (c=16) Delta Before (c=64) After (c=64) Delta
Output Throughput (tok/s) 341.05 340.19 -0.2% 4664.82 4609.79 -1.2% 10704.67 10608.86 -0.9%
Request Throughput (req/s) 2.6644 2.6578 -0.2% 36.4439 36.0140 -1.2% 83.6302 82.8817 -0.9%
Total Token Throughput (tok/s) 605.09 603.58 -0.2% 7215.66 7130.55 -1.2% 16420.66 16273.70 -0.9%
Mean TTFT (ms) 14.07 14.80 +5.2% 26.83 29.91 +11.5% 57.01 54.84 -3.8%
Mean TPOT (ms) 2.8443 2.8458 +0.1% 3.0617 3.0640 +0.1% 5.4326 5.4696 +0.7%
Mean ITL (ms) 6.3593 6.3628 +0.1% 6.8155 6.9161 +1.5% 12.0516 12.1472 +0.8%
Mean E2E Latency (ms) 375.29 376.22 +0.2% 415.67 419.03 +0.8% 746.95 749.47 +0.3%
Median TTFT (ms) 13.32 14.32 +7.5% 21.77 23.28 +7.0% 44.47 43.76 -1.6%
Median TPOT (ms) 2.8877 2.8881 +0.0% 3.0725 3.0455 -0.9% 5.4397 5.4697 +0.6%
Median ITL (ms) 6.4812 6.4970 +0.2% 6.7226 6.7085 -0.2% 11.0800 11.1692 +0.8%
P99 TTFT (ms) 17.70 17.76 +0.4% 71.06 84.02 +18.2% 186.84 200.47 +7.3%
P99 TPOT (ms) 3.2338 3.2305 -0.1% 4.1242 4.1888 +1.6% 7.0481 7.3187 +3.8%
P99 ITL (ms) 6.7988 6.7734 -0.4% 10.6903 11.9196 +11.5% 23.1690 26.5182 +14.5%
Spec Decode Accept Rate (%) 43.07 43.07 +0.0% 42.54 42.56 +0.0% 42.38 42.27 -0.3%
Spec Decode Accept Length 2.2921 2.2921 +0.0% 2.2763 2.2768 +0.0% 2.2715 2.2682 -0.1%

Future Work

It's unfortunate that we now lose the full cudagraph captured loop for the last N-1 draft tokens, but these changes were necessary for correctness. It may be possible to preserve the full cudagraph captured loop by introducing a new method to the AttentionMetadataBuilder API similar to update_block_table that allows updating the attention metadata in-place assuming single-token-decodes and that only the sequence position has changed since the last rebuild. With those assumptions, one could implement cudagraph-safe updates on a per-backend basis.

@mergify mergify Bot added the v1 label Apr 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the EagleSpeculator to support multi-step decoding with CUDA Graphs, introducing a multi_step_decode method and splitting draft generation logic. It also includes optimizations for metadata building and fixes typos in kernel naming. Feedback focuses on two critical areas: first, the reassignment of attn_metadata within the decode loop is incompatible with FULL CUDA graph execution because it fails to update the captured static inputs. Second, moving hidden state updates from the fused Triton kernel to a separate torch.copy_ call introduces additional kernel launch overhead, which may negatively impact latency in speculative decoding.

Comment thread vllm/v1/worker/gpu/spec_decode/eagle/speculator.py Outdated
Comment thread vllm/v1/worker/gpu/spec_decode/eagle/speculator.py Outdated
@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-draft-decode-rebuild-attn-metadata branch 4 times, most recently from bbe2722 to 3cd2adf Compare April 29, 2026 04:46
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 29, 2026
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review April 29, 2026 17:47
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Comment on lines +389 to +395
# Update the inputs for the next step.
update_eagle_draft_inputs(
num_reqs,
self.decode_draft_tokens[:num_reqs],
self.input_buffers,
self.max_model_len,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we skip this for the very last decode step?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that we can't pass the current draft step into generate_draft during FULL cudagraph, so it can't know when to skip or not.

While it results in an extra update_eagle_draft_inputs call for the last draft token, we at least don't pay for the kernel launch overhead. If we moved it outside of generate_draft into the loop and skipped it on the last step, we'd pay for the kernel launch overhead on every step except the last.

@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-draft-decode-rebuild-attn-metadata branch from 3cd2adf to d3fc140 Compare May 1, 2026 00:41
@TheEpicDolphin TheEpicDolphin marked this pull request as draft May 1, 2026 01:06
@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-draft-decode-rebuild-attn-metadata branch 2 times, most recently from f960d1b to af1131f Compare May 1, 2026 02:38
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the mrv2-draft-decode-rebuild-attn-metadata branch from af1131f to 4827469 Compare May 1, 2026 04:50
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review May 1, 2026 04:50
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

Was able to put the current draft step in a scalar tensor, and use it during the CG-captured generate_draft to:

  • Copy to self.draft_tokens in place
  • Copy to self.draft_logits in place
  • Early out during update_eagle_draft_inputs for the last speculative step

No more temporary buffers are needed. cc: @WoosukKwon

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@WoosukKwon WoosukKwon enabled auto-merge (squash) May 4, 2026 22:51
@WoosukKwon WoosukKwon merged commit e1e4646 into vllm-project:main May 5, 2026
63 checks passed
@TheEpicDolphin TheEpicDolphin deleted the mrv2-draft-decode-rebuild-attn-metadata branch May 5, 2026 00:45
chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request May 6, 2026
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…llm-project#41162)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request May 7, 2026
…llm-project#41162)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants