Skip to content

[Bugfix][ptd_eagle] Fix buffer overflow in PTD EAGLE speculative deco…#1

Merged
laviier merged 1 commit intopard-prfrom
ptd-memory-fix
Jan 22, 2026
Merged

[Bugfix][ptd_eagle] Fix buffer overflow in PTD EAGLE speculative deco…#1
laviier merged 1 commit intopard-prfrom
ptd-memory-fix

Conversation

@laviier
Copy link
Copy Markdown
Collaborator

@laviier laviier commented Jan 22, 2026

Previous bug: when running vllm bench serve --backend vllm --served-model-name gpt-oss-120b --endpoint /v1/completions --dataset-name random --random-input-len 1600 --random-output-len 600 --num-prompts 100, server would crash and throw error

(EngineCore_DP0 pid=12758)   File "<eval_with_key>.381", line 44, in forward
(EngineCore_DP0 pid=12758)     submod_0 = self.submod_0(l_input_ids_, s72, l_self_modules_embed_tokens_parameters_weight_, l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_hidden_norm_parameters_weight_, l_hidden_states_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_bias_, l_positions_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, s47);  l_input_ids_ = l_self_modules_embed_tokens_parameters_weight_ = l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_0_modules_hidden_norm_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_bias_ = None
(EngineCore_DP0 pid=12758)                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=12758)   File "/home/ubuntu/github/vllm/vllm/compilation/cuda_graph.py", line 222, in __call__
(EngineCore_DP0 pid=12758)     return self.runnable(*args, **kwargs)
(EngineCore_DP0 pid=12758)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=12758)   File "/home/ubuntu/github/vllm/vllm/compilation/piecewise_backend.py", line 186, in __call__
(EngineCore_DP0 pid=12758)     assert range_entry is not None, (
(EngineCore_DP0 pid=12758)            ^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=12758) AssertionError: Shape: 8213 out of considered ranges: [(1, 8192)]

Parallel draft methods (PTD EAGLE) generate K draft tokens in a single forward pass using mask tokens, which requires larger buffers than sequential drafting. The inherited buffer allocation formula was insufficient, causing crashes under load.

Bug manifestation:

  • Sequential EAGLE: needs max_num_batched_tokens + max_num_seqs tokens
  • Parallel draft: needs max_num_batched_tokens + max_num_seqs * num_speculative_tokens tokens
  • Error: "AssertionError: Shape: 8213 out of considered ranges: [(1, 8192)]"

This fix addresses three critical issues:

  1. Buffer Allocation (ptd_eagle.py):

    • Corrects max_num_tokens formula for parallel draft generation pattern
    • Reallocates all buffers (input_ids, positions, hidden_states, slot_buffer)
    • Adds ~6MB memory overhead (negligible for 3-4x speedup)
  2. Compilation Ranges (vllm.py):

    • Extends compile_ranges_split_points when parallel_draft=True
    • Ensures CUDA graph compilation handles expanded token counts
    • Adds informative logging for parallel draft detection

The bug was caught during benchmarking with 100 prompts (1600 input, 600 output tokens) where batch size reached 8192 tokens + 7 requests * 3 masks = 8213 tokens, exceeding the compilation range of 8192.

Tested-by: Load testing with max batch size configurations

@laviier laviier requested a review from hai-meh-cs January 22, 2026 18:09
@laviier laviier self-assigned this Jan 22, 2026
@laviier laviier added the bug Something isn't working label Jan 22, 2026
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

…ding

Parallel draft methods (PTD EAGLE) generate K draft tokens in a single
forward pass using mask tokens, which requires larger buffers than sequential
drafting. The inherited buffer allocation formula was insufficient, causing
crashes under load.

Bug manifestation:
- Sequential EAGLE: needs max_num_batched_tokens + max_num_seqs tokens
- Parallel draft: needs max_num_batched_tokens + max_num_seqs * num_speculative_tokens tokens
- Error: "AssertionError: Shape: 8213 out of considered ranges: [(1, 8192)]"

This fix addresses three critical issues:

1. Buffer Allocation (ptd_eagle.py):
   - Corrects max_num_tokens formula for parallel draft generation pattern
   - Reallocates all buffers (input_ids, positions, hidden_states, slot_buffer)
   - Adds ~6MB memory overhead (negligible for 3-4x speedup)

2. Compilation Ranges (vllm.py):
   - Extends compile_ranges_split_points when parallel_draft=True
   - Ensures CUDA graph compilation handles expanded token counts
   - Adds informative logging for parallel draft detection

The bug was caught during benchmarking with 100 prompts (1600 input, 600 output
tokens) where batch size reached 8192 tokens + 7 requests * 3 masks = 8213 tokens,
exceeding the compilation range of 8192.

Tested-by: Load testing with max batch size configurations

Signed-off-by: Li Zhang <lzhanga@amazon.com>

Simplify updates to eagle files

Signed-off-by: Li Zhang <lzhanga@amazon.com>

Minor format updates
@laviier laviier merged commit 9a4d148 into pard-pr Jan 22, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants