Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,45 @@ def test_compile_sizes_padding_validation():
assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise


@pytest.mark.parametrize(
"capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
[
# Normal capping: sizes filtered to <= num_blocks
(
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
512,
200,
[1, 2, 4, 8, 16, 32, 64, 128],
128,
),
# No capping needed: num_blocks >= max
([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
# Exact boundary: num_blocks == max (no capping)
([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
# All sizes capped: num_blocks < smallest size
([8, 16, 32], 32, 4, [], 0),
# num_blocks <= 0: early return, no change
([1, 2, 4], 4, 0, [1, 2, 4], 4),
],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
"""Test that cudagraph capture sizes are correctly capped to fit
available Mamba cache blocks.

See: https://github.com/vllm-project/vllm/issues/34094
"""
config = CompilationConfig(
cudagraph_capture_sizes=capture_sizes,
max_cudagraph_capture_size=max_size,
cudagraph_mode=CUDAGraphMode.NONE,
)
config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
assert config.cudagraph_capture_sizes == expected_sizes
assert config.max_cudagraph_capture_size == expected_max
# Invariant: last element == max_cudagraph_capture_size
if expected_sizes:
assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size
120 changes: 120 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,3 +1199,123 @@ def test_is_uniform_decode() -> None:
num_reqs=15,
force_uniform_decode=False,
)


@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_cudagraph_sizes_capped_for_mamba_cache():
"""Test that cudagraph capture sizes are capped to num_blocks for
hybrid models with Mamba layers.

See: https://github.com/vllm-project/vllm/issues/34094
"""
set_random_seed(42)

update_environment_variables(
{
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
from tests.utils import ensure_current_vllm_config

with ensure_current_vllm_config():
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)

model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
attention_config=attention_config,
)

with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in ["model.layers.0.self_attn.attn", "model.layers.1.self_attn.attn"]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [
"model.layers.2.mixer",
"model.layers.3.mixer",
"model.layers.4.mixer",
"model.layers.5.mixer",
]:
fwd_context[key] = MambaMixer2(
hidden_size=hf_config.hidden_size,
ssm_state_size=hf_config.mamba_d_state,
conv_kernel_size=hf_config.mamba_d_conv,
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
use_conv_bias=hf_config.mamba_conv_bias,
use_bias=hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
assert fwd_context is not None

runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()

available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_spec], [available_memory]
)[0]
num_blocks = kv_cache_config.num_blocks

# Set max_cudagraph_capture_size to a value larger than num_blocks
# to trigger the Mamba capping logic.
large_max = num_blocks + 100
compilation_config = vllm_config.compilation_config
compilation_config.max_cudagraph_capture_size = large_max
compilation_config.cudagraph_capture_sizes = [
s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max
]

runner.initialize_kv_cache(kv_cache_config)

# After initialization, cudagraph sizes should be capped
assert compilation_config.max_cudagraph_capture_size <= num_blocks
assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes)
# Invariant: last element == max
if compilation_config.cudagraph_capture_sizes:
assert (
compilation_config.cudagraph_capture_sizes[-1]
== compilation_config.max_cudagraph_capture_size
)
52 changes: 52 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,58 @@ def adjust_cudagraph_sizes_for_spec_decode(
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes

def adjust_cudagraph_sizes_for_mamba_cache(
self, num_mamba_cache_blocks: int
) -> None:
"""Cap cudagraph capture sizes to available Mamba cache blocks.

For hybrid Mamba/attention models, the Mamba conv_state and
ssm_state tensors have their first dimension equal to num_blocks
(from KVCacheConfig). During CUDA graph capture the decode batch
size equals num_tokens, so capture sizes exceeding num_blocks
would cause out-of-bounds access in Mamba kernels.

See: https://github.com/vllm-project/vllm/issues/34094
"""
if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0:
return

assert self.max_cudagraph_capture_size is not None

if num_mamba_cache_blocks >= self.max_cudagraph_capture_size:
return

capped_sizes = [
s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks
]

if len(capped_sizes) == 0:
logger.warning(
"No valid cudagraph capture sizes remain after capping "
"to Mamba cache blocks (%d). The smallest capture size "
"was %d. Disabling cudagraph capture. Consider reducing "
"max_num_seqs or increasing available GPU memory.",
num_mamba_cache_blocks,
self.cudagraph_capture_sizes[0],
)
self.cudagraph_capture_sizes = []
self.max_cudagraph_capture_size = 0
return

logger.warning(
"Capping cudagraph capture sizes from max %d to %d to fit "
"Mamba cache blocks (%d blocks available). This limits the "
"maximum batch size that can use CUDA graphs. To increase "
"this limit, reduce max_num_seqs or increase available GPU "
"memory.",
self.max_cudagraph_capture_size,
capped_sizes[-1],
num_mamba_cache_blocks,
)

self.max_cudagraph_capture_size = capped_sizes[-1]
self.cudagraph_capture_sizes = capped_sizes

def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
Expand Down
16 changes: 16 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5700,6 +5700,22 @@ def _check_and_update_cudagraph_mode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
)

# If the model has Mamba layers and cudagraph mode includes FULL
# decode, cap cudagraph capture sizes to the number of available
# Mamba cache blocks. Each decode request needs one conv_state
# cache line, so capture batch sizes cannot exceed num_blocks.
# Only FULL decode graphs are affected because PIECEWISE captures
# run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update).
# See: https://github.com/vllm-project/vllm/issues/34094
if cudagraph_mode.has_full_cudagraphs():
has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
)
if has_mamba and self.kv_cache_config is not None:
self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache(
self.kv_cache_config.num_blocks
)

# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
self.compilation_config.cudagraph_mode = cudagraph_mode
Expand Down