Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,3 +1321,189 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():

with pytest.raises(ValueError, match="max_num_seqs"):
runner.initialize_kv_cache(kv_cache_config)


def test_v2_reshape_kv_cache_hybrid_attention_mamba():
"""Test V2 model runner's _reshape_kv_cache with mixed AttentionSpec
and MambaSpec, including virtual block splitting and hybrid layout
adjustment.

This is a regression test for the V2 model runner's handling of hybrid
attention models (e.g. Qwen3.5) where attention layers produce
AttentionSpec and linear attention (Gated DeltaNet) layers produce
MambaSpec.
"""
from unittest.mock import MagicMock

from vllm.v1.kv_cache_interface import (
KVCacheTensor,
MambaSpec,
)
from vllm.v1.worker.gpu.attn_utils import (
_reshape_kv_cache,
_update_hybrid_attention_mamba_layout,
)

# Configuration
num_kv_heads = 4
head_size = 8
dtype = torch.float16
# dtype_size = 2 (float16), not used directly but documents layout

# KV manager block_size = 32 (large, not power of 2 for kernel)
# Kernel block_size = 16 (FlashInfer-compatible)
kv_manager_block_size = 32
kernel_block_size = 16
num_blocks = 4

# --- Build AttentionSpec ---
attn_spec = FullAttentionSpec(
block_size=kv_manager_block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
)
attn_page_size = attn_spec.page_size_bytes
# 2 * 32 * 4 * 8 * 2 = 4096 bytes

# --- Build MambaSpec ---
# Two state tensors: conv (d_conv, d_inner) and ssm (d_state, d_inner)
conv_shape = (4, 16) # (d_conv, d_inner)
ssm_shape = (8, 16) # (d_state, d_inner)
mamba_spec = MambaSpec(
block_size=kv_manager_block_size,
shapes=(conv_shape, ssm_shape),
dtypes=(dtype, dtype),
# Pad to match attention page size for hybrid models
page_size_padded=attn_page_size,
)

# --- Build KVCacheConfig ---
attn_layer = "model.layers.0.self_attn.attn"
mamba_layer = "model.layers.1.mixer"

attn_tensor_size = attn_page_size * num_blocks
mamba_tensor_size = mamba_spec.page_size_bytes * num_blocks

kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(size=attn_tensor_size, shared_by=[attn_layer]),
KVCacheTensor(size=mamba_tensor_size, shared_by=[mamba_layer]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=[attn_layer], kv_cache_spec=attn_spec),
KVCacheGroupSpec(layer_names=[mamba_layer], kv_cache_spec=mamba_spec),
],
)

# --- Allocate raw tensors ---
device = torch.device("cpu")
kv_cache_raw_tensors = {
attn_layer: torch.zeros(attn_tensor_size, dtype=torch.int8, device=device),
mamba_layer: torch.zeros(mamba_tensor_size, dtype=torch.int8, device=device),
}

# --- Mock attention backend ---
# FlashInfer-like: shape = (num_blocks, 2, block_size, num_kv_heads,
# head_size)
mock_backend = MagicMock()

def mock_get_kv_cache_shape(n_blocks, block_sz, n_kv_heads, h_size, cache_dtype):
return (n_blocks, 2, block_sz, n_kv_heads, h_size)

mock_backend.get_kv_cache_shape = mock_get_kv_cache_shape
mock_backend.get_kv_cache_stride_order.side_effect = AttributeError

attn_backends = {attn_layer: mock_backend}
kernel_block_sizes = [kernel_block_size, kv_manager_block_size]

# --- Call _reshape_kv_cache ---
kv_caches = _reshape_kv_cache(
kv_cache_config,
kv_cache_raw_tensors,
attn_backends,
"auto",
kernel_block_sizes,
)

# --- Verify attention layer ---
attn_cache = kv_caches[attn_layer]
assert isinstance(attn_cache, torch.Tensor)

# Virtual block splitting: 4 KV blocks * (32/16) = 8 kernel blocks
expected_kernel_num_blocks = num_blocks * (
kv_manager_block_size // kernel_block_size
)
assert attn_cache.shape[0] == expected_kernel_num_blocks # 8
assert attn_cache.shape[1] == 2 # K and V
assert attn_cache.shape[2] == kernel_block_size # 16
assert attn_cache.shape[3] == num_kv_heads # 4
assert attn_cache.shape[4] == head_size # 8

# --- Verify hybrid layout adjustment ---
# In hybrid mode, _update_hybrid_attention_mamba_layout should have been
# called. For FlashInfer-like backends where shape[0] == 2 (kv_dim),
# the stride adjustment converts (2, num_blocks, ...) to
# (num_blocks, 2, ...).
# Since our mock returns shape (8, 2, 16, 4, 8) where shape[0]=8 != 2,
# the hybrid layout adjustment should NOT fire (it only fires when
# shape[0] == 2).

# --- Verify mamba layer ---
mamba_cache = kv_caches[mamba_layer]
assert isinstance(mamba_cache, list)
assert len(mamba_cache) == 2 # conv and ssm state tensors

conv_tensor = mamba_cache[0]
ssm_tensor = mamba_cache[1]

assert conv_tensor.shape == (num_blocks, *conv_shape)
assert ssm_tensor.shape == (num_blocks, *ssm_shape)
assert conv_tensor.dtype == dtype
assert ssm_tensor.dtype == dtype

# --- Verify data isolation ---
# Writing to mamba blocks should not corrupt attention blocks
conv_tensor.fill_(1.0)
ssm_tensor.fill_(2.0)

# Attention cache should still be zeros
assert torch.all(attn_cache == 0)

# Writing to attention cache should not corrupt mamba
attn_cache.fill_(3.0)
assert torch.all(conv_tensor == 1.0)
assert torch.all(ssm_tensor == 2.0)

# --- Directly test _update_hybrid_attention_mamba_layout ---
# Create a scenario where shape[0] == 2 (like FA backend with
# (2, num_blocks, block_size, num_kv_heads, head_size))
small_num_blocks = 4
raw = torch.arange(
2 * small_num_blocks * kernel_block_size * num_kv_heads * head_size,
dtype=dtype,
).reshape(2, small_num_blocks, kernel_block_size, num_kv_heads, head_size)
test_kv_caches: dict[str, torch.Tensor] = {
attn_layer: raw,
mamba_layer: [conv_tensor, ssm_tensor], # type: ignore[assignment]
}
_update_hybrid_attention_mamba_layout(kv_cache_config, test_kv_caches)

updated = test_kv_caches[attn_layer]
assert isinstance(updated, torch.Tensor)
# Shape should be unchanged
assert updated.shape == (
2,
small_num_blocks,
kernel_block_size,
num_kv_heads,
head_size,
)
# But strides should be adjusted: dim0 stride should be
# hidden_size (= block_size * num_kv_heads * head_size)
hidden_size = kernel_block_size * num_kv_heads * head_size
assert updated.stride()[0] == hidden_size
assert updated.stride()[1] == 2 * hidden_size
# Data is the same (just reinterpreted via strides)
assert updated.data_ptr() == raw.data_ptr()
Loading
Loading