Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 196 additions & 1 deletion tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,6 +2466,201 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0


def test_eagle_swa_alignment_caches_extra_block():
"""Regression: SWA + EAGLE with `sliding_window <= alignment_tokens`.

When the cache-hit alignment (lcm of per-group block sizes) is larger than
the SWA window, the SWA mask only kept the last block of each aligned
segment. EAGLE/MTP lookup needs ``tail + 1`` contiguous cached blocks and
that +1 block lives at the next segment's first position, which was left
uncached. The fix caches that extra block when ``use_eagle=True``.
"""
block_size = 8
# Full group uses 4 * block_size, so lcm/alignment is 4 * block_size.
# SWA group has sliding_window = block_size (i.e., tail = 1 block).
# Without the fix, the second cached block needed for the EAGLE 2-block
# match never exists -> EAGLE cache hit fails entirely.
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)

# Prime the cache with a long prompt (16 swa blocks = 4 aligned segments).
token_ids = [i for i in range(16) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
manager.free(req0)

# Second request with identical prompt should find an EAGLE cache hit.
# Without the fix, ``num_computed_tokens`` is 0; with the fix, it lands at
# an alignment boundary (multiple of 32 tokens, minus the EAGLE drop).
req1 = make_request("1", token_ids, block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens > 0, (
"EAGLE + SWA with sliding_window <= alignment failed to find any "
"cache hit; the +1 block past each segment boundary must be cached."
)
# Each aligned segment contributes 4 * block_size = 32 tokens; EAGLE drops
# the last block (block_size tokens) from the hit.
assert num_computed_tokens % (4 * block_size) == 0


def test_eagle_swa_boundary_caches_post_boundary_block():
"""EAGLE + SWA must cache the first block after an alignment boundary.

A 40-token computed prefix with 8-token SWA blocks and 32-token hybrid
alignment needs SWA blocks 3 and 4 cached to reuse a 32-token prefix:
block 3 is the segment tail, and block 4 is the EAGLE lookahead block
that gets dropped after lookup.
"""
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)

token_ids = [i for i in range(5) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None

pool = manager.block_pool
assert pool.get_cached_block(req0.block_hashes[3], kv_cache_group_ids=[1])
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1])
manager.free(req0)

req1 = make_request("1", token_ids + [999], block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 4 * block_size


def test_eagle_grouped_swa_siblings_use_same_cache_mask():
"""Grouped SWA siblings must cache the EAGLE lookahead block together."""
block_size = 8
swa_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(["swa_main"], swa_spec),
KVCacheGroupSpec(["swa_mtp"], swa_spec, is_eagle_group=True),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)

token_ids = [i for i in range(9) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None

pool = manager.block_pool
assert pool.get_cached_block(req0.block_hashes[4], kv_cache_group_ids=[1, 2])
assert pool.get_cached_block(req0.block_hashes[8], kv_cache_group_ids=[1, 2])
manager.free(req0)

req1 = make_request("1", token_ids + [999], block_size, sha256)
_, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 8 * block_size


def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
Expand Down Expand Up @@ -2614,7 +2809,7 @@ def test_hybrid_cache_blocks_swa_tail_window_only():


def test_hybrid_cache_blocks_clamped_to_lcm():
"""HybridKVCacheCoordinator.cache_blocks() clamps to lcm_block_size.
"""HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size.
Chunks past the last lcm-aligned boundary can never participate in a
cache hit (find_longest_cache_hit always returns lcm-aligned hits), so
caching them only pollutes the prefix-cache hash map and keeps blocks
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run_one_case(block_is_cached, tail_token, expect_length):
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False,
drop_eagle_block=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
Expand Down Expand Up @@ -157,7 +157,7 @@ def run_one_case(block_is_cached, expect_length):
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False,
drop_eagle_block=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _find_hit_blocks(
kv_cache_group_ids=group_ids,
block_pool=cast(BlockPool, cached_block_pool),
kv_cache_spec=spec,
use_eagle=(0 in eagle_indices),
drop_eagle_block=(0 in eagle_indices),
alignment_tokens=spec.block_size,
)
num_groups = len(self.kv_cache_groups)
Expand Down Expand Up @@ -262,9 +262,9 @@ def _find_hit_blocks(
)
continue

use_eagle = idx in eagle_indices and idx not in eagle_verified
drop_eagle_block = idx in eagle_indices and idx not in eagle_verified
_max_length = curr_hit_length
if use_eagle:
if drop_eagle_block:
_max_length = min(curr_hit_length + spec.block_size, max_length)
hashes = self.block_hashes_for_spec(block_hashes, spec)
hit_blocks = manager_cls.find_longest_cache_hit(
Expand All @@ -273,11 +273,11 @@ def _find_hit_blocks(
kv_cache_group_ids=group_ids,
block_pool=cast(BlockPool, cached_block_pool),
kv_cache_spec=spec,
use_eagle=use_eagle,
drop_eagle_block=drop_eagle_block,
alignment_tokens=self.lcm_block_size,
)
_new_hit_length = len(hit_blocks[0]) * spec.block_size
if use_eagle:
if drop_eagle_block:
eagle_verified.add(idx)
elif _new_hit_length < curr_hit_length:
eagle_verified.clear()
Expand Down
Loading
Loading