Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/features/mooncake_store_connector_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ the vLLM JSON config.

- `load_async` (bool): Enable asynchronous loading for better compute-I/O overlap. Default: `true`.
- `enable_cross_layers_blocks` (bool): Enable cross-layer block packing for reduced store operations. Default: `false`.
- `discard_partial_chunks` (bool): Discard partial block chunks during store. Default: `true`.
- `lookup_rpc_port` (int): Custom port for the ZMQ lookup RPC socket. Default: `0`.

## Notes
Expand Down
64 changes: 63 additions & 1 deletion tests/v1/kv_connector/unit/test_mooncake_store_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def _make_bare_scheduler() -> MooncakeStoreScheduler:
scheduler.kv_role = "kv_both"
scheduler.original_block_size = 16
scheduler._block_size = 16
scheduler._discard_partial_chunks = True
scheduler.load_specs = {}
scheduler._preempted_req_ids = set()
scheduler._unfinished_request_ids = {"req-0"}
Expand Down Expand Up @@ -344,3 +343,66 @@ def test_from_request_tracker_no_load_saves_normally():
assert req_meta.can_save is True
assert req_meta.load_spec is None
assert tracker.num_saved_tokens == 48


class _StubLookupClient:
def __init__(self, hit_tokens: int) -> None:
self._hit_tokens = hit_tokens

def lookup(self, token_len: int, block_hashes: list[bytes]) -> int:
return self._hit_tokens


def test_full_external_hit_keeps_kvpool_cached_tokens_block_aligned():
# When the external store hits the entire prompt, scheduler must leave at
# least one token uncomputed for sampling but stay on a block boundary.
# Otherwise the recv-side load mask floors token_len to
# (num_tokens-1)//block_size, the tail partial chunk is dropped, and -- if
# the local cache covers the aligned prefix -- key_list ends up empty
# (ZeroDivisionError in the recv thread's `tp_rank % len(key_list)`).
scheduler = _make_bare_scheduler()
scheduler.load_async = True
scheduler.client = _StubLookupClient(hit_tokens=48) # full hit on 48-token prompt

request = SimpleNamespace(
request_id="req-0",
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
)

need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
request, num_computed_tokens=16
)

# 47 // 16 * 16 == 32 tokens left in external store after reserving the
# sub-block tail for sampling. 32 - 16 (local) == 16 to load.
assert need_to_allocate == 16
assert load_async is True
load_spec = scheduler.load_specs["req-0"]
assert load_spec.vllm_cached_tokens == 16
assert load_spec.kvpool_cached_tokens == 32
assert load_spec.kvpool_cached_tokens % 16 == 0


def test_full_external_hit_with_full_local_hit_skips_load():
# When local prefix cache already covers the block-aligned external hit,
# there is nothing for the connector to load. The pre-fix behavior would
# have scheduled a 15-token load that the recv thread couldn't translate
# into any block-aligned key.
scheduler = _make_bare_scheduler()
scheduler.load_async = True
scheduler.client = _StubLookupClient(hit_tokens=48)

request = SimpleNamespace(
request_id="req-0",
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
)

need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
request, num_computed_tokens=32
)

assert need_to_allocate == 0
assert load_async is False
assert "req-0" not in scheduler.load_specs
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def __init__(
kv_cache_config, vllm_config
)

self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True
)
)

# Per-request state
self.load_specs: dict[str, LoadSpec] = {} # to be loaded
self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests
Expand All @@ -88,18 +82,19 @@ def get_num_new_matched_tokens(
) -> tuple[int, bool]:
"""Check for external KV cache hit."""
# Look up against the full prefill range, not just the prompt.
if self._discard_partial_chunks:
token_len = request.num_tokens // self._block_size * self._block_size
else:
token_len = request.num_tokens

token_len = request.num_tokens // self._block_size * self._block_size
if token_len < self._block_size:
return 0, False

num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)

if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
# Leave a sub-block tail uncomputed for sampling, on a block
Comment thread
Dao007forever marked this conversation as resolved.
# boundary so the recv-side load mask covers every yielded chunk.
num_external_hit_tokens = max(
0,
(request.num_tokens - 1) // self._block_size * self._block_size,
)

if num_external_hit_tokens < num_computed_tokens:
need_to_allocate = 0
Expand Down Expand Up @@ -214,9 +209,7 @@ def build_connector_meta(
self._request_trackers[request.req_id] = request_tracker

last_chunk_tokens_num = (
(len(prefill_tokens) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(prefill_tokens)
len(prefill_tokens) // self._block_size * self._block_size
)

req_meta = ReqMeta.from_request_tracker(
Expand All @@ -226,7 +219,6 @@ def build_connector_meta(
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
if req_meta is not None:
Expand Down Expand Up @@ -269,9 +261,7 @@ def build_connector_meta(
self._request_trackers[req_id] = request_tracker

last_chunk_tokens_num = (
(len(prefill_tokens) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(prefill_tokens)
len(prefill_tokens) // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
Expand All @@ -282,7 +272,6 @@ def build_connector_meta(
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
else:
Expand Down Expand Up @@ -310,9 +299,7 @@ def build_connector_meta(
request_tracker.update(new_block_ids)

last_chunk_tokens_num = (
(prefill_end // self._block_size * self._block_size)
if self._discard_partial_chunks
else prefill_end
prefill_end // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
Expand All @@ -323,7 +310,6 @@ def build_connector_meta(
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)

Expand All @@ -341,10 +327,6 @@ def build_connector_meta(
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size != 0) and (
num_tokens_to_compute == unfinished_req.num_tokens - 1
):
num_tokens_to_compute = num_tokens_to_compute + 1
Comment thread
Dao007forever marked this conversation as resolved.
request_tracker = RequestTracker(
req_id=request_id,
token_len=num_tokens_to_compute,
Expand All @@ -358,7 +340,6 @@ def build_connector_meta(
load_spec=load_spec,
skip_save=None,
block_hashes=unfinished_req.block_hashes,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,14 +1104,7 @@ def get_finished(
if load_spec is None or not load_spec.can_load:
continue

token_len = request.token_len_chunk
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
load_spec.kvpool_cached_tokens == token_len - 1
):
token_len = load_spec.kvpool_cached_tokens + 1
else:
token_len = load_spec.kvpool_cached_tokens
load_spec.token_len = token_len
load_spec.token_len = load_spec.kvpool_cached_tokens

assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(request)
Expand Down
Loading