diff --git a/tests/v1/kv_connector/unit/test_mooncake_store_worker.py b/tests/v1/kv_connector/unit/test_mooncake_store_worker.py index 6adb045277fa..375aad4eeb8f 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_store_worker.py +++ b/tests/v1/kv_connector/unit/test_mooncake_store_worker.py @@ -461,6 +461,35 @@ def test_store_sending_thread_only_skips_on_no_available_handle(): assert store.batch_put_from_multi_buffers.call_count == 2 +def test_store_sending_thread_releases_pin_on_batch_is_exist_failure(): + # `batch_is_exist` raising must still decrement `stored_requests` so the + # scheduler can drop `delay_free_blocks` and release the pinned GPU blocks. + store = MagicMock() + store.batch_is_exist.side_effect = RuntimeError("mooncake down") + thread = _make_store_sending_thread(store) + + thread.add_stored_request("req-a") + with pytest.raises(RuntimeError): + thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"])) + + assert thread.stored_requests["req-a"] == 0 + store.batch_put_from_multi_buffers.assert_not_called() + + +def test_store_sending_thread_releases_pin_on_batch_put_failure(): + # `batch_put_from_multi_buffers` raising is logged (not re-raised), and the + # pin must still be released through the finally block. + store = MagicMock() + store.batch_is_exist.return_value = [0, 0] + store.batch_put_from_multi_buffers.side_effect = RuntimeError("rdma error") + thread = _make_store_sending_thread(store) + + thread.add_stored_request("req-a") + thread._handle_request(_make_store_req("req-a", [b"a0", b"a1"])) + + assert thread.stored_requests["req-a"] == 0 + + def test_store_recving_thread_reports_failed_block_ids(): store = MagicMock() store.batch_get_into_multi_buffers.return_value = [256, -5, -7] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py index 486c2553b6d1..cd4eb5c37136 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py @@ -517,187 +517,190 @@ def _handle_request(self, req_meta: ReqMeta): if req_id not in self.stored_requests: self.request_queue.task_done() return - if token_len == 0: - self.dec_stored_request(req_id) - self.request_queue.task_done() - return - if self._should_skip_request(req_id): - logger.debug( - "Skipping Mooncake store for request %s while CPU/disk offloading " - "is under pressure", - req_id, - ) - self.dec_stored_request(req_id) - self.request_queue.task_done() - return - # Within each lcm region only per-spec relevant chunks are loaded - # (e.g., SWA or linear attn), so mask out irrelevant chunks - store_masks = self.coord.store_mask(token_len) - starts: list[int] = [] - ends: list[int] = [] - keys: list[str] = [] - block_hashes: list[BlockHash] = [] - group_indices: list[int] = [] - for g_idx, db in enumerate(self.token_databases): - mask = store_masks[g_idx] - for chunk_idx, (start, end, key) in enumerate( - db.process_tokens(token_len, req_meta.block_hashes) - ): - if chunk_idx >= len(mask) or not mask[chunk_idx]: - continue - starts.append(start) - ends.append(end) - keys.append(key.to_string()) - block_hashes.append(req_meta.block_hashes[chunk_idx]) - group_indices.append(g_idx) - - # Apply put_step striding for TP - sl = slice(self.tp_rank % self.put_step, None, self.put_step) - starts = starts[sl] - ends = ends[sl] - keys = keys[sl] - block_hashes = block_hashes[sl] - group_indices = group_indices[sl] - - if not keys: - self.dec_stored_request(req_id) - return - - # Check which blocks already exist (dedup) - save_exists_start = time.perf_counter() + # Decrement the in-flight counter and signal task_done() in `finally` + # so the scheduler can release the GPU blocks it pinned for this + # request (via `delay_free_blocks`) even when the store path raises. try: - exists_states = self.store.batch_is_exist(keys) - except Exception: + if token_len == 0: + return + if self._should_skip_request(req_id): + logger.debug( + "Skipping Mooncake store for request %s while CPU/disk " + "offloading is under pressure", + req_id, + ) + return + + # Within each lcm region only per-spec relevant chunks are loaded + # (e.g., SWA or linear attn), so mask out irrelevant chunks + store_masks = self.coord.store_mask(token_len) + starts: list[int] = [] + ends: list[int] = [] + keys: list[str] = [] + block_hashes: list[BlockHash] = [] + group_indices: list[int] = [] + for g_idx, db in enumerate(self.token_databases): + mask = store_masks[g_idx] + for chunk_idx, (start, end, key) in enumerate( + db.process_tokens(token_len, req_meta.block_hashes) + ): + if chunk_idx >= len(mask) or not mask[chunk_idx]: + continue + starts.append(start) + ends.append(end) + keys.append(key.to_string()) + block_hashes.append(req_meta.block_hashes[chunk_idx]) + group_indices.append(g_idx) + + # Apply put_step striding for TP + sl = slice(self.tp_rank % self.put_step, None, self.put_step) + starts = starts[sl] + ends = ends[sl] + keys = keys[sl] + block_hashes = block_hashes[sl] + group_indices = group_indices[sl] + + if not keys: + return + + # Check which blocks already exist (dedup) + save_exists_start = time.perf_counter() + try: + exists_states = self.store.batch_is_exist(keys) + except Exception: + self._record_operation( + "save_exists", + save_exists_start, + len(keys), + status="error", + num_failed_keys=len(keys), + ) + raise self._record_operation( "save_exists", save_exists_start, len(keys), - status="error", - num_failed_keys=len(keys), ) - raise - self._record_operation( - "save_exists", - save_exists_start, - len(keys), - ) - missing_indices = [i for i, exists in enumerate(exists_states) if exists != 1] + missing_indices = [ + i for i, exists in enumerate(exists_states) if exists != 1 + ] - if not missing_indices: - self.dec_stored_request(req_id) - return + if not missing_indices: + return - starts = [starts[i] for i in missing_indices] - ends = [ends[i] for i in missing_indices] - keys = [keys[i] for i in missing_indices] - block_hashes = [block_hashes[i] for i in missing_indices] - group_indices = [group_indices[i] for i in missing_indices] + starts = [starts[i] for i in missing_indices] + ends = [ends[i] for i in missing_indices] + keys = [keys[i] for i in missing_indices] + block_hashes = [block_hashes[i] for i in missing_indices] + group_indices = [group_indices[i] for i in missing_indices] - logger.debug( - "Storing KV cache for %d blocks (groups=%s) for request %s", - len(keys), - set(group_indices), - req_id, - ) + logger.debug( + "Storing KV cache for %d blocks (groups=%s) for request %s", + len(keys), + set(group_indices), + req_id, + ) - addrs: list[list[int]] = [] - sizes: list[list[int]] = [] - stored_events: list[BlockStored] = [] - # parent_block_hash chains live within a group, not across. - prev_key_per_group: dict[int, Any] = {} - new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes] + addrs: list[list[int]] = [] + sizes: list[list[int]] = [] + stored_events: list[BlockStored] = [] + # parent_block_hash chains live within a group, not across. + prev_key_per_group: dict[int, Any] = {} + new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes] - for idx, (s, e, g_idx) in enumerate( - zip(starts, ends, group_indices, strict=True) - ): - db = self.token_databases[g_idx] - addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx]) - addrs.append(addr) - sizes.append(size) - - if self.enable_kv_event: - token_ids = ( - req_meta.token_ids[s:e] if req_meta.token_ids is not None else None - ) - stored_event = BlockStored( - block_hashes=[new_block_hashes[idx]], - parent_block_hash=prev_key_per_group.get(g_idx), - token_ids=token_ids, - block_size=req_meta.original_block_size, - lora_id=None, - medium="cpu", - lora_name=None, - ) - stored_events.append(stored_event) - prev_key_per_group[g_idx] = new_block_hashes[idx] + for idx, (s, e, g_idx) in enumerate( + zip(starts, ends, group_indices, strict=True) + ): + db = self.token_databases[g_idx] + addr, size, _ = db.prepare_value(s, e, block_ids_per_group[g_idx]) + addrs.append(addr) + sizes.append(size) + + if self.enable_kv_event: + token_ids = ( + req_meta.token_ids[s:e] + if req_meta.token_ids is not None + else None + ) + stored_event = BlockStored( + block_hashes=[new_block_hashes[idx]], + parent_block_hash=prev_key_per_group.get(g_idx), + token_ids=token_ids, + block_size=req_meta.original_block_size, + lora_id=None, + medium="cpu", + lora_name=None, + ) + stored_events.append(stored_event) + prev_key_per_group[g_idx] = new_block_hashes[idx] - if current_event is not None: - current_event.synchronize() + if current_event is not None: + current_event.synchronize() - batch_bytes = _sum_batch_bytes(sizes) - put_start = time.perf_counter() - try: - res = self.store.batch_put_from_multi_buffers( - keys, - addrs, - sizes, - self.replicate_config, - ) - failed = [i for i, v in enumerate(res) if v < 0] - self._record_operation( - "save_put", - put_start, - len(keys), - num_bytes=batch_bytes, - status="partial_failure" if failed else "ok", - num_failed_keys=len(failed), - ) - if failed: - failed_codes = set(res[i] for i in failed) - logger.warning( - "batch_put failed: %d/%d keys failed " - "(codes=%s, batch_bytes=%d, num_keys=%d), " - "first_key=%s", - len(failed), - len(keys), - failed_codes, - batch_bytes, + batch_bytes = _sum_batch_bytes(sizes) + put_start = time.perf_counter() + try: + res = self.store.batch_put_from_multi_buffers( + keys, + addrs, + sizes, + self.replicate_config, + ) + failed = [i for i, v in enumerate(res) if v < 0] + self._record_operation( + "save_put", + put_start, len(keys), - keys[0] if keys else "N/A", + num_bytes=batch_bytes, + status="partial_failure" if failed else "ok", + num_failed_keys=len(failed), ) - if ( - MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes - and not self._mark_request_skipped_for_pressure(req_id) - ): + if failed: + failed_codes = set(res[i] for i in failed) logger.warning( - "Detected Mooncake CPU/disk offloading pressure " - "(NO_AVAILABLE_HANDLE); skipping future store " - "batches for request %s until a later store " - "batch succeeds", - req_id, + "batch_put failed: %d/%d keys failed " + "(codes=%s, batch_bytes=%d, num_keys=%d), " + "first_key=%s", + len(failed), + len(keys), + failed_codes, + batch_bytes, + len(keys), + keys[0] if keys else "N/A", ) - elif self._clear_store_pressure(): - logger.info( - "Mooncake CPU/disk offloading pressure cleared after a " - "successful store batch" + if ( + MOONCAKE_NO_AVAILABLE_HANDLE in failed_codes + and not self._mark_request_skipped_for_pressure(req_id) + ): + logger.warning( + "Detected Mooncake CPU/disk offloading pressure " + "(NO_AVAILABLE_HANDLE); skipping future store " + "batches for request %s until a later store " + "batch succeeds", + req_id, + ) + elif self._clear_store_pressure(): + logger.info( + "Mooncake CPU/disk offloading pressure cleared after a " + "successful store batch" + ) + except Exception as e: + self._record_operation( + "save_put", + put_start, + len(keys), + num_bytes=batch_bytes, + status="error", + num_failed_keys=len(keys), ) - except Exception as e: - self._record_operation( - "save_put", - put_start, - len(keys), - num_bytes=batch_bytes, - status="error", - num_failed_keys=len(keys), - ) - logger.error("Failed to put key %s, error: %s", keys, e) + logger.error("Failed to put key %s, error: %s", keys, e) - if self.enable_kv_event and stored_events: - self.update_kv_event(stored_events) - - self.dec_stored_request(req_id) - self.request_queue.task_done() + if self.enable_kv_event and stored_events: + self.update_kv_event(stored_events) + finally: + self.dec_stored_request(req_id) + self.request_queue.task_done() class KVCacheStoreRecvingThread(KVTransferThread):