Skip to content

[NIXL] Add remote_request_id to kv_transfer_params#29665

Merged
njhill merged 2 commits intovllm-project:mainfrom
markmc:nixl-remote-request-id
Dec 5, 2025
Merged

[NIXL] Add remote_request_id to kv_transfer_params#29665
njhill merged 2 commits intovllm-project:mainfrom
markmc:nixl-remote-request-id

Conversation

@markmc
Copy link
Copy Markdown
Member

@markmc markmc commented Nov 28, 2025

Include the internal request ID that the prefill instance is expecting the decode instance to send it in the NIXL notification.

Right now, we rely on the proxy supplying the ID via X-Request-ID and that prefill and decode will mangle this ID in identical ways. This is obviously quite brittle, and P should be explicit about what ID it expects from D.

Relates to #27987 - adding a random prefix to client-provided request IDs.

@markmc
Copy link
Copy Markdown
Member Author

markmc commented Nov 28, 2025

Example of it working on top of #27987

Prefill side:

$ grep cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0 nixl-prefill-1764323*
nixl-prefill-1764323995.log:(APIServer pid=3451400) DEBUG 11-28 05:02:04 [entrypoints/logger.py:37] Request cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0 details: prompt: 'Do you know the book Traction by Gino Wickman', prompt_token_ids: [128000, 5519, 499, 1440, 279, 2363, 350, 16597, 555, 480, 3394, 75206, 1543], prompt_embeds shape: None.
nixl-prefill-1764323995.log:(APIServer pid=3451400) INFO 11-28 05:02:04 [entrypoints/logger.py:47] Received request cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0: params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=1, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args={'kv_transfer_params': {'do_remote_decode': True, 'do_remote_prefill': False, 'remote_engine_id': None, 'remote_block_ids': None, 'remote_host': None, 'remote_port': None}}), lora_request: None.
nixl-prefill-1764323995.log:(APIServer pid=3451400) INFO 11-28 05:02:04 [v1/engine/async_llm.py:360] Added request cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0.
nixl-prefill-1764323995.log:(EngineCore_DP0 pid=3452005) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:620] NIXLConnector request_finished(cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0), request_status=FINISHED_LENGTH_CAPPED, kv_transfer_params={'do_remote_decode': True, 'do_remote_prefill': False, 'remote_engine_id': None, 'remote_block_ids': None, 'remote_host': None, 'remote_port': None}
nixl-prefill-1764323995.log:(EngineCore_DP0 pid=3452005) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:655] NIXLConnector request_finished(cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0) waiting for 480 seconds for remote decode to fetch blocks
nixl-prefill-1764323995.log:(EngineCore_DP0 pid=3452005) DEBUG 11-28 05:02:04 [v1/core/sched/scheduler.py:1492] Finished sending KV transfer for request cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0

Decode side, notice the decode request id is cmpl-b82f4d1c-482061a8-551f-4964-925d-fc0925e9261b-0 and the prefill request id is cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0

grep cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0 nixl-decode-1764323*
nixl-decode-1764323996.log:(APIServer pid=3451466) INFO 11-28 05:02:04 [entrypoints/logger.py:47] Received request cmpl-b82f4d1c-482061a8-551f-4964-925d-fc0925e9261b-0: params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=120, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args={'kv_transfer_params': {'do_remote_prefill': True, 'do_remote_decode': False, 'remote_block_ids': [91], 'remote_engine_id': '75192ab7-d1c6-48cf-bd82-ccef4ba0e1d6', 'remote_request_id': 'cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0', 'remote_host': 'localhost', 'remote_port': 5559, 'tp_size': 1}}), lora_request: None.
nixl-decode-1764323996.log:(EngineCore_DP0 pid=3452080) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:486] NIXLConnector get_num_new_matched_tokens: num_computed_tokens=0, kv_transfer_params={'do_remote_prefill': True, 'do_remote_decode': False, 'remote_block_ids': [91], 'remote_engine_id': '75192ab7-d1c6-48cf-bd82-ccef4ba0e1d6', 'remote_request_id': 'cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0', 'remote_host': 'localhost', 'remote_port': 5559, 'tp_size': 1}
nixl-decode-1764323996.log:(EngineCore_DP0 pid=3452080) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:507] NIXLConnector update_state_after_alloc: num_external_tokens=13, kv_transfer_params={'do_remote_prefill': True, 'do_remote_decode': False, 'remote_block_ids': [91], 'remote_engine_id': '75192ab7-d1c6-48cf-bd82-ccef4ba0e1d6', 'remote_request_id': 'cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0', 'remote_host': 'localhost', 'remote_port': 5559, 'tp_size': 1}
nixl-decode-1764323996.log:(EngineCore_DP0 pid=3452080) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:507] NIXLConnector update_state_after_alloc: num_external_tokens=0, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [91], 'remote_engine_id': '75192ab7-d1c6-48cf-bd82-ccef4ba0e1d6', 'remote_request_id': 'cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0', 'remote_host': 'localhost', 'remote_port': 5559, 'tp_size': 1}
nixl-decode-1764323996.log:(EngineCore_DP0 pid=3452080) DEBUG 11-28 05:02:04 [distributed/.../v1/nixl_connector.py:620] NIXLConnector request_finished(cmpl-b82f4d1c-482061a8-551f-4964-925d-fc0925e9261b-0), request_status=FINISHED_LENGTH_CAPPED, kv_transfer_params={'do_remote_prefill': False, 'do_remote_decode': False, 'remote_block_ids': [91], 'remote_engine_id': '75192ab7-d1c6-48cf-bd82-ccef4ba0e1d6', 'remote_request_id': 'cmpl-b9d6c33a-482061a8-551f-4964-925d-fc0925e9261b-0', 'remote_host': 'localhost', 'remote_port': 5559, 'tp_size': 1}

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces remote_request_id to make the NIXL notification more robust, which is a good improvement. The changes correctly plumb this new ID to where it's needed for constructing the notification ID. However, I've found a critical issue where direct dictionary access on kv_transfer_params could lead to a KeyError in certain paths. I've provided a suggestion to make this access safer.

@markmc markmc added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 28, 2025
@markmc markmc force-pushed the nixl-remote-request-id branch from c519e25 to 9d04ff1 Compare November 28, 2025 15:03
@mergify mergify bot added the v1 label Nov 28, 2025
Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markmc this looks great. I guess a couple of things we should consider:

  • This would require P workers to be updated in tandem with D workers. I think you may have already been looking into questions related to that more generally?
  • Proxies would also require an update if they don't propagate arbitrary kv transfer params (we can check whether this is the case for the toy proxy, llm-d, etc)

I think we could handle the above issues for now by adjusting the P side logic in this PR to fall back on local request_id if remote_request_id isn't present in the kv transfer params. But they might become a concern again once we switch to enforced unique internal request ids.

@markmc
Copy link
Copy Markdown
Member Author

markmc commented Dec 1, 2025

Thanks @markmc this looks great. I guess a couple of things we should consider:

  • This would require P workers to be updated in tandem with D workers. I think you may have already been looking into questions related to that more generally?

Yes, see #29503

  • Proxies would also require an update if they don't propagate arbitrary kv transfer params (we can check whether this is the case for the toy proxy, llm-d, etc)

I think we could handle the above issues for now by adjusting the P side logic in this PR to fall back on local request_id if remote_request_id isn't present in the kv transfer params. But they might become a concern again once we switch to enforced unique internal request ids.

ITYM the D side logic? I'm sceptical ... it seems like trying to handle a situation that shouldn't happen. If rolling out a change like this causes unforseen problems, it's probably better we hear about it.

@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 1, 2025

ITYM the D side logic?

Yes sorry, that's what I meant!

If rolling out a change like this causes unforseen problems, it's probably better we hear about it.

Makes sense, it might then be worth at least detecting this situation and logging an appropriate warning and/or failing in an obvious manner? i.e. if D receives the metadata with the expected fields other than this one.

@markmc
Copy link
Copy Markdown
Member Author

markmc commented Dec 1, 2025

If rolling out a change like this causes unforseen problems, it's probably better we hear about it.

Makes sense, it might then be worth at least detecting this situation and logging an appropriate warning and/or failing in an obvious manner? i.e. if D receives the metadata with the expected fields other than this one.

There's a "Got invalid KVTransferParams" warning for that already?

@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 2, 2025

If rolling out a change like this causes unforseen problems, it's probably better we hear about it.

Makes sense, it might then be worth at least detecting this situation and logging an appropriate warning and/or failing in an obvious manner? i.e. if D receives the metadata with the expected fields other than this one.

There's a "Got invalid KVTransferParams" warning for that already?

Ah, good point!

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a way how this change could break things internally, so I would give the green light here.
You already discussed routers too so that's great.

I am just going to ping our friends at semantic-router for awareness in the (unlikely) case this may lead to unintended consequences @Xunzhuo .

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markmc!

I wonder whether to be safe we should merge #29503 before this one though.

@Xunzhuo
Copy link
Copy Markdown
Member

Xunzhuo commented Dec 3, 2025

thanks @NickLucche ! Overall looks good to me as well.

Include the internal request ID that the prefill instance is
expecting the decode instance to send it in the NIXL notification.

Right now, we rely on the proxy supplying the ID via X-Request-ID
and that prefill and decode will mangle this ID in identical ways.
This is obviously quite brittle, and P should be explicit about what
ID it expects from D.

Relates to vllm-project#27987 - adding a random prefix to client-provided
request IDs.

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
@markmc markmc force-pushed the nixl-remote-request-id branch from 9d04ff1 to 442a507 Compare December 5, 2025 15:21
@markmc
Copy link
Copy Markdown
Member Author

markmc commented Dec 5, 2025

Now that #29503 has landed, I've bumped NIXL_CONNECTOR_VERSION to be sure the incompatibility is caught at handshake time

@njhill njhill merged commit dff0a2b into vllm-project:main Dec 5, 2025
51 checks passed
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Dec 8, 2025
Culprit: vllm-project/vllm#29665 and
vllm-project/vllm#27938

---------

Signed-off-by: Dobrzyniewicz, Agata <agata.dobrzyniewicz@intel.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
jianzs pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 22, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
This PR addresses a request ID mismatch issue in the PD
(Prefill-Decoding) separation deployment scenario for vllm-ascend.
Upstream vLLM recently mitigated request ID collisions by appending a
random suffix to each request_id (e.g., req-123 → req-123-abc), refer to
[PR-27987](vllm-project/vllm#27987 ) &
[PR-29665](vllm-project/vllm#29665). While this
works in single-node deployments, it breaks compatibility in
PD-separated setups: the Producer (Prefill node) and Consumer (Decoding
node) end up with different request_id values, preventing the Consumer
from correctly retrieving the KV cache generated by the Producer.
To resolve this, this PR introduces a new field remote_request_id in the
metadata passed via mooncake_connector. The Producer preserves and
forwards the original (unmodified) request_id as remote_request_id. The
Consumer then uses this remote_request_id—instead of its locally
generated suffixed ID—to fetch the correct KV cache from the Prefill
node.
This ensures consistent request identification across PD nodes while
maintaining compatibility with upstream vLLM’s request ID deduplication
mechanism.
<img width="1279" height="781" alt="image"
src="https://github.com/user-attachments/assets/274238c1-dab6-4d3a-9ee4-6e578679b762"
/>

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@d682094

Signed-off-by: ghphotoframe <854746559@qq.com>
Co-authored-by: jiangweixiang <jwx02384838@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants