[Bugfix] limit cudagraph capture sizes by num_blocks for GDN models#34881
[Bugfix] limit cudagraph capture sizes by num_blocks for GDN models#34881ZJY0516 wants to merge 7 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses an assertion error by improving the validation logic in causal_conv1d_update. The new assertion correctly checks that all conv_state_indices are within the valid range of num_cache_lines. However, I've identified a critical edge case in the new code. When the batch size is zero, conv_state_indices.max() will be called on an empty tensor, causing a RuntimeError. I've provided a suggestion to fix this.
tdoublep
left a comment
There was a problem hiding this comment.
I don't really understand this change. If the assert is duplicated - why doesn't the first one fail?
This check seems like a correct thing to have imo.
Because |
|
OK, but then we read the batch size from the |
Yes, you are right. I'll try to fix it in another way |
|
Hi @ZJY0516, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| has_gdn = any( | ||
| layer.get_attn_backend().get_name() == "GDN_ATTN" | ||
| for layer in attn_layers.values() | ||
| ) | ||
| if not has_gdn: | ||
| return |
There was a problem hiding this comment.
Why is this a specific change to GDN? I think we need to do this for all hybrid models actually.
| return | ||
|
|
||
| original_sizes = self.compilation_config.cudagraph_capture_sizes or [] | ||
| filtered_sizes = [s for s in original_sizes if s <= num_blocks] |
There was a problem hiding this comment.
What happens if this is empty?
| else: | ||
| break | ||
|
|
||
| def _maybe_limit_cudagraph_sizes_by_num_blocks(self, num_blocks: int) -> None: |
There was a problem hiding this comment.
Does this need to be in GPU model runner?
Purpose
FIX #34094
Fixes a corner case bug where vLLM crashes with an AssertionError during CUDA graph capture for GDN (Gated Delta Net) models when the configured cudagraph capture size exceeds the available KV cache blocks (num_blocks).
Problem
For GDN models using CUDA graphs, vLLM determines the cudagraph capture sizes at config initialization time based on
max_num_seqs. However, the actual number of KV cache blocks (num_blocks) is determined later during memory profiling and can be smaller than the cudagraph capture size due to memory constraints.When this happens, during
causal_conv1d_update, the following assertion fails becausenum_cache_lines(which equalsnum_blocks) is smaller than the batch size:The GDN attention backend uses num_blocks as cache lines for conv states. When CUDA graph capture creates batches larger than num_blocks, the conv state cache doesn't have enough slots, triggering the assertion.
Solution
Add
_maybe_limit_cudagraph_sizes_by_num_blocks()inGPUModelRunner.initialize_kv_cache()to:Check if the model uses GDN attention (by detecting
GDN_ATTNbackend)Check if
CUDAGraphModehas FULL mode enabled (only affectsFULL/FULL_AND_PIECEWISEmodes)If
max_cudagraph_capture_size > num_blocks, filtercudagraph_capture_sizesto onlyinclude sizes ≤ num_blocksand updatemax_cudagraph_capture_sizeaccordinglyThis ensures CUDA graph capture never attempts batch sizes larger than available cache lines for GDN models.
Test
log
cc @tdoublep @ywang96
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.