Skip to content

[P/D] rework mooncake connector and introduce its bootstrap server#31034

Merged
vllm-bot merged 15 commits intovllm-project:mainfrom
openanolis:dtcccc/mooncake_connector
Feb 3, 2026
Merged

[P/D] rework mooncake connector and introduce its bootstrap server#31034
vllm-bot merged 15 commits intovllm-project:mainfrom
openanolis:dtcccc/mooncake_connector

Conversation

@dtcccc
Copy link
Contributor

@dtcccc dtcccc commented Dec 19, 2025

Purpose

Rework mooncake connector to achieve better performance and prepare for more features in future.
Introduce a central bootstrap server on P.
1
init phase:
All P workers register their info (dp/tp/pp rank, zmq worker addr) with the bootstrap server.
After all P workers finished registering, proxy and D workers can query when they meet a new engine_id.

Note:
(deprecated)

Since #30739 data_parallel_size and data_parallel_size_local are stick to 1 for non-Moe models. So origin_data_parallel_size and origin_data_parallel_size_local are introduced to get the raw value.
For non-Moe model with dp_size > 1, all its engines by dp have the same engine_id and unable to distinguish dp_ranks and workers. So [engine_id][dp_rank] is used. See comments in MooncakeBootstrapServer for detail.

After startup, the proxy will send reqs to both P and D concurrently. Due to #27987 P and D cannot know the exact request_id of each other. #32630 seems to be a solution, and I just try to drop the last 9 chars as workaround here. See comments in TruncatingDict.

Thanks to #33037 and #32937 we can drop all workarounds now.

This design is partially inspired by sglang, aiming to improve Time To First Token (TTFT) performance and to lay the groundwork for a future layerwise transfer feature.

With random dataset and max-concurrency = 1, TTFT on two A10 machines running Qwen2.5-7B-Instruct is improved:
random-input-len 128: 83.41ms -> 77.53ms
random-input-len 1024: 252.35ms ->246.25ms
This result shows TTFT win about 6ms from running in P and D simultaneously.

Other highlights of this PR:

  • Introduced the proxy and example scripts for the new bootstrap server architecture, along with updated documentation to facilitate learning and testing.
  • Prepare the future introduction of heterogeneous TP. For now, heterogeneous TP and PP is explicitly rejected.
  • Pick a bugfix from [Bugfix] Fix _reqs_to_process leak on abort #26012

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Dec 19, 2025

Documentation preview: https://vllm--31034.org.readthedocs.build/en/31034/

@mergify mergify bot added documentation Improvements or additions to documentation kv-connector labels Dec 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-structured rework of the mooncake connector, introducing a central bootstrap server to improve performance and enable future features. The refactoring to an asynchronous design is a great improvement. My review focuses on the robustness of the new distributed communication patterns. I've identified a few high-severity issues related to error handling and idempotency that could impact the system's reliability, particularly under failure conditions.

