Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions tests/v1/kv_connector/unit/test_mooncake_store_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,35 @@ def test_store_sending_thread_only_skips_on_no_available_handle():
assert store.batch_put_from_multi_buffers.call_count == 2


def test_store_sending_thread_releases_pin_on_batch_is_exist_failure():
# `batch_is_exist` raising must still decrement `stored_requests` so the
# scheduler can drop `delay_free_blocks` and release the pinned GPU blocks.
store = MagicMock()
store.batch_is_exist.side_effect = RuntimeError("mooncake down")
thread = _make_store_sending_thread(store)

thread.add_stored_request("req-a")
with pytest.raises(RuntimeError):
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))

assert thread.stored_requests["req-a"] == 0
store.batch_put_from_multi_buffers.assert_not_called()


def test_store_sending_thread_releases_pin_on_batch_put_failure():
# `batch_put_from_multi_buffers` raising is logged (not re-raised), and the
# pin must still be released through the finally block.
store = MagicMock()
store.batch_is_exist.return_value = [0, 0]
store.batch_put_from_multi_buffers.side_effect = RuntimeError("rdma error")
thread = _make_store_sending_thread(store)

thread.add_stored_request("req-a")
thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"]))

assert thread.stored_requests["req-a"] == 0


def test_store_recving_thread_reports_failed_block_ids():
store = MagicMock()
store.batch_get_into_multi_buffers.return_value = [256, -5, -7]
Expand Down
325 changes: 164 additions & 161 deletions vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,187 +517,190 @@ def _handle_request(self, req_meta: ReqMeta):
if req_id not in self.stored_requests:
self.request_queue.task_done()
return
if token_len == 0:
self.dec_stored_request(req_id)
self.request_queue.task_done()
return
if self._should_skip_request(req_id):
logger.debug(
"Skipping Mooncake store for request %s while CPU/disk offloading "
"is under pressure",
req_id,
)
self.dec_stored_request(req_id)
self.request_queue.task_done()
return

# Within each lcm region only per-spec relevant chunks are loaded
# (e.g., SWA or linear attn), so mask out irrelevant chunks
store_masks = self.coord.store_mask(token_len)
starts: list[int] = []
ends: list[int] = []
keys: list[str] = []
block_hashes: list[BlockHash] = []
group_indices: list[int] = []
for g_idx, db in enumerate(self.token_databases):
mask = store_masks[g_idx]
for chunk_idx, (start, end, key) in enumerate(
db.process_tokens(token_len, req_meta.block_hashes)
):
if chunk_idx >= len(mask) or not mask[chunk_idx]:
continue
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[chunk_idx])
group_indices.append(g_idx)

# Apply put_step striding for TP
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
starts = starts[sl]
ends = ends[sl]
keys = keys[sl]
block_hashes = block_hashes[sl]
group_indices = group_indices[sl]

if not keys:
self.dec_stored_request(req_id)
return

# Check which blocks already exist (dedup)
save_exists_start = time.perf_counter()
# Decrement the in-flight counter and signal task_done() in `finally`
# so the scheduler can release the GPU blocks it pinned for this
# request (via `delay_free_blocks`) even when the store path raises.
try:
exists_states = self.store.batch_is_exist(keys)
except Exception:
if token_len == 0:
return
if self._should_skip_request(req_id):
logger.debug(
"Skipping Mooncake store for request %s while CPU/disk "
"offloading is under pressure",
req_id,
)
return

# Within each lcm region only per-spec relevant chunks are loaded
# (e.g., SWA or linear attn), so mask out irrelevant chunks
store_masks = self.coord.store_mask(token_len)
starts: list[int] = []
ends: list[int] = []
keys: list[str] = []
block_hashes: list[BlockHash] = []
group_indices: list[int] = []
for g_idx, db in enumerate(self.token_databases):
mask = store_masks[g_idx]
for chunk_idx, (start, end, key) in enumerate(
db.process_tokens(token_len, req_meta.block_hashes)
):
if chunk_idx >= len(mask) or not mask[chunk_idx]:
continue
starts.append(start)
ends.append(end)
keys.append(key.to_string())
block_hashes.append(req_meta.block_hashes[chunk_idx])
group_indices.append(g_idx)

# Apply put_step striding for TP
sl = slice(self.tp_rank % self.put_step, None, self.put_step)
starts = starts[sl]
ends = ends[sl]
keys = keys[sl]
block_hashes = block_hashes[sl]
group_indices = group_indices[sl]

if not keys:
return

# Check which blocks already exist (dedup)
save_exists_start = time.perf_counter()
try:
exists_states = self.store.batch_is_exist(keys)
except Exception:
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
status="error",
num_failed_keys=len(keys),
)
raise
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
status="error",
num_failed_keys=len(keys),
)
raise
self._record_operation(
"save_exists",
save_exists_start,
len(keys),
)
missing_indices = [i for i, exists in enumerate(exists_states) if exists != 1]
missing_indices = [
i for i, exists in enumerate(exists_states) if exists != 1
]

if not missing_indices:
self.dec_stored_request(req_id)
return
if not missing_indices:
return

starts = [starts[i] for i in missing_indices]
ends = [ends[i] for i in missing_indices]
keys = [keys[i] for i in missing_indices]
block_hashes = [block_hashes[i] for i in missing_indices]
group_indices = [group_indices[i] for i in missing_indices]
starts = [starts[i] for i in missing_indices]
ends = [ends[i] for i in missing_indices]
keys = [keys[i] for i in missing_indices]
block_hashes = [block_hashes[i] for i in missing_indices]
group_indices = [group_indices[i] for i in missing_indices]

