diff --git a/tests/v1/kv_connector/unit/test_mooncake_store_coordinator.py b/tests/v1/kv_connector/unit/test_mooncake_store_coordinator.py index 1fd601af59ce..492a905ed16c 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_store_coordinator.py +++ b/tests/v1/kv_connector/unit/test_mooncake_store_coordinator.py @@ -300,3 +300,85 @@ def test_store_mask_fast_path_single_attention_group(): assert len(coord.attention_groups) == 1 masks = coord.store_mask(64) assert masks == ([True] * 4, [True] * 4) + + +# ----- Eagle / MTP interaction with load_mask ----- + + +def test_lookup_with_eagle_pops_last_full_attention_block(): + """Sanity: with use_eagle, find_longest_cache_hit drops the last block. + Pairs with the load_mask test below to lock the round-trip contract.""" + groups = [KVCacheGroupSpec(["L0"], _full(16))] + coord = _make_coord(groups, hash_block_size=16, use_eagle=True) + hs = _hashes(4) + cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs}) + _masks, hit = coord.find_longest_cache_hit( + hs, max_length=64, cached_block_pool=cmap + ) + # 4 blocks present, eagle pops 1 → 3 blocks = 48 tokens. + assert hit == 48 + + +def test_load_mask_with_eagle_does_not_double_prune_full_attention(): + """Regression for silent KV corruption with MTP/EAGLE-3. + + The recv side calls ``load_mask(block_hashes, token_len)`` where + ``token_len`` is already the eagle-pruned hit length from ``lookup``. + A second eagle pop here used to shorten the mask by one extra block; + ``process_tokens`` then yielded a chunk past the mask, which the worker + silently skipped — leaving the trailing block of the loaded prefix + uninitialized in local KV. + """ + groups = [KVCacheGroupSpec(["L0"], _full(16))] + coord = _make_coord(groups, hash_block_size=16, use_eagle=True) + hs = _hashes(4) + cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs}) + _masks, hit = coord.find_longest_cache_hit( + hs, max_length=64, cached_block_pool=cmap + ) + assert hit == 48 # eagle popped 1 block + + masks = coord.load_mask(hs, token_len=hit) + # Every chunk that process_tokens(token_len=48, ...) would yield must + # have a corresponding mask slot. process_tokens emits chunk_id 0..2 + # (start=0, 16, 32), so the mask must be length 3, all True. + assert masks[0] == [True, True, True] + + +def test_load_mask_with_eagle_hybrid_full_plus_swa(): + """Hybrid (FullAttn + SWA) with eagle: load_mask must cover every chunk + in [0, token_len) for the FullAttn group; SWA group keeps its + tail-window mask.""" + groups = [ + KVCacheGroupSpec(["L0"], _full(16)), + KVCacheGroupSpec(["L1"], _swa(16, 32)), + ] + coord = _make_coord(groups, hash_block_size=16, use_eagle=True) + hs = _hashes(4) + exists = {(g, bytes(h)) for g in (0, 1) for h in hs} + cmap = ExternalCachedBlockPool(exists) + _masks, hit = coord.find_longest_cache_hit( + hs, max_length=64, cached_block_pool=cmap + ) + # FullAttn dictates the convergence; eagle pops one block off it. + assert hit == 48 + + masks = coord.load_mask(hs, token_len=hit) + # FullAttn: all chunks populated locally. + assert masks[0] == [True, True, True] + # SWA: tail-window only (ceil((32-1)/16) = 2 trailing blocks). + assert masks[1][-2:] == [True, True] + + +def test_load_mask_without_eagle_unchanged(): + """Sanity: when eagle is off, load_mask is identical to the pre-fix path.""" + groups = [KVCacheGroupSpec(["L0"], _full(16))] + coord = _make_coord(groups, hash_block_size=16, use_eagle=False) + hs = _hashes(4) + cmap = ExternalCachedBlockPool({(0, bytes(h)) for h in hs}) + _masks, hit = coord.find_longest_cache_hit( + hs, max_length=64, cached_block_pool=cmap + ) + assert hit == 64 + masks = coord.load_mask(hs, token_len=hit) + assert masks[0] == [True, True, True, True] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py index 26f3aa263206..b16fdb7c16c0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/coordinator.py @@ -115,12 +115,21 @@ def find_longest_cache_hit( block_hashes: list[BlockHash], max_length: int, cached_block_pool: ExternalCachedBlockPool, + *, + apply_eagle: bool = True, ) -> tuple[tuple[list[bool], ...], int]: """Returns ``(load_mask_per_group, hit_length)``. ``mask[g][i]`` is True iff group ``g`` populates chunk ``i`` locally (e.g. SWA and Mamba tail-only); - recv-side callers skip False slots.""" + recv-side callers skip False slots. + + ``apply_eagle`` controls whether the per-spec ``use_eagle`` last-block + pop is applied. Lookup callers want it (the drafter requires recomputing + the last block); per-chunk mask callers must not, because ``token_len`` + already reflects the eagle-pruned hit length and a second pop would + leave the trailing block unloaded. + """ blocks_per_group, hit_length = self._find_hit_blocks( - block_hashes, max_length, cached_block_pool + block_hashes, max_length, cached_block_pool, apply_eagle=apply_eagle ) masks = tuple( [blk is not cached_block_pool.null_block for blk in blocks] @@ -137,8 +146,17 @@ def load_mask( spec would populate chunk ``i`` locally at length ``token_len`` (e.g. SWA / Mamba tail-only). """ + # ``apply_eagle=False`` because ``token_len`` is already the + # eagle-pruned hit length returned by ``client.lookup``. Re-applying + # the pop here would shorten the mask by one extra block; the recv + # thread would then silently skip the trailing chunk yielded by + # ``db.process_tokens`` and leave that block uninitialized in the + # local KV pool. masks, _ = self.find_longest_cache_hit( - block_hashes, token_len, ExternalCachedBlockPool() + block_hashes, + token_len, + ExternalCachedBlockPool(), + apply_eagle=False, ) return masks @@ -195,10 +213,17 @@ def _find_hit_blocks( block_hashes: list[BlockHash], max_length: int, cached_block_pool: ExternalCachedBlockPool, + *, + apply_eagle: bool = True, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: """Mirrors HybridKVCacheCoordinator.find_longest_cache_hit but dispatches via spec_manager_map (we don't allocate managers). + + When ``apply_eagle`` is False, ignore ``eagle_attn_group_indices`` — + used by ``load_mask`` to avoid popping a second block on top of the + one already removed by the lookup. """ + eagle_indices = self.eagle_attn_group_indices if apply_eagle else set() if len(self.attention_groups) == 1: spec, group_ids, manager_cls = self.attention_groups[0] hashes = self.block_hashes_for_spec(block_hashes, spec) @@ -208,7 +233,7 @@ def _find_hit_blocks( kv_cache_group_ids=group_ids, block_pool=cast(BlockPool, cached_block_pool), kv_cache_spec=spec, - use_eagle=(0 in self.eagle_attn_group_indices), + use_eagle=(0 in eagle_indices), alignment_tokens=spec.block_size, ) num_groups = len(self.kv_cache_groups) @@ -237,9 +262,7 @@ def _find_hit_blocks( ) continue - use_eagle = ( - idx in self.eagle_attn_group_indices and idx not in eagle_verified - ) + use_eagle = idx in eagle_indices and idx not in eagle_verified _max_length = curr_hit_length if use_eagle: _max_length = min(curr_hit_length + spec.block_size, max_length)