Skip to content

[Nixl][PD] Lease renewal TTL KV blocks on P#38027

Closed
NickLucche wants to merge 11 commits intovllm-project:mainfrom
NickLucche:lease-ttl-kv-blocks
Closed

[Nixl][PD] Lease renewal TTL KV blocks on P#38027
NickLucche wants to merge 11 commits intovllm-project:mainfrom
NickLucche:lease-ttl-kv-blocks

Conversation

@NickLucche
Copy link
Copy Markdown
Collaborator

This PR adds heartbeat-based lease renewal for NIXL KV cache blocks in disaggregated P/D deployments.

Problem

The Existing Timeout/TTL mechanisms for ensuring KV Cache blocks are eventually cleared on P after D disconnects and/or edge-cases abort scenarios, can lead to severe degradation in perf due to "dead" blocks retention.
With the current default VLLM_NIXL_ABORT_REQUEST_TIMEOUT=480s, when D crashes, P may retain several GBs of cache for in-flight requests for up to 8 minutes. Subsequent requests hitting P will only have a portion of cache at their
disposal.

Approach

We augment the TTL mechanism with lease renewal logic:

  • P grants a short initial lease (~30s) when prefill completes
  • D periodically refreshes the lease on its active requests (increasing TTL by VLLM_NIXL_KV_LEASE_EXTENSION)
  • If D crashes or stops heartbeating, P reclaims blocks after lease expiration (supposedly << original VLLM_NIXL_ABORT_REQUEST_TIMEOUT)

Key Design Decisions:

  1. Per-request leasing (not per-instance)

We lease at the request/transfer level rather than per-D-instance (e.g., "D0 is alive, refresh all its blocks"). This is because P has no notion of which D its KV blocks belong to. If P tracked per-D ownership, we couldn't defer D
selection until after prefill completes - we'd have to select both P and D upfront in the load balancer.

In practice, D batches lease extensions toward the same P by grouping requests with the same remote_engine_id.

  1. NIXL notifications as communication channel

Rather than adding FE API changes or managing additional ZMQ connections, we reuse the existing NIXL notification system (send_notif / get_new_notifs) to send heartbeats.
The notif system is backend-specific, with fallback from IB/RoCE to TCP as medium already taken care of.

  1. No background thread - heartbeats happen in the forward loop

Heartbeat sending/processing happens in start_load_kv / get_finished, not a separate background thread. This means timing isn't precise - longer model execution delays heartbeats. However:

  • The tight forward loop in practice keeps timing reasonable
  • Lease durations are configured with sufficient margin (default: 5s heartbeat interval, 30s extension). Empirically defaults are at least an order of magnitude bigger than a forward pass.
  • Avoids lock complexity between threads. We want to strive for simplicity here rather than ms precise timing
  • It can still be extended to dynamic calculations at runtime.

A consideration on Heterogeneous TP support: When P TP > D TP (e.g., P TP=4, D TP=2), a single D worker pulls KV blocks from multiple P workers. Each P worker independently tracks lease expiration. Therefore, heartbeats must be sent to ALL P workers for a given engine.
Conversely, when D TP > P TP, a single P will receive notifications from multiple Ds. This causes to refresh TTL multiple times, but with no downside.

Default Configuration

  | Variable                        | Default | Description                     |
  |---------------------------------|---------|---------------------------------|
  | VLLM_NIXL_KV_LEASE_DURATION     | 30s     | Initial lease duration on P     |
  | VLLM_NIXL_KV_LEASE_EXTENSION    | 20s     | Extension granted per heartbeat |
  | VLLM_NIXL_KV_HEARTBEAT_INTERVAL | 5s      | Heartbeat send interval on D    |

Test With

  pytest tests/v1/kv_connector/unit/test_nixl_connector.py -v -k "heartbeat or lease"

cc @robertgshaw2-redhat @markmc @ZhanqiuHu

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 24, 2026

Documentation preview: https://vllm--38027.org.readthedocs.build/en/38027/

@mergify mergify Bot added documentation Improvements or additions to documentation v1 kv-connector labels Mar 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a heartbeat-based lease management system for KV cache blocks in the Nixl connector. It replaces the single VLLM_NIXL_ABORT_REQUEST_TIMEOUT with VLLM_NIXL_KV_LEASE_DURATION, VLLM_NIXL_KV_LEASE_EXTENSION, and VLLM_NIXL_KV_HEARTBEAT_INTERVAL to manage KV block leases between producer and consumer. The changes include implementing D-side heartbeat sending, P-side lease extension and expiration, and updating related documentation and environment variable definitions. Review comments highlight inconsistencies in default values between documentation and code for the new environment variables, suggest improving a test by using the caplog fixture, point out a potential log pollution issue with empty heartbeat messages, and identify an inefficiency in the cleanup logic for _pending_transfers_by_engine.