logger.debug(
"Storing KV cache for %d blocks (groups=%s) for request %s",
len(keys),
set(group_indices),
req_id,
)
logger.debug(
"Storing KV cache for %d blocks (groups=%s) for request %s",
len(keys),
set(group_indices),
req_id,
)

addrs: list[list[int]] = []
sizes: list[list[int]] = []
stored_events: list[BlockStored] = []
# parent_block_hash chains live within a group, not across.
prev_key_per_group: dict[int, Any] = {}
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]
addrs: list[list[int]] = []
sizes: list[list[int]] = []
stored_events: list[BlockStored] = []
# parent_block_hash chains live within a group, not across.
prev_key_per_group: dict[int, Any] = {}
new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes]

for idx, (s, e, g_idx) in enumerate(
zip(starts, ends, group_indices, strict=True)
):
db = self.token_databases[g_idx]
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
addrs.append(addr)
sizes.append(size)

if self.enable_kv_event:
token_ids = (
req_meta.token_ids[s:e] if req_meta.token_ids is not None else None
)
stored_event = BlockStored(
block_hashes=[new_block_hashes[idx]],
parent_block_hash=prev_key_per_group.get(g_idx),
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key_per_group[g_idx] = new_block_hashes[idx]
for idx, (s, e, g_idx) in enumerate(
zip(starts, ends, group_indices, strict=True)
):
db = self.token_databases[g_idx]
addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx])
addrs.append(addr)
sizes.append(size)

if self.enable_kv_event:
token_ids = (
req_meta.token_ids[s:e]
if req_meta.token_ids is not None
else None
)
stored_event = BlockStored(
block_hashes=[new_block_hashes[idx]],
parent_block_hash=prev_key_per_group.get(g_idx),
token_ids=token_ids,
block_size=req_meta.original_block_size,
lora_id=None,
medium="cpu",
lora_name=None,
)
stored_events.append(stored_event)
prev_key_per_group[g_idx] = new_block_hashes[idx]

if current_event is not None:
current_event.synchronize()
if current_event is not None:
current_event.synchronize()

batch_bytes = _sum_batch_bytes(sizes)
put_start = time.perf_counter()
try:
res = self.store.batch_put_from_multi_buffers(
keys,
addrs,
sizes,
self.replicate_config,
)
failed = [i for i, v in enumerate(res) if v < 0]
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="partial_failure" if failed else "ok",
num_failed_keys=len(failed),
)
if failed:
failed_codes = set(res[i] for i in failed)
logger.warning(
"batch_put failed: %d/%d keys failed "
"(codes=%s, batch_bytes=%d, num_keys=%d), "
"first_key=%s",
len(failed),
len(keys),
failed_codes,
batch_bytes,
batch_bytes = _sum_batch_bytes(sizes)
put_start = time.perf_counter()
try:
res = self.store.batch_put_from_multi_buffers(
keys,
addrs,
sizes,
self.replicate_config,
)
failed = [i for i, v in enumerate(res) if v < 0]
self._record_operation(
"save_put",
put_start,
len(keys),
keys[0] if keys else "N/A",
num_bytes=batch_bytes,
status="partial_failure" if failed else "ok",
num_failed_keys=len(failed),
)
if (
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
and not self._mark_request_skipped_for_pressure(req_id)
):
if failed:
failed_codes = set(res[i] for i in failed)
logger.warning(
"Detected Mooncake CPU/disk offloading pressure "
"(NO_AVAILABLE_HANDLE); skipping future store "
"batches for request %s until a later store "
"batch succeeds",
req_id,
"batch_put failed: %d/%d keys failed "
"(codes=%s, batch_bytes=%d, num_keys=%d), "
"first_key=%s",
len(failed),
len(keys),
failed_codes,
batch_bytes,
len(keys),
keys[0] if keys else "N/A",
)
elif self._clear_store_pressure():
logger.info(
"Mooncake CPU/disk offloading pressure cleared after a "
"successful store batch"
if (
MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes
and not self._mark_request_skipped_for_pressure(req_id)
):
logger.warning(
"Detected Mooncake CPU/disk offloading pressure "
"(NO_AVAILABLE_HANDLE); skipping future store "
"batches for request %s until a later store "
"batch succeeds",
req_id,
)
elif self._clear_store_pressure():
logger.info(
"Mooncake CPU/disk offloading pressure cleared after a "
"successful store batch"
)
except Exception as e:
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="error",
num_failed_keys=len(keys),
)
except Exception as e:
self._record_operation(
"save_put",
put_start,
len(keys),
num_bytes=batch_bytes,
status="error",
num_failed_keys=len(keys),
)
logger.error("Failed to put key %s, error: %s", keys, e)
logger.error("Failed to put key %s, error: %s", keys, e)

if self.enable_kv_event and stored_events:
self.update_kv_event(stored_events)

self.dec_stored_request(req_id)
self.request_queue.task_done()
if self.enable_kv_event and stored_events:
self.update_kv_event(stored_events)
finally:
self.dec_stored_request(req_id)
self.request_queue.task_done()


class KVCacheStoreRecvingThread(KVTransferThread):
Expand Down
Loading