Skip to content
Merged
4 changes: 4 additions & 0 deletions tests/v1/kv_connector/unit/test_mooncake_store_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,18 @@ def test_worker_methods_delegate_to_store_worker():

worker = mock_worker_cls.return_value
worker.get_finished.return_value = ({"req-1"}, {"req-2"})
worker.get_block_ids_with_load_errors.return_value = {3, 4}
connector.bind_connector_metadata(metadata)

connector.register_kv_caches(kv_caches)
result = connector.get_finished(finished_req_ids)
invalid_block_ids = connector.get_block_ids_with_load_errors()

worker.register_kv_caches.assert_called_once_with(kv_caches)
worker.get_finished.assert_called_once_with(finished_req_ids, metadata)
assert result == ({"req-1"}, {"req-2"})
worker.get_block_ids_with_load_errors.assert_called_once_with()
assert invalid_block_ids == {3, 4}


def test_get_kv_connector_kv_cache_events_returns_none_when_empty():
Expand Down
66 changes: 65 additions & 1 deletion tests/v1/kv_connector/unit/test_mooncake_store_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _make_store_sending_thread(
def _make_store_recving_thread(
store: MagicMock,
*,
tp_rank: int = 0,
disk_offload_buffer_budget_bytes: int | None = None,
) -> mooncake_store_worker.KVCacheStoreRecvingThread:
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec
Expand All @@ -93,7 +94,7 @@ def _make_store_recving_thread(
store=store,
token_databases=[token_database],
block_size=16,
tp_rank=0,
tp_rank=tp_rank,
ready_event=threading.Event(),
coord=coord,
disk_offload_buffer_budget_bytes=disk_offload_buffer_budget_bytes,
Expand Down Expand Up @@ -439,6 +440,67 @@ def test_store_sending_thread_only_skips_on_no_available_handle():
assert store.batch_put_from_multi_buffers.call_count == 2


def test_store_recving_thread_reports_failed_block_ids():
store = MagicMock()
store.batch_get_into_multi_buffers.return_value = [256, -5, -7]
thread = _make_store_recving_thread(store)

thread._handle_request(
_make_load_req(
"req-a",
[b"a0", b"a1", b"a2"],
token_len=48,
)
)

assert thread.get_and_clear_finished_requests() == {"req-a"}
assert thread.get_and_clear_block_ids_with_load_errors() == {1, 2}
assert thread.get_and_clear_block_ids_with_load_errors() == set()


def test_store_recving_thread_reports_failed_block_ids_after_rotation():
store = MagicMock()
store.batch_get_into_multi_buffers.return_value = [256, -5, 256]
thread = _make_store_recving_thread(store, tp_rank=1)

thread._handle_request(
_make_load_req(
"req-a",
[b"a0", b"a1", b"a2"],
token_len=48,
)
)

assert thread.get_and_clear_block_ids_with_load_errors() == {2}


def test_store_recving_thread_reports_all_attempted_blocks_on_exception():
store = MagicMock()
store.batch_get_into_multi_buffers.side_effect = RuntimeError("boom")
thread = _make_store_recving_thread(store)

thread._handle_request(
_make_load_req(
"req-a",
[b"a0", b"a1", b"a2"],
token_len=48,
)
)

assert thread.get_and_clear_finished_requests() == {"req-a"}
assert thread.get_and_clear_block_ids_with_load_errors() == {0, 1, 2}


def test_store_worker_get_block_ids_with_load_errors_delegates_to_recv_thread():
recv_thread = MagicMock()
recv_thread.get_and_clear_block_ids_with_load_errors.return_value = {3, 4}
w = _make_bare_worker()
w.kv_recv_thread = recv_thread

assert w.get_block_ids_with_load_errors() == {3, 4}
recv_thread.get_and_clear_block_ids_with_load_errors.assert_called_once_with()


def test_store_sending_thread_passes_replicate_config_when_preferred_segment_set():
store = MagicMock()
store.batch_is_exist.side_effect = lambda keys: [0] * len(keys)
Expand Down Expand Up @@ -638,6 +700,7 @@ def test_recv_thread_stops_after_first_failing_disk_offload_sub_batch():
thread._handle_request(req)

assert store.batch_get_into_multi_buffers.call_count == 1
assert thread.get_and_clear_block_ids_with_load_errors() == {0, 1}


def test_recv_thread_skips_split_when_budget_holds_all_keys():
Expand Down Expand Up @@ -682,6 +745,7 @@ def test_recv_thread_reports_unsplittable_key_larger_than_budget():
thread._handle_request(req)

assert store.batch_get_into_multi_buffers.call_count == 0
assert thread.get_and_clear_block_ids_with_load_errors() == {0}


def test_requester_worker_init_uses_positional_setup(tmp_path, monkeypatch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def get_finished(
assert isinstance(metadata, MooncakeStoreConnectorMetadata)
return self.connector_worker.get_finished(finished_req_ids, metadata)

def get_block_ids_with_load_errors(self) -> set[int]:
assert self.connector_worker is not None
return self.connector_worker.get_block_ids_with_load_errors()

def get_kv_connector_kv_cache_events(
self,
) -> MooncakeStoreKVEvents | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Literal, TypeVar

import regex as re
import torch
Expand Down Expand Up @@ -64,6 +64,12 @@
DEFAULT_LOCAL_BUFFER_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB

MOONCAKE_NO_AVAILABLE_HANDLE = -200
_T = TypeVar("_T")


def _rotate_list(values: list[_T], offset: int) -> list[_T]:
return values[offset:] + values[:offset]


# Mirrors FileStorageConfig::local_buffer_size in Mooncake C++.
DEFAULT_MOONCAKE_DISK_STAGING_BUFFER_BYTES = 1280 * 1024 * 1024
Expand Down Expand Up @@ -644,6 +650,9 @@ def __init__(
ready_event,
name="KVCacheStoreRecvingThread",
)
# _invalid_block_ids can be access by both the Worker and RecvingThread
self._invalid_block_ids_lock = threading.Lock()
self._invalid_block_ids: set[int] = set()
self.disk_offload_buffer_budget_bytes = disk_offload_buffer_budget_bytes
self.usable_disk_offload_buffer_budget_bytes = (
None
Expand All @@ -654,6 +663,16 @@ def __init__(
)
self.coord = coord

def _add_load_error_block_ids(self, block_ids: list[int]) -> None:
with self._invalid_block_ids_lock:
self._invalid_block_ids.update(block_ids)

def get_and_clear_block_ids_with_load_errors(self) -> set[int]:
with self._invalid_block_ids_lock:
invalid_block_ids = self._invalid_block_ids.copy()
self._invalid_block_ids.clear()
return invalid_block_ids

def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
req_id = req_meta.req_id
Expand All @@ -670,6 +689,7 @@ def _handle_request(self, req_meta: ReqMeta):
addr_list: list[list[int]] = []
size_list: list[list[int]] = []
key_list: list[str] = []
block_id_list: list[int] = []
for g_idx, db in enumerate(self.token_databases):
mask = load_mask_per_group[g_idx]
for start, end, key in db.process_tokens(
Expand All @@ -678,25 +698,49 @@ def _handle_request(self, req_meta: ReqMeta):
chunk_idx = start // db.block_size
if chunk_idx >= len(mask) or not mask[chunk_idx]:
continue
addr, size, _ = db.prepare_value(start, end, req_meta.block_ids[g_idx])
addr, size, block_id = db.prepare_value(
start, end, req_meta.block_ids[g_idx]
)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
block_id_list.append(block_id)

if not key_list:
Comment thread
Dao007forever marked this conversation as resolved.
Outdated
# Scheduler only schedules loads when there are external tokens
# beyond the local cache, so block_hashes is non-empty and
# mask_num < token_len here. The remaining way to get an empty
# key_list is that every chunk was filtered out by the per-group
# load mask -- e.g. all candidate blocks lie in the SWA
# pre-window that the consumer's spec wouldn't populate locally.
logger.warning(
"Skipping Mooncake load for request %s: every chunk filtered "
"by per-group load mask (token_len=%d, mask_num=%d, "
"num_block_hashes=%d)",
req_id,
token_len,
mask_num,
len(req_meta.block_hashes),
)
self.set_finished_request(req_id)
self.request_queue.task_done()
return
Comment thread
Dao007forever marked this conversation as resolved.
Outdated

# Rotate lists by tp_rank for load balancing
# Rotate aligned lists by tp_rank for load balancing.
rotation = self.tp_rank % len(key_list)
Comment thread
Dao007forever marked this conversation as resolved.
key_list_c = key_list[rotation:] + key_list[:rotation]
addr_list_c = addr_list[rotation:] + addr_list[:rotation]
size_list_c = size_list[rotation:] + size_list[:rotation]
key_list_c = _rotate_list(key_list, rotation)
addr_list_c = _rotate_list(addr_list, rotation)
size_list_c = _rotate_list(size_list, rotation)
block_id_list_c = _rotate_list(block_id_list, rotation)
Comment thread
Dao007forever marked this conversation as resolved.

load_batches = [(key_list_c, addr_list_c, size_list_c)]
load_batches = [(key_list_c, addr_list_c, size_list_c, block_id_list_c)]
if self.usable_disk_offload_buffer_budget_bytes is not None:
total_staging_bytes = sum(
_estimate_disk_offload_staging_bytes(size) for size in size_list_c
)
if total_staging_bytes > self.usable_disk_offload_buffer_budget_bytes:
assert self.disk_offload_buffer_budget_bytes is not None
load_batches, oversized_key = _split_disk_offload_load_batches(
split_batches, oversized_key = _split_disk_offload_load_batches(
key_list_c,
addr_list_c,
size_list_c,
Expand All @@ -705,6 +749,9 @@ def _handle_request(self, req_meta: ReqMeta):
)
if oversized_key is not None:
oversized_key_index = key_list_c.index(oversized_key)
self._add_load_error_block_ids(
[block_id_list_c[oversized_key_index]]
)
Comment thread
Dao007forever marked this conversation as resolved.
Outdated
oversized_key_bytes = _estimate_disk_offload_staging_bytes(
size_list_c[oversized_key_index]
)
Expand All @@ -719,11 +766,24 @@ def _handle_request(self, req_meta: ReqMeta):
self.set_finished_request(req_id)
self.request_queue.task_done()
return
load_batches = []
block_id_offset = 0
for batch_keys, batch_addrs, batch_sizes in split_batches:
next_block_id_offset = block_id_offset + len(batch_keys)
batch_block_ids = block_id_list_c[
block_id_offset:next_block_id_offset
]
load_batches.append(
(batch_keys, batch_addrs, batch_sizes, batch_block_ids)
)
block_id_offset = next_block_id_offset

current_batch_keys: list[str] = key_list_c
current_batch_block_ids: list[int] = block_id_list_c
try:
for batch_keys, batch_addrs, batch_sizes in load_batches:
for batch_keys, batch_addrs, batch_sizes, batch_block_ids in load_batches:
current_batch_keys = batch_keys
current_batch_block_ids = batch_block_ids
tiers_by_key: dict[str, str] | None = None
if envs.VLLM_MOONCAKE_STORE_TIER_LOG:
tiers_by_key = _get_replica_tiers_by_key(self.store, batch_keys)
Expand All @@ -735,20 +795,26 @@ def _handle_request(self, req_meta: ReqMeta):
req_id, batch_keys, res, tiers_by_key
)
failed = [
(key, value)
for key, value in zip(batch_keys, res, strict=True)
(key, value, block_id)
for key, value, block_id in zip(
batch_keys, res, batch_block_ids, strict=True
)
if value < 0
]
if failed:
self._add_load_error_block_ids(
[block_id for _, _, block_id in failed]
)
logger.warning(
"Failed to get %d Mooncake keys from sub-batch "
"(batch_keys=%d, first_failures=%s)",
len(failed),
len(batch_keys),
failed[:3],
[(key, value) for key, value, _ in failed[:3]],
)
break
except Exception as e:
self._add_load_error_block_ids(current_batch_block_ids)
logger.warning(
"Failed to get Mooncake sub-batch %s, error: %s",
current_batch_keys[:3],
Expand Down Expand Up @@ -1155,6 +1221,11 @@ def get_finished(
)
return done_sending, done_recving

def get_block_ids_with_load_errors(self) -> set[int]:
if self.kv_recv_thread is None:
return set()
return self.kv_recv_thread.get_and_clear_block_ids_with_load_errors()

def _get_and_clear_finished_sending(
self,
finished_req_ids: set[str],
Expand Down
Loading