Skip to content

Conversation

@zixi-qi
Copy link
Collaborator

@zixi-qi zixi-qi commented Aug 12, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Fix llama4 spec decoding path where draft model uses rope + global attention, under the current code ChunkedLocalAttention would be used which would result in a crash.

Test Plan

CUDA_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python examples/offline_inference/spec_decode.py  --num_spec_tokens 7 --num_prompts 1 --method eagle --model_dir /home/qizixi/models/llama4_scout/Llama-4-Scout-17B-16E-Instruct --eagle_dir /home/qizixi/models/llama4_scout/scout_draft_HF_20250605_202942 --tp 4

Test Result

  • before the fix
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596] Traceback (most recent call last):
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/executor/multiproc_executor.py", line 591, in worker_busy_loop
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     output = func(*args, **kwargs)
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]              ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/uv_env/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     return func(*args, **kwargs)
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/worker/gpu_worker.py", line 367, in execute_model
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     output = self.model_runner.execute_model(scheduler_output,
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/uv_env/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     return func(*args, **kwargs)
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/worker/gpu_model_runner.py", line 1504, in execute_model
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     self._prepare_inputs(scheduler_output))
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/worker/gpu_model_runner.py", line 885, in _prepare_inputs
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     attn_metadata_i = (builder.build(
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]                        ^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/attention/backends/utils.py", line 554, in build
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     build_preprocess_fn(common_attn_metadata),
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/attention/layers/chunked_local_attention.py", line 28, in build_preprocess_fn
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     return make_local_attention_virtual_batches(attention_chunk_size, cm,
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]   File "/home/qizixi/qizixi/vllm/vllm/v1/attention/backends/utils.py", line 429, in make_local_attention_virtual_batches
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]     attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596]                        ~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
(VllmWorker TP0 pid=2656884) ERROR 08-11 18:07:57 [multiproc_executor.py:596] TypeError: unsupported operand type(s) for %: 'int' and 'NoneType'
  • after the fix
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2442.81it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.54s/it, est. speed input: 240.16 toks/s, output: 72.42 toks/s]
--------------------------------------------------
total_num_output_tokens: 256
num_drafts: 94
num_draft_tokens: 658
num_accepted_tokens: 161
mean acceptance length: 2.71
--------------------------------------------------
acceptance at token 0: 0.89
acceptance at token 1: 0.82
acceptance at token 2: 0.00
acceptance at token 3: 0.00
acceptance at token 4: 0.00
acceptance at token 5: 0.00
acceptance at token 6: 0.00

(Optional) Documentation Update

@mergify mergify bot added the llama Related to Llama models label Aug 12, 2025
@zixi-qi zixi-qi marked this pull request as ready for review August 12, 2025 02:16
@facebook-github-bot
Copy link

@zixi-qi has imported this pull request. If you are a Meta employee, you can view this in D80059459.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash in the Llama4 speculative decoding path. The issue occurred when a model used Rotary Position Embeddings (RoPE) with global attention, where the code would incorrectly attempt to use ChunkedLocalAttention without a valid attention_chunk_size, leading to a TypeError. The fix correctly introduces a check for config.attention_chunk_size before selecting the attention mechanism. This ensures that ChunkedLocalAttention is only used when explicitly configured, falling back to the standard Attention class otherwise. The change is well-implemented, directly solves the reported crash, and improves the robustness of the model's attention mechanism selection. The provided test results confirm the fix is effective.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luccafong Could you please review?

Copy link
Collaborator

@luccafong luccafong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing!

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 18, 2025
@houseroad houseroad enabled auto-merge (squash) August 18, 2025 07:05
@houseroad houseroad merged commit 5bfe0de into vllm-project:main Aug 19, 2025
40 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request Aug 20, 2025
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
cyang49 pushed a commit to cyang49/vllm that referenced this pull request Aug 20, 2025
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
Signed-off-by: qizixi <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: qizixi <[email protected]>
Co-authored-by: Lu Fang <[email protected]>
Signed-off-by: Xiao Yu <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants