Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5988,6 +5988,58 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups(
else:
break

def _maybe_limit_cudagraph_sizes_by_num_blocks(self, num_blocks: int) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be in GPU model runner?

"""
Limit cudagraph capture sizes based on num_blocks to prevent
assertion errors in GDN models where num_cache_lines (num_blocks)
can be smaller than the cudagraph capture batch size.

This is only applied when:
1. CUDAGraphMode > PIECEWISE (i.e., FULL or FULL_AND_PIECEWISE modes)
2. The model uses GDN attention (detected by GDN_ATTN backend)

Args:
num_blocks: The number of available KV cache blocks.
"""
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode is None or not cudagraph_mode.has_full_cudagraphs():
return

max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if max_cudagraph_size is None or max_cudagraph_size <= num_blocks:
return

# Check if model uses GDN attention
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)

attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
has_gdn = any(
layer.get_attn_backend().get_name() == "GDN_ATTN"
for layer in attn_layers.values()
)
if not has_gdn:
return
Comment on lines +6021 to +6026
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a specific change to GDN? I think we need to do this for all hybrid models actually.


original_sizes = self.compilation_config.cudagraph_capture_sizes or []
filtered_sizes = [s for s in original_sizes if s <= num_blocks]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this is empty?

self.compilation_config.cudagraph_capture_sizes = filtered_sizes
# Set max_cudagraph_capture_size to the max of filtered sizes
new_max = max(filtered_sizes) if filtered_sizes else num_blocks

logger.warning(
"Limiting max_cudagraph_capture_size from %d to %d "
"due to num_blocks=%d constraint for GDN model",
max_cudagraph_size,
new_max,
num_blocks,
)
self.compilation_config.max_cudagraph_capture_size = new_max

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Expand All @@ -5999,6 +6051,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.kv_cache_config = kv_cache_config
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)

# Limit cudagraph sizes for GDN models to prevent assertion errors
# when num_cache_lines (num_blocks) < batch size during cudagraph capture.
self._maybe_limit_cudagraph_sizes_by_num_blocks(kv_cache_config.num_blocks)

self.initialize_attn_backend(kv_cache_config)
# The kernel block size for all KV cache groups. For example, if
# kv_cache_manager uses block_size 256 for a given group, but the attention
Expand Down