Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)

def get_block_ids_with_load_errors(self) -> set[int]:
"""Get the set of block IDs that failed to load."""
if self.connector_worker is not None:
return self.connector_worker.get_block_ids_with_load_errors()
return set()

def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
Expand Down Expand Up @@ -541,6 +547,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

self.finished_sending_reqs: set[ReqId] = set()
self.finished_recving_reqs: set[ReqId] = set()
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
# Track invalid block IDs for failed transfers (similar to nixl_connector)
self._invalid_block_ids: set[int] = set()

self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -974,6 +983,39 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
finished_recving_reqs = self.finished_recving_reqs
self.finished_recving_reqs = set()

# Handle timeout to avoid stranding blocks on remote
now = time.perf_counter()

expired_req_ids = []
# Create a copy of items to avoid concurrent modification issues
reqs_to_recv_items = list(self.reqs_to_recv.items())
for remote_engine_id, pull_metas in reqs_to_recv_items:
pull_metas_items = list(pull_metas.items())
for req_id, pull_meta in pull_metas_items:
if pull_meta.expire_time < now:
logger.warning(
"Request %s timed out after %d seconds without "
"finishing KV transfer from remote engine %s. "
"Marking %d blocks as invalid to prevent garbage output.",
pull_meta.d_req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
remote_engine_id,
len(pull_meta.local_block_ids),
)
# Mark blocks as invalid to prevent garbage output
self._invalid_block_ids.update(pull_meta.local_block_ids)
finished_recving_reqs.add(pull_meta.d_req_id)
expired_req_ids.append((remote_engine_id, req_id))

# Remove expired requests from tracking
for remote_engine_id, req_id in expired_req_ids:
if (
remote_engine_id in self.reqs_to_recv
and req_id in self.reqs_to_recv[remote_engine_id]
):
del self.reqs_to_recv[remote_engine_id][req_id]

return finished_recving_reqs

async def fetch_finished_sending_reqs(self) -> set[ReqId]:
Expand Down Expand Up @@ -1088,7 +1130,17 @@ async def receive_kv_from_single_worker(
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
except Exception as e:
logger.error("MooncakeXferMetadata transfer failed for %s: %s", req_ids, e)
logger.error(
"MooncakeXferMetadata transfer failed for %s: %s. "
"Marking associated blocks as invalid to prevent garbage output.",
req_ids,
e,
)
# Add failed requests to finished_recving_reqs for scheduler cleanup
for req_id, pull_meta in pull_metas.items():
# Mark blocks as invalid to prevent garbage output
self._invalid_block_ids.update(pull_meta.local_block_ids)
self.finished_recving_reqs.add(pull_meta.d_req_id)
return

def process_pulling_result(
Expand All @@ -1110,10 +1162,18 @@ def process_pulling_result(

if response.err_reqs:
logger.error(
"pulling kv_caches for %s failed: %s",
"pulling kv_caches for %s failed: %s. "
"Marking associated blocks as invalid to prevent garbage output.",
response.err_reqs,
response.err_msg,
)
# Add failed requests to finished_recving_reqs for scheduler cleanup
for req_id in response.err_reqs:
if req_id in pull_metas:
pull_meta = pull_metas[req_id]
# Mark blocks as invalid to prevent garbage output
self._invalid_block_ids.update(pull_meta.local_block_ids)
self.finished_recving_reqs.add(pull_meta.d_req_id)

async def _connect_to_prefiller_bootstrap(self, remote_bootstrap_addr: str):
url = remote_bootstrap_addr + "/query"
Expand Down Expand Up @@ -1157,8 +1217,13 @@ def receive_kv(
raise NotImplementedError(
"Mooncake: Heterogeneous TP is not supported yet."
)
expire_time = time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
for pull_meta in pull_metas.values():
pull_meta.pull_tasks_count = count
# Set expire time to avoid infinitely waiting for remote KV transfer
# Only set if not already set to avoid race conditions
if pull_meta.expire_time == float("inf"):
pull_meta.expire_time = expire_time
for remote_tp_rank in remote_tp_ranks:
worker_addr = self._remote_agents[remote_engine_id][remote_tp_rank][0]
asyncio.create_task(
Expand Down Expand Up @@ -1228,6 +1293,8 @@ async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):

def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if not self.is_kv_producer and metadata.reqs_to_recv:
# Store requests to receive for timeout tracking
self.reqs_to_recv.update(metadata.reqs_to_recv)
asyncio.run_coroutine_threadsafe(
self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop
)
Expand All @@ -1239,6 +1306,20 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
self.record_send_reqs(metadata), self.sender_loop
)

def get_block_ids_with_load_errors(self) -> set[int]:
"""
Return and clear the set of block IDs that failed to load.

This is called by the scheduler to identify blocks that need
to be retried after a Mooncake transfer failure or timeout.

Returns:
Set of block IDs that encountered load errors. Empty set if none.
"""
result = self._invalid_block_ids
self._invalid_block_ids = set()
return result


def group_concurrent_contiguous(
src_indices: list[int], dst_indices: list[int]
Expand Down
Loading