[Bugfix] Fix block size used in EAGLE slot mapping#31540
[Bugfix] Fix block size used in EAGLE slot mapping#31540vllm-bot merged 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug in the EAGLE speculative decoding implementation where a globally configured block size was used for slot mapping, causing crashes with hybrid models that use varying block sizes. The fix correctly retrieves the block size from the kv_cache_spec of the relevant attention metadata builder within the propose and propose_tree methods. This change is well-targeted, logical, and significantly improves the robustness and correctness of the EAGLE implementation for a wider range of models. The code is now more reliable as it no longer depends on a potentially incorrect global configuration. The changes are approved.
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug in EAGLE's slot mapping computation for hybrid models that use different block sizes for linear and attention layers. The fix retrieves the block size from the KV cache specification instead of using the global cache config, preventing crashes in long-context or high-concurrency scenarios.
Key Changes:
- Removed the
self.block_sizeinstance variable from the EAGLE class initialization - Updated slot mapping calculations to use the block size from
kv_cache_specat runtime - Applied the fix to both
propose()andpropose_tree()methods
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mgoin
left a comment
There was a problem hiding this comment.
Seems reasonable to me. @benchislett have you seen this improve Qwen3-Next results?
|
This bugfix came about when fixing issues in a WIP branch for specdec support of NVIDIA Nemotron Nano V3 + EAGLE. I'm not sure if Qwen3-Next uses a different block size for its GDN attention, so it might or might not manifest this issue. I haven't checked. |
|
#31186 remains open, but I can rerun it later this week to check if this fix applies. Given that the crash in that issue varies based on cuda graph mode, I'm skeptical that this fix would resolve that issue. |
|
Update, does not fix #31186. It seems that Qwen3-Next actually uses the larger block size when allocating KV-Cache for the MTP module, and in that case Can still reproduce the crash with full graphs. Seems unrelated, or at least not completely resolved yet. |
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
### What this PR does / why we need it? Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e) 1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since vllm-project/vllm#31517 deleted unused arg 2. Remove dense `Qwen/Qwen3-0.6B` in `tests/e2e/multicard/test_aclgraph_capture_replay.py` and `tests/e2e/multicard/test_data_parallel.py` due to vllm-project/vllm#30739 where offline data parallel mode will not be supported/useful for dense models 3. Adapt `vllm_ascend/worker/worker.py` due to vllm-project/vllm#31584 4. Adapt `self.block_size` calling due to vllm-project/vllm#31540 5. Modify `test_mla_v1.py` due to vllm-project/vllm#28454 , which refactorred `get_head_size()` ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: wjunLu <wjunlu217@gmail.com>
Some hybrid models use a different block size for the linear and attention layers. In such cases,
vllm_config.cache_config.block_sizeis not the same as the block size needed for EAGLE.This leads to subtle issues in the slot mapping computed by EAGLE, and causes crashes for long-context or high-concurrency.
The fix is to get the block size from the specification of the kv cache group being used.