Skip to content

[KV Connector]: Support KV push from Prefill to Decode node using Nixl KV Connector#35264

Merged
NickLucche merged 7 commits into
vllm-project:mainfrom
snadampal:push_kv_from_ptod
Jun 12, 2026
Merged

[KV Connector]: Support KV push from Prefill to Decode node using Nixl KV Connector#35264
NickLucche merged 7 commits into
vllm-project:mainfrom
snadampal:push_kv_from_ptod

Conversation

@snadampal

@snadampal snadampal commented Feb 25, 2026

Copy link
Copy Markdown
Contributor

RFC #36923

Implemented KV push feature where Prefill node pushes
its KV blocks to Decode node as soon as the model executor
completes the forward pass and finishes request.
The implementation supports heterogeneous TP and
heterogeneous block sizes between P and D nodes

And it covers both the scenarios:
Scenario 1: D registers blocks with P before P finishes generating KV
Scenario 2: P has the KV ready before D registers

Purpose

To improve TTFT for P-D disaggregated inference deployment.

Test Plan

Manually tested Inference on P-D disaggregated setup

Test Result

Tested 1p1d configuration for pushing KV from prefill to decode node and validated accuracy of the results.

Benchmarking results

I ran vllm bench serve with sonnet dataset for different input and output token lengths. KV push mode (this PR) showed 1.2x - 3.0x improvements over pull mode across different input and out lengths. Following are the performance numbers measured on AWS P5en instance and similar improvements are observed on AWS Trn2 instances as well where the feature was originally developed.

Mode Prefill/Decode TP Input Len Output Len QPS Mean TTFT (ms) Median TTFT (ms) P99 TTFT (ms) Mean TPOT (ms) Req/s
pull 4 512 64 4 110.94 85.07 740.03 12.51 3.86
push 4 512 64 4 74.46 77.13 100.24 12.50 3.87
pull 4 512 128 4 301.92 89.22 1125.98 12.66 3.63
push 4 512 128 4 93.77 79.95 302.37 12.59 3.75
pull 4 512 128 8 2948.11 2682.32 7258.96 12.70 4.65
push 4 512 128 8 1212.48 1275.05 2826.72 12.67 5.92
pull 4 1024 128 4 330.94 91.26 1163.70 12.74 3.62
push 4 1024 128 4 273.83 84.10 1096.88 12.79 3.67
pull 4 2048 128 4 352.22 98.53 1225.20 12.88 3.61
push 4 2048 128 4 239.65 88.23 1042.32 12.85 3.73
pull 8 512 64 8 85.77 65.69 257.58 8.95 7.62
push 8 512 64 8 59.41 60.32 71.88 8.84 7.63
pull 8 512 128 16 3672.89 3628.99 7880.98 9.07 6.45
push 8 512 128 16 1653.70 1791.73 3681.04 9.06 9.04
pull 8 1024 128 8 986.94 776.67 2764.92 9.10 6.03
push 8 1024 128 8 674.82 551.53 2082.18 9.09 6.37
pull 8 2048 128 8 1048.93 828.82 2905.03 9.20 5.99
push 8 2048 128 8 648.08 371.25 1859.21 9.19 6.46

TODO:

  1. Run unit tests

Phase2:

  1. Add layer-wise KV push to bring out the actual performance gains from kv push model

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@snadampal snadampal marked this pull request as ready for review February 25, 2026 05:44

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature to support KV cache push from a Prefill node to a Decode node using the Nixl KV Connector, aiming to improve TTFT in disaggregated inference setups. The implementation correctly handles two synchronization scenarios between the nodes and includes necessary changes across the connector, engine, and worker components. The logic for handling block registration and the two push scenarios seems well-thought-out. However, I've identified a critical bug in the _write_blocks method that could lead to an UnboundLocalError for models with local attention, which needs to be addressed.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
@mergify

mergify Bot commented Feb 25, 2026

Copy link
Copy Markdown
Contributor

Hi @snadampal, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@yewentao256 yewentao256 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please solve the comments, pre-commit issue.

Also, please reduce diff to make it easier for review.

Adding lm_eval metrics to make sure we have the correct acc

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
@snadampal

Copy link
Copy Markdown
Contributor Author

@yewentao256 , thanks for the review, will take care of the comments. btw, I'm implementing heterogeneous TP and heterogeneous block_size between P and D nodes (the current version supports only homogeneous) and will update the PR soon.
I'm also unifying the pull and push mode logic for example, combing _read_blocks and _write_blocks to remove redundant code which should bring down the size a bit.

@snadampal snadampal force-pushed the push_kv_from_ptod branch from e9bc2b5 to a2aa7a4 Compare March 6, 2026 21:09
@snadampal snadampal requested a review from njhill as a code owner March 6, 2026 21:09
@mergify

mergify Bot commented Mar 6, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 6, 2026
@snadampal

Copy link
Copy Markdown
Contributor Author

Hi @yewentao256 , I have updated the PR with the features I mentioned above and also I've unified the push and pull modes as much as possible to avoid any redundant code. Still it's a large PR, but this is the best I could condense it to. Please review and let me know your comments.

I have started this on yesterday's commit and today it already has conflicts :(
I will rebase again and resolve them after the review is done.

@snadampal

Copy link
Copy Markdown
Contributor Author

I will write a design doc (in RFC format) to help with the reviews.

@snadampal

snadampal commented Mar 12, 2026

Copy link
Copy Markdown
Contributor Author

I have captured design and implementation details in the following design/RDC doc. I hope this helps with reviewing the PR.
#36923 - [RFC]: [KV Connector]: Support KV push from Prefill to Decode node using Nixl Connector

cc: @yewentao256 , @NickLucche

@snadampal

snadampal commented Apr 2, 2026

Copy link
Copy Markdown
Contributor Author

I have updated the PR to split the implementation into 8 commits; the first 4 commits being just the existing code refactor with zero functionality change, and the last 4 commits are for
(i) P allows block registration from remote
(ii) D registers blocks with P
(iii) P pushes its local blocks into remote blocks, and
(iv) D uses push mode

so that it will be much easier for review and also we can take commit by commit into mainline without disturbing anything. I have rebased it to the latest ( so it's now with hybrid kv feature) and run the benchmarks and accuracy tests, the performance numbers are same as those I had posted above.

As per Nicolo's suggestion, I have also hosted it as an out-of-tree connector as a plugin for people to try out quickly on top of vllm.
https://github.com/snadampal/vllm-nixl-pushpull-connector

cc:@NickLucche, @yewentao256

@snadampal snadampal force-pushed the push_kv_from_ptod branch from cbbc457 to f071ab2 Compare April 3, 2026 20:24
@mergify

mergify Bot commented Apr 6, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 6, 2026

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the work @snadampal !
Will try to run the numbers on my side too and then we can iron out a few details with this PR

@snadampal

snadampal commented Apr 10, 2026

Copy link
Copy Markdown
Contributor Author

Hi @NickLucche, I have done some more deep dive into Push mode performance gains and got some interesting findings:

  1. Push mode gives 3x or more gains than Pull mode for TTFT for the scenario where there is a 100% prefix cache hit on D node. In this case,
  • Pull mode incurs the latency of Prefill node request handling + proxy round-trip. If the prefill node doesn't hit cache, it includes complete prefill forward pass latency which is of the order of several 100s of milliseconds
  • Where as in Push mode, the request can start decoding right away and there is no serialization with Prefill or proxy.
  1. Push mode TTFT is around 1.3x - 2x faster and TPOT is 30% faster (and hence higher decode throughput) than Pull mode when KV cache memory is fragmented. When the memory is fragmented, the KV blocks are non-contiguous and hence Nixl can't coalesce the descriptors into fewer. For llama-3.3-70b-instruct model I see the descriptor count grew from 118 to 18880 which caused the nixl postXfer() duration to increase by 100x (460usec to 46ms).
  • In Pull mode, the start_load_kv -> nixl.transfer() -> postXfer() comes in decode critical path, just before the forward pass. Hence it directly affects TPOT .
  • Where as in Push mode, the start_push_kv -> nixl.transfer() -> postXfer() happens in a separate thread. it doesn't affect the worker forward pass. And moreover, this happens on Prefill node and it doesn't effect the Decode latencies.
    Note: Even when Prefill node has to pull KV (my other PR for multi-turn optimization), the postXfer() affect will not be as bad as in Decode node, because the engine loop steps are much longer, order of 100s of ms, so, the additional overhead from postXfer() will not be magnified.
  1. Push mode is around 10% faster than Pull mode when there is no prefix cache on Decode node and KV cache is not fragmented. But this scenario happens only a fresh instance start.

Please let me know your observations.

@mergify mergify Bot removed the needs-rebase label May 14, 2026
@mergify

mergify Bot commented May 14, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 14, 2026

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega job here @snadampal !
Code is clean and the docs diagram goes a long way into explaining the flow.
I also feel pretty good about the pull separation in terms of maintainability.

I left some comments, none of them is major, so feel free to address them as you see fit, we can even defer some of these ideas to a follow up PR (especially the one around thread state separation, if at all).

Comment thread examples/disaggregated/disaggregated_serving/disagg_proxy_pushconnector_demo.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_worker.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_scheduler.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_worker.py
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_scheduler.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_worker.py
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/push_worker.py Outdated
@snadampal snadampal force-pushed the push_kv_from_ptod branch from f706249 to 2958f67 Compare June 10, 2026 18:17
@mergify mergify Bot removed the needs-rebase label Jun 10, 2026
@snadampal snadampal force-pushed the push_kv_from_ptod branch 2 times, most recently from 6ec30d4 to f253643 Compare June 10, 2026 19:37
@snadampal

snadampal commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

Hi @NickLucche , thanks for the great feedback, as always! I have incorporated everything except making a new Writer class.
I'm deferring the Writer change for the next PR because, based on perf improvements in Push mode, I'm adding the similar threading logic for Pull mode as well, and as part of that I was planning to move the threading logic to BaseConnector if possible.

The PR is now has been rebased, all the feedback incorporated, and tested for unit tests, proxy, and benchmarking script.

@snadampal snadampal force-pushed the push_kv_from_ptod branch from f253643 to d598939 Compare June 11, 2026 00:47
@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 11, 2026
@NickLucche NickLucche enabled auto-merge (squash) June 11, 2026 08:47
@mergify

mergify Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 11, 2026
auto-merge was automatically disabled June 11, 2026 20:03

Head branch was pushed to by a user without write access

@snadampal snadampal force-pushed the push_kv_from_ptod branch from 7c0767d to dc8bdee Compare June 11, 2026 20:03
@mergify mergify Bot removed the needs-rebase label Jun 11, 2026
@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 12, 2026
snadampal and others added 7 commits June 12, 2026 08:52
…NixlPullConnector

Refactor the NIXL KV connector into base and pull-specific classes to enable
adding push-based KV transfer as a separate subclass.

- The monolithic NixlConnector, NixlConnectorScheduler, and
NixlConnectorWorker are split into NixlBaseConnector/NixlPullConnector,
NixlBaseConnectorScheduler/NixlPullConnectorScheduler, and
NixlBaseConnectorWorker/NixlPullConnectorWorker respectively.

- Base classes contain shared infrastructure (memory registration,
handshakes, TP mapping, transfer completion tracking, heartbeats,
post-processing).

- Pull classes contain READ-specific logic (start_load_kv,
_read_blocks_for_req, _read_blocks, _get_new_notifs, get_num_new_matched_tokens,
update_state_after_alloc, request_finished).

- Backward-compatible shim files preserve existing
imports (NixlConnector, NixlConnectorScheduler, NixlConnectorWorker).

- No functional changes.

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Scheduler stages registrations / finished blocks via metadata

Worker runs a dedicated nixl-push-writer thread that owns all push NIXL ops
(get_new_notifs, send_notif for PUSH_REG, make_prepped_xfer, transfer)
off the engine main thread. Writer wakes on engine signals from
start_load_kv / get_finished, self-polls only while waiting for
unmatched PUSH_REG notifs, and clears state on lease expiration.

Adds has_pending_push_work() to keep EngineCore stepping while pushes
are in flight, plus a soft per-registration watchdog.

D-registers-first and P-finishes-first cases are handled symmetrically;
request-id matching falls back to UUID extraction to tolerate retries.

Co-authored-by: Sunita Nadampalli <nadampal@amazon.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Add tests/v1/kv_connector/unit/test_nixl_push_connector.py covering the
push (WRITE) connector's scheduler and writer-thread mechanics without
requiring a real NIXL agent or network.

Coverage (30 tests):
- TestPushScheduler (6): D-side update_state_after_alloc staging, P-side
  request_finished staging, build_connector_meta drain on both sides,
  has_pending_push_work lifecycle, update_connector_output cleanup,
  registration watchdog expiration.
- TestPushWriterMatching (4): scenario 1 (P finished first) and
  scenario 2 (D registered first), fuzzy base-request-id matching for
  retried requests, malformed PUSH_REG payload drops.
- TestPushWriterStartLoadKv (2): finished-blocks inbox match against
  stashed D registration, start_load_kv enqueueing to writer queues.
- TestPushWriterNotifs (2): forwarded completion notif processing,
  get_finished eviction enqueue + writer wake.
- TestPushSchedulerNegative (6): no kv_transfer_params, invalid
  remote_block_ids payload, num_external_tokens=0, RUNNING status,
  empty block groups, unknown-request update_connector_output.
- TestPushWriterNegative (10): empty pop helpers, no-fuzzy when base
  UUIDs differ, non-dict payload, non-string request_id, idempotent
  duplicate PUSH_REG, eviction cardinality, no-op get_finished still
  wakes the writer, unknown completion notif, empty-metadata
  start_load_kv.

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Add examples/disaggregated/disaggregated_serving/disagg_proxy_pushconnector_demo.py,
the push-mode counterpart to disagg_proxy_demo.py. Same client-facing API
(/v1/completions, /v1/chat/completions, /status, /instances/add, xPyD
round-robin scheduling, runtime instance add) but speaks the
NixlPushConnector wire protocol on D's side: kv_transfer_params carries
P's coordinates (engine_id, host, side-channel port, tp_size).
P's prefill leg is issued first (max_tokens=1, do_remote_decode=True)
and drained, then the decode leg is streamed to the client
while D registers blocks with P and waits for the NIXL WRITE.

Required CLI: --prefill-engine-id, --prefill-kv-host,
--prefill-side-channel-port, --prefill-tp-size (P coordinates that pull
mode would learn from P's response but push mode needs up-front).

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Adds docs/design/nixl_kv_push_connector.md covering the push-mode
threading model, wake sources, writer-local
matching tables, PUSH_REG wire format, scheduler responsibilities,
watchdogs/leases, and failure handling. Includes a Mermaid sequence
diagram of the end-to-end P/D flow.

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify mergify Bot removed the needs-rebase label Jun 12, 2026
@NickLucche NickLucche enabled auto-merge (squash) June 12, 2026 07:30
@NickLucche NickLucche merged commit 88ed636 into vllm-project:main Jun 12, 2026
88 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants