Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/v1/kv_connector/unit/test_mooncake_store_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,85 @@ def test_store_mask_fast_path_single_attention_group():
assert len(coord.attention_groups) == 1
masks = coord.store_mask(64)
assert masks == ([True] * 4, [True] * 4)


# ----- Eagle / MTP interaction with load_mask -----


def test_lookup_with_eagle_pops_last_full_attention_block():
"""Sanity: with use_eagle, find_longest_cache_hit drops the last block.
Pairs with the load_mask test below to lock the round-trip contract."""
groups = [KVCacheGroupSpec(["L0"], _full(16))]
coord = _make_coord(groups, hash_block_size=16, use_eagle=True)
hs = _hashes(4)
cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs})
_masks, hit = coord.find_longest_cache_hit(
hs, max_length=64, cached_block_pool=cmap
)
# 4 blocks present, eagle pops 1 → 3 blocks = 48 tokens.
assert hit == 48


def test_load_mask_with_eagle_does_not_double_prune_full_attention():
"""Regression for silent KV corruption with MTP/EAGLE-3.

The recv side calls ``load_mask(block_hashes, token_len)`` where
``token_len`` is already the eagle-pruned hit length from ``lookup``.
A second eagle pop here used to shorten the mask by one extra block;
``process_tokens`` then yielded a chunk past the mask, which the worker
silently skipped — leaving the trailing block of the loaded prefix
uninitialized in local KV.
"""
groups = [KVCacheGroupSpec(["L0"], _full(16))]
coord = _make_coord(groups, hash_block_size=16, use_eagle=True)
hs = _hashes(4)
cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs})
_masks, hit = coord.find_longest_cache_hit(
hs, max_length=64, cached_block_pool=cmap
)
assert hit == 48 # eagle popped 1 block

masks = coord.load_mask(hs, token_len=hit)
# Every chunk that process_tokens(token_len=48, ...) would yield must
# have a corresponding mask slot. process_tokens emits chunk_id 0..2
# (start=0, 16, 32), so the mask must be length 3, all True.
assert masks[0] == [True, True, True]


def test_load_mask_with_eagle_hybrid_full_plus_swa():
"""Hybrid (FullAttn + SWA) with eagle: load_mask must cover every chunk
in [0, token_len) for the FullAttn group; SWA group keeps its
tail-window mask."""
groups = [
KVCacheGroupSpec(["L0"], _full(16)),
KVCacheGroupSpec(["L1"], _swa(16, 32)),
]
coord = _make_coord(groups, hash_block_size=16, use_eagle=True)
hs = _hashes(4)
exists = {(g, bytes(h)) for g in (0, 1) for h in hs}
cmap = ExternalCachedBlockPool(exists)
_masks, hit = coord.find_longest_cache_hit(
hs, max_length=64, cached_block_pool=cmap
)
# FullAttn dictates the convergence; eagle pops one block off it.
assert hit == 48

masks = coord.load_mask(hs, token_len=hit)
# FullAttn: all chunks populated locally.
assert masks[0] == [True, True, True]
# SWA: tail-window only (ceil((32-1)/16) = 2 trailing blocks).
assert masks[1][-2:] == [True, True]


def test_load_mask_without_eagle_unchanged():
"""Sanity: when eagle is off, load_mask is identical to the pre-fix path."""
groups = [KVCacheGroupSpec(["L0"], _full(16))]
coord = _make_coord(groups, hash_block_size=16, use_eagle=False)
hs = _hashes(4)
cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs})
_masks, hit = coord.find_longest_cache_hit(
hs, max_length=64, cached_block_pool=cmap
)
assert hit == 64
masks = coord.load_mask(hs, token_len=hit)
assert masks[0] == [True, True, True, True]
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,21 @@ def find_longest_cache_hit(
block_hashes: list[BlockHash],
max_length: int,
cached_block_pool: ExternalCachedBlockPool,
*,
apply_eagle: bool = True,
) -> tuple[tuple[list[bool], ...], int]:
"""Returns ``(load_mask_per_group, hit_length)``. ``mask[g][i]`` is True iff
group ``g`` populates chunk ``i`` locally (e.g. SWA and Mamba tail-only);
recv-side callers skip False slots."""
recv-side callers skip False slots.

``apply_eagle`` controls whether the per-spec ``use_eagle`` last-block
pop is applied. Lookup callers want it (the drafter requires recomputing
the last block); per-chunk mask callers must not, because ``token_len``
already reflects the eagle-pruned hit length and a second pop would
leave the trailing block unloaded.
"""
blocks_per_group, hit_length = self._find_hit_blocks(
block_hashes, max_length, cached_block_pool
block_hashes, max_length, cached_block_pool, apply_eagle=apply_eagle
)
masks = tuple(
[blk is not cached_block_pool.null_block for blk in blocks]
Expand All @@ -137,8 +146,17 @@ def load_mask(
spec would populate chunk ``i`` locally at length ``token_len``
(e.g. SWA / Mamba tail-only).
"""
# ``apply_eagle=False`` because ``token_len`` is already the
# eagle-pruned hit length returned by ``client.lookup``. Re-applying
# the pop here would shorten the mask by one extra block; the recv
# thread would then silently skip the trailing chunk yielded by
# ``db.process_tokens`` and leave that block uninitialized in the
# local KV pool.
masks, _ = self.find_longest_cache_hit(
block_hashes, token_len, ExternalCachedBlockPool()
block_hashes,
token_len,
ExternalCachedBlockPool(),
apply_eagle=False,
Comment thread
Dao007forever marked this conversation as resolved.
)
return masks

Expand Down Expand Up @@ -195,10 +213,17 @@ def _find_hit_blocks(
block_hashes: list[BlockHash],
max_length: int,
cached_block_pool: ExternalCachedBlockPool,
*,
apply_eagle: bool = True,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""Mirrors HybridKVCacheCoordinator.find_longest_cache_hit but
dispatches via spec_manager_map (we don't allocate managers).

When ``apply_eagle`` is False, ignore ``eagle_attn_group_indices`` —
used by ``load_mask`` to avoid popping a second block on top of the
one already removed by the lookup.
"""
eagle_indices = self.eagle_attn_group_indices if apply_eagle else set()
if len(self.attention_groups) == 1:
spec, group_ids, manager_cls = self.attention_groups[0]
hashes = self.block_hashes_for_spec(block_hashes, spec)
Expand All @@ -208,7 +233,7 @@ def _find_hit_blocks(
kv_cache_group_ids=group_ids,
block_pool=cast(BlockPool, cached_block_pool),
kv_cache_spec=spec,
use_eagle=(0 in self.eagle_attn_group_indices),
use_eagle=(0 in eagle_indices),
alignment_tokens=spec.block_size,
)
num_groups = len(self.kv_cache_groups)
Expand Down Expand Up @@ -237,9 +262,7 @@ def _find_hit_blocks(
)
continue

use_eagle = (
idx in self.eagle_attn_group_indices and idx not in eagle_verified
)
use_eagle = idx in eagle_indices and idx not in eagle_verified
_max_length = curr_hit_length
if use_eagle:
_max_length = min(curr_hit_length + spec.block_size, max_length)
Expand Down
Loading