Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 204 additions & 4 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)

Expand Down Expand Up @@ -101,8 +102,15 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:


def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
)
elif second_spec_type == "full_attention":
second_spec = FullAttentionSpec(block_size, 2, 1, torch.float32)

return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
Expand All @@ -113,16 +121,46 @@ def make_kv_cache_config_hybrid_model(
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
),
second_spec,
),
KVCacheGroupSpec(
["layer3"],
second_spec,
),
],
)


def make_kv_cache_config_three_types(
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
) -> KVCacheConfig:
if third_spec_type == "mamba":
third_spec = MambaSpec(block_size, 1, 1, torch.float32)
elif third_spec_type == "sliding_window":
third_spec = SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=4 * block_size
)
elif third_spec_type == "full_attention":
third_spec = FullAttentionSpec(block_size, 2, 1, torch.float32)

return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
),
),
KVCacheGroupSpec(
["layer3"],
third_spec,
),
],
)

Expand Down Expand Up @@ -406,6 +444,168 @@ def test_partial_request_hit(
)


@pytest.mark.parametrize("second_spec_type", ["sliding_window", "full_attention"])
def test_prefill_hybrid_model_parametrized(second_spec_type: str):
"""
Test prefix caching with hybrid model where second spec can be:
- SlidingWindow (original behavior)
- Different FullAttentionSpec (tests multiple FullAttention types)
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config_hybrid_model(block_size, 21, second_spec_type),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

hash_fn = sha256

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids

req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0

blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all 3 groups
assert len(blocks.get_block_ids()) == 3

# Cache hit test
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)

# Should hit cached blocks for all 3 groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == 3

manager.free(req0)
manager.free(req1)


@pytest.mark.parametrize(
"third_spec_type",
["mamba", "sliding_window", "full_attention"],
)
def test_prefill_three_spec_types(third_spec_type: str):
"""
Test prefix caching with 3 different KV cache spec types:
FullAttention + SlidingWindow + third type (Mamba, different SlidingWindow,
or different FullAttention)
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config_three_types(block_size, 31, third_spec_type),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

hash_fn = sha256

# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids

req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0

blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all 3 groups
assert len(blocks.get_block_ids()) == 3

# Cache hit test
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)

# Should hit cached blocks for all 3 groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == 3

manager.free(req0)
manager.free(req1)


def test_interleaved_group_ids():
"""
Test that interleaved group IDs work correctly.
Groups 0, 2 are FullAttention; Groups 1, 3 are SlidingWindow
"""
block_size = 16
kv_cache_config = KVCacheConfig(
num_blocks=50,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer0"],
FullAttentionSpec(block_size, 1, 1, torch.float32),
),
KVCacheGroupSpec(
["layer1"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
),
),
KVCacheGroupSpec(
["layer2"],
FullAttentionSpec(block_size, 1, 1, torch.float32),
),
KVCacheGroupSpec(
["layer3"],
SlidingWindowSpec(
block_size, 1, 1, torch.float32, sliding_window=2 * block_size
),
),
],
)

manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)

common_token_ids = [i for i in range(3) for _ in range(block_size)]
req0 = make_request("0", common_token_ids + [3] * 7, block_size, sha256)

computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)

# Verify blocks are allocated for all 4 groups
assert blocks is not None
assert len(blocks.get_block_ids()) == 4

# Verify cache hit works
manager.free(req0)
req1 = make_request("1", common_token_ids + [4] * 5, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)

assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == 4

manager.free(req1)


def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.

Expand Down
Loading