Skip to content

[Bugfix] Fix block size used in EAGLE slot mapping#31540

Merged
vllm-bot merged 2 commits intovllm-project:mainfrom
CentML:fix-eagle-block-size
Jan 2, 2026
Merged

[Bugfix] Fix block size used in EAGLE slot mapping#31540
vllm-bot merged 2 commits intovllm-project:mainfrom
CentML:fix-eagle-block-size

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Dec 30, 2025

Some hybrid models use a different block size for the linear and attention layers. In such cases, vllm_config.cache_config.block_size is not the same as the block size needed for EAGLE.

This leads to subtle issues in the slot mapping computed by EAGLE, and causes crashes for long-context or high-concurrency.

The fix is to get the block size from the specification of the kv cache group being used.

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett added the bug Something isn't working label Dec 30, 2025
@benchislett benchislett requested review from LucasWilkinson, WoosukKwon and Copilot and removed request for luccafong December 30, 2025 18:33
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the EAGLE speculative decoding implementation where a globally configured block size was used for slot mapping, causing crashes with hybrid models that use varying block sizes. The fix correctly retrieves the block size from the kv_cache_spec of the relevant attention metadata builder within the propose and propose_tree methods. This change is well-targeted, logical, and significantly improves the robustness and correctness of the EAGLE implementation for a wider range of models. The code is now more reliable as it no longer depends on a potentially incorrect global configuration. The changes are approved.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a bug in EAGLE's slot mapping computation for hybrid models that use different block sizes for linear and attention layers. The fix retrieves the block size from the KV cache specification instead of using the global cache config, preventing crashes in long-context or high-concurrency scenarios.

Key Changes:

  • Removed the self.block_size instance variable from the EAGLE class initialization
  • Updated slot mapping calculations to use the block size from kv_cache_spec at runtime
  • Applied the fix to both propose() and propose_tree() methods

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you.

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 30, 2025
@benchislett benchislett enabled auto-merge (squash) December 30, 2025 19:37
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me. @benchislett have you seen this improve Qwen3-Next results?

@benchislett
Copy link
Collaborator Author

This bugfix came about when fixing issues in a WIP branch for specdec support of NVIDIA Nemotron Nano V3 + EAGLE. I'm not sure if Qwen3-Next uses a different block size for its GDN attention, so it might or might not manifest this issue. I haven't checked.

@benchislett
Copy link
Collaborator Author

#31186 remains open, but I can rerun it later this week to check if this fix applies. Given that the crash in that issue varies based on cuda graph mode, I'm skeptical that this fix would resolve that issue.

@benchislett
Copy link
Collaborator Author

benchislett commented Dec 30, 2025

Update, does not fix #31186. It seems that Qwen3-Next actually uses the larger block size when allocating KV-Cache for the MTP module, and in that case attn_metadata_builder.kv_cache_spec.block_size == vllm_config.cache_config.block_size == 560.

Can still reproduce the crash with full graphs. Seems unrelated, or at least not completely resolved yet.

@vllm-bot vllm-bot merged commit ea53ca5 into vllm-project:main Jan 2, 2026
41 of 44 checks passed
@vadiklyutiy vadiklyutiy self-requested a review January 5, 2026 13:17
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 6, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants