Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
def _make_bare_scheduler() -> MooncakeStoreScheduler:
scheduler = object.__new__(MooncakeStoreScheduler)
scheduler.kv_role = "kv_both"
scheduler.original_block_size = 16
scheduler._block_size = 16
scheduler.load_specs = {}
scheduler._preempted_req_ids = set()
Expand Down
82 changes: 78 additions & 4 deletions tests/v1/kv_connector/unit/test_mooncake_store_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _make_store_req(req_id: str, block_hashes: list[bytes]) -> ReqMeta:
block_ids=([0, 1],),
block_hashes=block_hashes,
can_save=True,
original_block_size=16,
)


Expand Down Expand Up @@ -970,7 +969,6 @@ def test_store_sending_thread_clamps_token_len_to_lcm():
block_ids=([0, 1, 2],),
block_hashes=[b"a0", b"a1", b"a2"],
can_save=True,
original_block_size=16,
)
)

Expand Down Expand Up @@ -1010,7 +1008,6 @@ def test_store_sending_thread_skips_when_token_len_below_lcm():
block_ids=([0, 1],),
block_hashes=[b"a0", b"a1"],
can_save=True,
original_block_size=64,
)
)

Expand Down Expand Up @@ -1088,7 +1085,6 @@ def test_store_sending_thread_only_stores_swa_blocks_in_window():
block_ids=([0, 1], list(range(8))),
block_hashes=hs,
can_save=True,
original_block_size=32,
)
)

Expand All @@ -1104,6 +1100,84 @@ def test_store_sending_thread_only_stores_swa_blocks_in_window():
assert swa_hashes == {hs[3].hex(), hs[7].hex()}


def test_store_sending_thread_kv_events_use_group_chunk_metadata():
from vllm.v1.core.kv_cache_utils import BlockHash, maybe_convert_block_hash
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheGroupSpec,
SlidingWindowSpec,
)

store = MagicMock()
store.batch_is_exist.side_effect = lambda keys: [0] * len(keys)
store.batch_put_from_multi_buffers.return_value = [256, 256]

full_spec = FullAttentionSpec(
block_size=32, num_kv_heads=8, head_size=64, dtype=None
)
swa_spec = SlidingWindowSpec(
block_size=8,
num_kv_heads=8,
head_size=64,
dtype=None,
sliding_window=8,
)
coord = mooncake_store_worker.MooncakeStoreCoordinator(
[KVCacheGroupSpec(["L0"], full_spec), KVCacheGroupSpec(["L1"], swa_spec)],
scheduler_block_size=32,
hash_block_size=8,
)

db_full = ChunkedTokenDatabase(
KeyMetadata("test-model", 0, 0, 0, 0, group_id=0),
block_size=32,
hash_block_size=8,
)
db_full.set_kv_caches_base_addr([0x1000])
db_full.set_block_len([512])
db_swa = ChunkedTokenDatabase(
KeyMetadata("test-model", 0, 0, 0, 0, group_id=1),
block_size=8,
hash_block_size=8,
)
db_swa.set_kv_caches_base_addr([0x2000])
db_swa.set_block_len([128])

thread = _make_store_sending_thread(
store,
coord=coord,
token_databases=[db_full, db_swa],
block_size=32,
)
thread.enable_kv_event = True

hs = [bytes([i + 1]) * 4 for i in range(4)]
thread.add_stored_request("r0")
thread._handle_request(
ReqMeta(
req_id="r0",
token_len_chunk=32,
block_ids=([0], list(range(4))),
block_hashes=hs,
can_save=True,
token_ids=list(range(32)),
)
)

full_event, swa_event = thread.get_kv_events()
assert full_event.group_idx == 0
assert full_event.block_size == 32
assert full_event.token_ids == list(range(32))
assert full_event.block_hashes == [
maybe_convert_block_hash(BlockHash(b"".join(hs)))
]

assert swa_event.group_idx == 1
assert swa_event.block_size == 8
assert swa_event.token_ids == list(range(24, 32))
assert swa_event.block_hashes == [maybe_convert_block_hash(BlockHash(hs[3]))]


def _auto_set_ready_event(*args, **kwargs):
"""Side effect for mocked thread constructors that auto-sets ready_event."""
for arg in args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ class ReqMeta:
current_event: torch.cuda.Event | None = None

token_ids: list[int] | None = None
original_block_size: int | None = None

@staticmethod
def from_request_tracker(
Expand All @@ -223,7 +222,6 @@ def from_request_tracker(
skip_save: bool | None = False,
block_hashes: list[BlockHash] | None = None,
is_last_chunk: bool | None = None,
original_block_size: int | None = None,
) -> "ReqMeta | None":
"""Create ReqMeta from a RequestTracker."""
if block_hashes is None:
Expand Down Expand Up @@ -274,7 +272,6 @@ def from_request_tracker(
block_hashes=block_hashes,
is_last_chunk=is_last_chunk,
token_ids=token_ids,
original_block_size=original_block_size,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ def __init__(
)
self.client = LookupKeyClient(vllm_config)

self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.original_block_size = vllm_config.cache_config.block_size
# LCM for multi-group HMA; bs * pcp * dcp for single-group. Matches
# the engine's own scheduler block size by construction.
# Align with the engine's own scheduler_block_size and hash_block_size.
self._block_size, self._hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)
Expand Down Expand Up @@ -221,7 +217,6 @@ def build_connector_meta(
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num),
original_block_size=self.original_block_size,
)
if req_meta is not None:
meta.add_request(req_meta)
Expand Down Expand Up @@ -274,7 +269,6 @@ def build_connector_meta(
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
original_block_size=self.original_block_size,
)
else:
# Decode/chunked request
Expand Down Expand Up @@ -312,7 +306,6 @@ def build_connector_meta(
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
original_block_size=self.original_block_size,
)

if req_meta is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ def _clear_store_pressure(self) -> bool:
return True

def _handle_request(self, req_meta: ReqMeta):
# Cache hits are always a multiple of ``lcm_block_size`` tokens
# Cache hits are always a multiple of ``lcm_block_size`` tokens, which
# is also ``store_mask``'s precondition.
lcm_block_size = self.coord.lcm_block_size
token_len = req_meta.token_len_chunk // lcm_block_size * lcm_block_size
block_ids_per_group = req_meta.block_ids
Expand Down Expand Up @@ -550,7 +551,7 @@ def _handle_request(self, req_meta: ReqMeta):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[chunk_idx])
block_hashes.append(BlockHash(bytes.fromhex(key.chunk_hash)))
group_indices.append(g_idx)

# Apply put_step striding for TP
Expand Down Expand Up @@ -627,10 +628,11 @@ def _handle_request(self, req_meta: ReqMeta):
block_hashes=[new_block_hashes[idx]],
parent_block_hash=prev_key_per_group.get(g_idx),
token_ids=token_ids,
block_size=req_meta.original_block_size,
block_size=db.block_size,
lora_id=None,
medium="cpu",
lora_name=None,
group_idx=g_idx,
)
stored_events.append(stored_event)
prev_key_per_group[g_idx] = new_block_hashes[idx]
Expand Down Expand Up @@ -947,7 +949,6 @@ def __init__(
"load_async", True
)
self.cache_config = vllm_config.cache_config
self.original_block_size = self.cache_config.block_size
self.block_size, self.hash_block_size = resolve_kv_cache_block_sizes(
kv_cache_config, vllm_config
)
Expand Down
Loading