Comment on lines +415 to +423
except Exception as e:
err_msg = (
e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e)
)
logger.error(
"Failed to register request %s with bootstrap server: %s",
req_id,
err_msg,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The failure to register a request with the bootstrap server is handled by logging an error, but the failure is not propagated. Since register_req_with_bootstrap is called in a fire-and-forget manner, the scheduler remains unaware of the failure. This will likely cause the request to hang on the decoder side until it times out, which is not a clean failure mode and makes debugging difficult. A more robust solution would be to communicate this failure back to the scheduler (e.g., via a queue) so it can abort the request immediately with a clear error reason.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a valid concern, we should at least add a TODO.
Re: disconnects I am assuming the proxy forwards to both instances

Comment on lines +1170 to +1179
except Exception as e:
err_msg = (
e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e)
)
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
logger.error(
"Failed to query bootstrap server for %d requests: %s",
len(req_ids),
err_msg,
)
kv_pulls[path].append((req_id, meta.local_block_ids))
return {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The except block in batch_query_requests catches all exceptions and returns an empty dictionary. The calling function, handle_bootstrap_group, does not correctly handle this case, leading to the requests in the failed batch being effectively dropped. They are not retried or explicitly failed, which will likely cause them to hang until the scheduler times them out. This can make the system brittle, especially if the bootstrap server is temporarily unavailable. The client should implement retries for transient errors (e.g., 5xx status codes, connection errors) and handle non-retriable errors more explicitly.

Comment on lines +1399 to +1404
if (reg := self.req_to_dp_rank.get(payload.req_id)) is not None:
raise HTTPException(
status_code=400,
detail=f"Request '{payload.req_id}' already registered with rank {reg} "
f"but still want to register with rank {payload.dp_rank}",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The register_request endpoint is not idempotent. If a client retries a registration request (e.g., due to a transient network error), the second attempt will fail with a 400 error because the request ID is already present. This prevents the implementation of a robust retry mechanism on the client side. The endpoint should handle duplicate registrations for the same request and DP rank gracefully, for example, by returning a success response.

        if (reg := self.req_to_dp_rank.get(payload.req_id)) is not None:
            if reg == payload.dp_rank:
                # Request is already registered with the same dp_rank, treat as success.
                return {"status": "ok"}
            raise HTTPException(
                status_code=400,
                detail=f"Request '{payload.req_id}' already registered with rank {reg} "
                f"but still want to register with rank {payload.dp_rank}",
            )

@@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a separate proxy for mooncake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I've seen so far, the proxies used by vLLM's various connectors aren't unified. For example, disagg_proxy_demo.py doesn't support the nixl-specific "do_remote_decode" and "do_remote_prefill" parameters, which is why nixl has its own toy_proxy_server.py. So, in my view, it's common practice for each connector to have its own dedicated proxy.

To get back to the main point, we are using this specific proxy because it aligns with the central bootstrap server architecture we want to introduce. This is a prerequisite for supporting layerwise transfer in the future. (The other existing proxies only send data to D after receiving the full result from P, which is too late for our use case.)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1055 to +1058
await sock.send(encoded_data)
while True:
ret_msg = await sock.recv()
response = self._xfer_resp_decoder.decode(ret_msg)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge KV pull waits indefinitely when prefiller never replies

The decoder’s receive_kv loop now waits on await sock.recv() with no timeout or cancellation. If the prefiller never responds (e.g., wrong bootstrap address or the producer crashes mid-transfer), this coroutine hangs forever and the request is never re-queued or marked finished, so the abort timeout VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT is never enforced and the decode side leaks the request until process shutdown. The previous implementation set a receive timeout; we should restore a bounded wait or explicit cancellation/retry.

Useful? React with 👍 / 👎.

@mergify
Copy link

mergify bot commented Dec 19, 2025

Hi @dtcccc, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work once again @dtcccc !

After some thinking, I believe all in all these changes are already fine.

My main concern is the added complexity that this design brings in: we have a longer "failure-chain" to propagate through and effectively an extra point of failure in the side fastapi server which needs handling of failures for the registering/querying endpoints .
In particular, I feel like the extra "registering" call that each P worker now has to perform in its flow is the weakest link, so we really have to be sure the benefits from having D be aware of the req at step 0 outweigh this overhead.
To that extent it would be nice to have a broader benchmark sweep to confirm TTFT gains, possibly comparing with a prev version + "Refactored the sender thread using async coroutines" (in hindsight, this should've been a separate PR to help the review process here).

But I also understand that the alternative to the point above would be to "push" the dp_rank request-selection at the proxy level (or some Coordinator in front of the DP instances @robertgshaw2-redhat ) which would take away control from the connector and/or require more invasive changes. Therefore I am overall ok with this, but just wanted to bring up a few points for discussion.

One qq, what is currently stopping D from running query_requests before P registers the request (3/4)?

cc @wseaton for changes that are very similar to a past work of yours (+opinion on future failure handling?)

Comment on lines +66 to +71
http_log_level = logger.getEffectiveLevel()
# INFO logs of http are too noisy. Silence them.
# Setting vllm log level to DEBUG if we really want to see.
if http_log_level == logging.INFO:
http_log_level = logging.WARNING
logging.getLogger("httpx").setLevel(http_log_level)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit arbitrary, we should probably either do it in the mooncakeconnector init or push a separate global change cc @markmc

Either ways, we should log that we're "silencing" logs.



class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
self.reqs_to_recv: list[PullReqMeta] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we switch to a list here?
If we want to group the reqs by remote boostrap server, can't we just use a dict[server_addr, list[meta]] ?

Comment on lines +415 to +423
except Exception as e:
err_msg = (
e.response.text if isinstance(e, httpx.HTTPStatusError) else str(e)
)
logger.error(
"Failed to register request %s with bootstrap server: %s",
req_id,
err_msg,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a valid concern, we should at least add a TODO.
Re: disconnects I am assuming the proxy forwards to both instances

Comment on lines +1319 to +1321
class MooncakeBootstrapServer:
"""
A centralized server running on the global rank 0 prefiller worker.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could probably live in a separate mooncake_utils file

f"expected {self.tp_size}, got {payload.tp_size}"
),
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger debug with source payload info could help here

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 7 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

ready=asyncio.Event(),
)
for p_req_id in metadata.reqs_not_processed:
send_meta = self.reqs_need_send.pop(p_req_id)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeyError from dict.pop without default argument

High Severity

The code calls self.reqs_need_send.pop(p_req_id) without a default value, which raises KeyError when p_req_id doesn't exist in the dictionary. Since TruncatingDict doesn't implement the pop method, it falls back to the base MutableMapping.pop, which requires the key to exist unless a default is provided. The code then checks if send_meta:, suggesting it expects None as a possible return value, but the current implementation will crash instead of returning None.

Fix in Cursor Fix in Web

if d_req_id not in self.reqs_need_send:
# This req is not enqueued in P side yet, create it here.
self.reqs_need_send[d_req_id] = SendBlockMeta(
p_req_id="", local_block_ids=[], ready=asyncio.Event()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition with asyncio.Event from wrong loop

High Severity

asyncio.Event() is created without specifying the event loop, defaulting to the current thread's event loop. At line 702 in send_kv_to_decode (which runs in sender_loop) and line 1211 in record_send_reqs (also in sender_loop), these events are created but may be used across different event loops. The event at line 702 is created when handling a decoder request, while the sender_loop is a separate background event loop. This creates a cross-loop event sharing issue that can cause race conditions or incorrect event signaling behavior.

Additional Locations (1)

Fix in Cursor Fix in Web

while True:
for prefill_client in prefill_clients:
for i in range(prefill_client["dp_size"]):
yield prefill_client, i
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Infinite generator causes early prefiller cycling

Medium Severity

The prefiller_cycle generator is initialized at line 115 before get_prefiller_info completes, causing it to iterate with prefill_client["dp_size"] being undefined or zero. At line 38, range(prefill_client["dp_size"]) will be range(0) since dp_size is only set later in line 60 of get_prefiller_info. This means the first request(s) may skip prefiller workers or behave incorrectly until get_prefiller_info finishes and populates dp_size.

Additional Locations (1)

Fix in Cursor Fix in Web

if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send[req_id]
send_meta = self.reqs_need_send[p_req_id]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential KeyError accessing reqs_need_send without check

High Severity

The code accesses self.reqs_need_send[p_req_id] assuming the entry exists based on a comment saying "Already gone through request_finished()". However, there's a race condition where the scheduler's metadata with non-empty block_ids could arrive at the worker before the decoder's ZMQ request creates the entry in send_kv_to_decode (lines 698-704). Since reqs_need_send uses TruncatingDict which raises KeyError on missing keys, this can crash when metadata processing races ahead of decoder request handling.

Fix in Cursor Fix in Web

for remote_tp_rank in remote_tp_ranks:
worker_addr = self._remote_agents[remote_engine_id][remote_dp_rank][
remote_tp_rank
][0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation causes KeyError for nested dict

High Severity

The code accesses a deeply nested dictionary self._remote_agents[remote_engine_id][remote_dp_rank][remote_tp_rank][0] without validating that remote_dp_rank and remote_tp_rank exist in the structure. While remote_engine_id is checked at line 1168, the specific dp_rank and tp_rank keys could be missing from the bootstrap server response, causing KeyError when attempting the nested access. The bootstrap server might not have registered workers for all expected rank combinations.

Fix in Cursor Fix in Web

"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": prefill_client_info["bootstrap_addr"],
"remote_engine_id": prefill_client_info["dp_engine_id"][prefill_dp_rank],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeyError from dp_engine_id with non-sequential ranks

High Severity

The prefiller_cycle generator yields ranks from 0 to dp_size-1 (line 38), but dp_engine_id dictionary is populated with actual dp_rank values from the bootstrap server (line 59), which may not be sequential from zero. When accessing prefill_client_info["dp_engine_id"][prefill_dp_rank] at line 301, if the bootstrap server returns non-sequential dp_ranks (e.g., ranks 2 and 3 instead of 0 and 1), the code will access rank 0 which doesn't exist in the dictionary, causing KeyError.

Additional Locations (1)

Fix in Cursor Fix in Web


# Initialize round-robin iterators
app.state.prefill_iterator = prefiller_cycle(app.state.prefill_clients)
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty client lists cause StopIteration on requests

High Severity

When no --prefill or --decode arguments are provided, the iterators at lines 115-116 are created over empty collections. Line 115 creates a generator that never yields for empty prefill_clients, and line 116 creates itertools.cycle(range(0)) for empty decode_clients. When get_next_client calls next() on these iterators at lines 241 or 243, it raises StopIteration, crashing the request handler. There's no validation that at least one server of each type is configured.

Additional Locations (1)

Fix in Cursor Fix in Web

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
@dtcccc dtcccc requested a review from orozery as a code owner January 29, 2026 05:14
@dtcccc dtcccc force-pushed the dtcccc/mooncake_connector branch from 0228980 to 9988771 Compare January 29, 2026 10:59
@NickLucche NickLucche enabled auto-merge (squash) January 30, 2026 08:42
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 30, 2026
@mergify
Copy link

mergify bot commented Jan 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @dtcccc.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 30, 2026
…ctor

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
auto-merge was automatically disabled January 30, 2026 10:04

Head branch was pushed to by a user without write access

@mergify mergify bot removed the needs-rebase label Jan 30, 2026
@NickLucche NickLucche enabled auto-merge (squash) January 30, 2026 10:34
@NickLucche NickLucche disabled auto-merge February 3, 2026 15:38
@NickLucche NickLucche enabled auto-merge (squash) February 3, 2026 15:39
@vllm-bot vllm-bot merged commit 0d6ccf6 into vllm-project:main Feb 3, 2026
46 of 48 checks passed
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…llm-project#31034)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Signed-off-by: Pai <416932041@qq.com>
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
…llm-project#31034)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…llm-project#31034)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants