Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1787b87
map more fullattn layers to one block
peakcrosser7 Mar 1, 2026
954db4a
add test_get_kv_cache_configs_with_mamba
peakcrosser7 Mar 2, 2026
75fe78e
fix padding layers
peakcrosser7 Mar 2, 2026
5817d4d
fix padding layers
peakcrosser7 Mar 2, 2026
2d6eefc
add test
peakcrosser7 Mar 2, 2026
a365a26
change group_size to property
peakcrosser7 Mar 3, 2026
c39fd7f
fix
peakcrosser7 Mar 3, 2026
5e54f79
restrict mamba_num_attn_pages to hybrid model only
peakcrosser7 Mar 3, 2026
4eae237
remove num_attn_pages in Attention
peakcrosser7 Mar 4, 2026
a96611e
move group_size to KVCacheSpec
peakcrosser7 Mar 4, 2026
6c6d68f
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Mar 5, 2026
f553c16
Revert "move group_size to KVCacheSpec"
peakcrosser7 Mar 7, 2026
09c429e
move group_size to AttentionSpec
peakcrosser7 Mar 7, 2026
e1dc529
refactor with merge and split layers
peakcrosser7 Mar 8, 2026
49c58f7
remove group_size in attention
peakcrosser7 Mar 9, 2026
632c5e9
update test
peakcrosser7 Mar 9, 2026
fa3dcc4
add tests
peakcrosser7 Mar 9, 2026
5bf8e67
refactor attn layout reshape
peakcrosser7 Mar 10, 2026
b978949
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Mar 15, 2026
372fc18
update after merged
peakcrosser7 Mar 15, 2026
4ebdc5a
remove _update_hybrid_attention_mamba_layout
peakcrosser7 Mar 21, 2026
0c97f85
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Mar 21, 2026
f307aa4
move get_hybrid_attention_mamba_layout to mamba_utils
peakcrosser7 Mar 21, 2026
6114cd8
revert blank line
peakcrosser7 Mar 21, 2026
938c092
rename group_size to pack_size
peakcrosser7 Mar 21, 2026
263363a
fix the test
peakcrosser7 Mar 21, 2026
fea9f3e
add comments
peakcrosser7 Mar 21, 2026
920ecc1
fix
peakcrosser7 Mar 21, 2026
54eccae
fix the name of tests
peakcrosser7 Mar 21, 2026
cda00e2
udpate comments
peakcrosser7 Mar 21, 2026
dc1b08b
add test_hybrid_attention_mamba_kv_cache_pack_size
peakcrosser7 Mar 24, 2026
e8f4c53
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Mar 25, 2026
6b1104d
fix block_size without packed in profiling
peakcrosser7 Mar 29, 2026
e8db555
rename mamba_num_attn_pages to attn_pack_size
peakcrosser7 Apr 4, 2026
3805bc3
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Apr 7, 2026
7a0e113
update after merge
peakcrosser7 Apr 7, 2026
5303953
rm _update_hybrid_attention_mamba_layout
peakcrosser7 Apr 9, 2026
144b3a3
add comment
peakcrosser7 Apr 9, 2026
024c600
add log
peakcrosser7 Apr 9, 2026
9b6d3c3
Merge branch 'main' into feat/multi_attn2mamba
peakcrosser7 Apr 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
init_none_hash,
is_kv_cache_spec_uniform,
make_block_hash_with_group_id,
merge_attn_layers_into_pack,
split_attn_layers_from_pack,
tensor_data,
)
from vllm.v1.kv_cache_interface import (
Expand Down Expand Up @@ -110,6 +112,7 @@ def new_kv_cache_spec(
page_size_padded=None,
sliding_window=None,
attention_chunk_size=None,
pack_size=1,
):
return FullAttentionSpec(
block_size=block_size,
Expand All @@ -119,6 +122,7 @@ def new_kv_cache_spec(
page_size_padded=page_size_padded,
sliding_window=sliding_window,
attention_chunk_size=attention_chunk_size,
pack_size=pack_size,
)


Expand Down Expand Up @@ -2137,3 +2141,194 @@ def test_unify_hybrid_kv_cache_specs():

with pytest.raises(ValueError):
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)


def test_merge_attn_layers_into_pack():
attn_pack_size = 2

hybrid_kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_kv_cache_spec(head_size=32),
"layer_3": new_kv_cache_spec(head_size=32),
}
merged_kv_cache_specs = merge_attn_layers_into_pack(
attn_pack_size, hybrid_kv_cache_specs
)
assert merged_kv_cache_specs == {
"layer_1": new_mamba_spec(),
"layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=2),
}

hybrid_kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_kv_cache_spec(head_size=32),
"layer_3": new_kv_cache_spec(head_size=32),
"layer_4": new_kv_cache_spec(head_size=32),
}
merged_kv_cache_specs = merge_attn_layers_into_pack(
attn_pack_size, hybrid_kv_cache_specs
)
assert merged_kv_cache_specs == {
"layer_1": new_mamba_spec(),
"layer_2+layer_3": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size),
"layer_4": new_kv_cache_spec(head_size=32, pack_size=attn_pack_size),
}


def test_split_attn_layers_from_pack():
attn_pack_size = 2
expected_page_size = new_mamba_spec().page_size_bytes

kv_cache_config = KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(
size=expected_page_size * 20,
shared_by=["layer_1", "layer_2+layer_3"],
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_mamba_spec()),
KVCacheGroupSpec(
["layer_2+layer_3"],
new_kv_cache_spec(head_size=32, pack_size=2),
),
],
)
split_kv_cache_config = split_attn_layers_from_pack(attn_pack_size, kv_cache_config)

assert split_kv_cache_config == KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(
size=expected_page_size * 20,
shared_by=["layer_1", "layer_2", "layer_3"],
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_mamba_spec()),
KVCacheGroupSpec(
["layer_2", "layer_3"],
new_kv_cache_spec(head_size=32, pack_size=attn_pack_size),
),
],
)


def test_get_kv_cache_configs_with_mamba():
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)

expected_page_size = new_mamba_spec().page_size_bytes

# Test 1: Pure mamba model (2 layers with same spec)
kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_mamba_spec(),
}
available_memory = expected_page_size * 2 * 10
kv_cache_config = get_kv_cache_configs(
vllm_config, [kv_cache_specs], [available_memory]
)[0]

assert kv_cache_config == KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_1"]),
KVCacheTensor(size=expected_page_size * 10, shared_by=["layer_2"]),
],
kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec())],
)

# Test 2: 1 mamba + 1 full
hybrid_kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_kv_cache_spec(),
}
available_memory_hybrid = expected_page_size * 2 * 10
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid]
)[0]

assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(
size=expected_page_size * 20, shared_by=["layer_1", "layer_2"]
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_mamba_spec()),
KVCacheGroupSpec(["layer_2"], new_kv_cache_spec()),
],
)

# Test 3: 1 mamba + 2 full attention with group size 2
vllm_config.cache_config.attn_pack_size = 2
hybrid_kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_kv_cache_spec(head_size=32),
"layer_3": new_kv_cache_spec(head_size=32),
}
available_memory_hybrid = expected_page_size * 2 * 10
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid]
)[0]

assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(
size=expected_page_size * 20,
shared_by=["layer_1", "layer_2", "layer_3"],
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_mamba_spec()),
KVCacheGroupSpec(
["layer_2", "layer_3"],
new_kv_cache_spec(head_size=32, pack_size=2),
),
],
)

# Test 4: 2 mamba + 5 full (with 3 padding full)
vllm_config.cache_config.attn_pack_size = 2
hybrid_kv_cache_specs = {
"layer_1": new_mamba_spec(),
"layer_2": new_mamba_spec(),
"layer_3": new_kv_cache_spec(head_size=32),
"layer_4": new_kv_cache_spec(head_size=32),
"layer_5": new_kv_cache_spec(head_size=32),
"layer_6": new_kv_cache_spec(head_size=32),
"layer_7": new_kv_cache_spec(head_size=32),
}
available_memory_hybrid = expected_page_size * 2 * 10
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [hybrid_kv_cache_specs], [available_memory_hybrid]
)[0]

assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(
size=expected_page_size * 10,
shared_by=["layer_1", "layer_3", "layer_4", "layer_5", "layer_6"],
),
KVCacheTensor(
size=expected_page_size * 10,
shared_by=["layer_2", "layer_7"],
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_mamba_spec()),
KVCacheGroupSpec(
["layer_3", "layer_4", "layer_7"],
new_kv_cache_spec(head_size=32, pack_size=2),
),
KVCacheGroupSpec(
["layer_5", "layer_6"],
new_kv_cache_spec(head_size=32, pack_size=2),
),
],
)
Loading
Loading