Comment thread docs/features/nixl_connector_usage.md
Comment thread tests/v1/kv_connector/unit/test_nixl_connector.py
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment on lines +2484 to +2485
for engine_reqs in self._pending_transfers_by_engine.values():
engine_reqs.discard(req_id)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cleanup of _pending_transfers_by_engine iterates through all engine request sets to remove the completed req_id. This is inefficient, with a complexity of O(N_engines).

You can optimize this to O(1). Inside _pop_done_transfers, you can access self._recving_metadata to get the meta object for the req_id. From the metadata, you can get the engine_id and directly remove the req_id from the specific engine's request set in _pending_transfers_by_engine without iterating over all engines.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id rather keep tight execution as lean as possible, there's really no need to clean in _pop_done_transfers if the data structure is well bounded

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this code is cleaning in pop_done_transfers() ? (Not sure I understand your comment)

The issue here is you have no information here about which engine this request relates to

It would make sense to me if transfers was req_id -> (engine_id, handle) which would solve this?

Comment thread vllm/envs.py
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 30, 2026
Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to say, I don't quite understand the design of this lease extension

The key issue we are solving is that requests get stuck in the WAITING queue in the scheduler if the KV cache is overloaded.

IIUC, if the requests are in the WAITING queue, they have not yet been passed to the NIXL connector. So we need a mechanism by which the requests that have NOT YET been passed to the NIXL connector to be able to extend the leases

Am I missing something on this?

@NickLucche NickLucche marked this pull request as draft April 1, 2026 16:20
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 6, 2026
Copy link
Copy Markdown
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great ... until I saw @robertgshaw2-redhat's comment 🤣

we need a mechanism by which the requests that have NOT YET been passed to the NIXL connector to be able to extend the leases


- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thinking about ... "is it ok to just drop this, what about compatibility?" ... which reminds me ...

We need P and D to upgrade to this in lockstep, so that means we need to update NIXL_CONNECTOR_VERSION

- `VLLM_NIXL_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet read the KV-cache blocks through the nixl channel, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
- `VLLM_NIXL_KV_LEASE_DURATION`: Initial lease duration (in seconds) for KV blocks on the prefiller. (Optional)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #25700 - let's model this as config properly. Probably kv_connector_extra_config is the place

The "lease_extension" and "heartbeat_interval" configs are particularly niche - we would probably be fine with a hard-coded lease_extension = lease_duration * 2 /3 and heartbeat_interval = lease_duration / 6. And if anyone really needs to tweak these, we can add that later 👍

# In progress transfers.
# [req_id -> list[handle]]
# In-progress transfer tracking (D-side / consumer).
# Keyed by req_id to ensure ALL handles complete before marking done.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this comment

# grow indefinitely when new P remotes are added.
for k in list(self._pending_transfers_by_engine.keys()):
if not self._pending_transfers_by_engine[k]:
del self._pending_transfers_by_engine[k]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit odd - why not clear out the per-engine entry in _pop_done_transfers() when the set is empty?

continue

# Build batched heartbeat message: "HB:req1,req2,req3,..."
heartbeat_msg = ("HB:" + ",".join(req_ids)).encode()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any limit to the notification message size? This string can be pretty huge?

# Build batched heartbeat message: "HB:req1,req2,req3,..."
heartbeat_msg = ("HB:" + ",".join(req_ids)).encode()

# Send to ALL remote agents we handhshaked with for this remote.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Send to ALL remote agents we handhshaked with for this remote.
# Send to ALL remote agents we handshaked with for this remote.

try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=heartbeat_msg)
num_notifs += 1
except Exception as e:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really not a fan of catching Exception like this ... e.g. we shouldn't swallow a simple programming error like TypeError ... but you're just copying an existing pattern here, so nevermind

for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
self._pending_transfers_by_engine.clear()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary IMO - no need to free memory in shutdown(), that'll be handled by garbage collection. Only need to release other resources here


# Heartbeat/lease management for D-side (consumer).
# Single timestamp suffices - heartbeat interval limits overall send rate,
# not per-engine. New engines get fresh leases on P-side anyway.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratched my head on this one ... you're saying "there's no need to track last heartbeat per engine". I think a comment like this would be more helpful in send_lease_heartbeats()

# Check if enough time has passed since last heartbeat.
if now - self._last_heartbeat_time < self._heartbeat_interval:
return

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd expect to see

self._last_heartbeat_time = now

here

why would we want to come back here again immediately if no notifications are sent?

@markmc
Copy link
Copy Markdown
Member

markmc commented Apr 16, 2026

Relevant: https://github.com/llm-d/llm-d/blob/84ba2f7004abf3ba0e323fd8ffe3f8ce3c94656f/docs/wip-docs-new/architecture/advanced/disaggregation/operations-vllm.md

KV blocks are stranded on the P instance until the timeout VLLM_NIXL_ABORT_REQUEST_TIMEOUT, which defaults to 480s. We are currently working on a lease-extension strategy that will dramatically shorten the timeout window.

@NickLucche
Copy link
Copy Markdown
Collaborator Author

closing in favro of #41383

@NickLucche NickLucche closed this May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation kv-connector needs-rebase ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants