Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 235 additions & 4 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)

Expand Down Expand Up @@ -106,8 +107,23 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:


def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)

return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
Expand All @@ -123,16 +139,49 @@ def make_kv_cache_config_hybrid_model(
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
second_spec,
),
KVCacheGroupSpec(
["layer3"],
second_spec,
),
],
)


def make_kv_cache_config_three_types(
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
) -> KVCacheConfig:
if third_spec_type == "mamba":
third_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
elif third_spec_type == "sliding_window":
third_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
)

return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
),
KVCacheGroupSpec(
["layer3"],
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
Expand All @@ -141,6 +190,10 @@ def make_kv_cache_config_hybrid_model(
sliding_window=2 * block_size,
),
),
KVCacheGroupSpec(
["layer3"],
third_spec,
),
],
)

Expand Down Expand Up @@ -424,6 +477,184 @@ def test_partial_request_hit(
)


def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
"""
Create a KVCacheConfig with the specified spec types.

Args:
block_size: The block size for KV cache.
num_blocks: The number of blocks in the KV cache.
spec_types: List of spec type strings. Supported types:
- "full": FullAttentionSpec
- "sliding_window": SlidingWindowSpec with window=2*block_size
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
- "mamba": MambaSpec
"""
spec_map = {
"full": lambda: FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
"sliding_window": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
"sliding_window_large": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
),
"mamba": lambda: MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
),
}

kv_cache_groups = [
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
for i, spec_type in enumerate(spec_types)
]

return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)


# Test cases covering various combinations of KV cache spec types:
# - Varying number of groups (2, 3, or 4)
# - 0, 1, or 2 full attention groups
# - Sliding window with different window sizes
# - Interleaved group IDs (full attn and other types mixed)
# - Mamba spec combinations
_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
pytest.param(["full", "mamba"], id="2g-full+mamba"),
# 2 groups: 0 full (all other types)
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
# 3 groups: 1 full + 2 others (same type)
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
# 3 groups: 1 full + 2 others (different types)
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
pytest.param(
["full", "sliding_window", "sliding_window_large"],
id="3g-full+sw+sw_large",
),
# 3 groups: 2 full + 1 other
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
# 4 groups: interleaved (full, other, full, other)
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw+sw_large",
),
pytest.param(
["full", "mamba", "full", "mamba"],
id="4g-interleaved-full+mamba",
),
# 4 groups: interleaved with different sliding windows
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw_mixed",
),
# 4 groups: 0 full (all other types)
pytest.param(
["sliding_window", "mamba", "sliding_window_large", "mamba"],
id="4g-sw+mamba+sw_large+mamba",
),
# 4 groups: 2 full + 2 others (grouped)
pytest.param(
["full", "full", "sliding_window", "mamba"],
id="4g-2full+sw+mamba",
),
]


@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
"""
Test prefix caching with hybrid models containing various combinations of
KV cache spec types.

This unified test covers:
- Various combinations (full attn + other attn types)
- Varying number of groups (2, 3, or 4)
- 0, 1, or 2 full attention groups in the combination
- Two sliding_window attn groups with different window sizes
- Interleaved group IDs (full attn and other types alternating)
- Mamba spec with other attention types
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups

kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

hash_fn = sha256

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids

# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0

blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups

# Second request: should hit cached blocks for common prefix
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)

# Should hit cached blocks for all groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == num_groups

# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(common_token_ids) + 5 - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups

manager.free(req0)
manager.free(req1)


def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.

Expand Down
Loading