Skip to content

[Bugfix][NIXL] Fix Async Scheduler timeout issue#25808

Merged
tlrmchlsmth merged 1 commit intovllm-project:mainfrom
NickLucche:fix-async-pd
Sep 27, 2025
Merged

[Bugfix][NIXL] Fix Async Scheduler timeout issue#25808
tlrmchlsmth merged 1 commit intovllm-project:mainfrom
NickLucche:fix-async-pd

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Sep 27, 2025

To keep P kv blocks from being stranded in the case a request is aborted during decode, we have introduced a TTL/timeout #20139.

With async scheduling + high concurrency requests, this log line (and consequent logic) can be observed

ERROR 09-27 17:06:12 [nixl_connector.py:1101] Potentially invalid KV blocks for unrecognized request cmpl-141e42e3-ad61-46ab-88cc-fb6aaa47e363-0 were retrieved by a decode worker. They may have expired.

The current workflow on main is something like:

____________________
iteration i (engine step)
____________________

(Scheduler) 
build_connector_meta(_)->meta.reqs_to_send = _reqs_need_send` # Requests to time-out are prepared for worker
      |
(Scheduler)  # Here P is done and we acknowledge it by setting timeout 
request_finished(_)->self._reqs_need_send[request.request_id] = some_time
      | 
worker runs . . .

____________________
iteration i+1
____________________
(Scheduler) 
build_connector_meta(_)->meta.reqs_to_send = _reqs_need_send` # Prepare the request for which we set timeout above
       |
     .  .  . 
       |
 (Worker) # Worker *receives* metadata about finished requests and starts tracking its timeout
start_load_kv(_) -> self._reqs_to_send.update(metadata.reqs_to_send)
. . .

Now if D is allowed in between iter_i and iter_i+1 to decode the requests and READ blocks from P (this happens on a separate async channel, not tied to scheduler step), the log.error can be observed as self._reqs_to_send has not yet been updated!

Basically async scheduling pushes the misalignment between the moment in which requests expiration is set (P side) and the moment in which blocks are read from D, by allowing the scheduler one step ahead and the runner to pick the next batch right up.

FIX #25777

Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche requested a review from ApostaC as a code owner September 27, 2025 17:07
@NickLucche NickLucche changed the title fix async scheduler timeout bug [Bugfix][NIXL] Fix Async Scheduler timeout issue Sep 27, 2025
@mergify mergify bot added the kv-connector label Sep 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a race condition in the asynchronous scheduler that could lead to requests being incorrectly marked as expired. The fix introduces a new state to track requests within a processing batch, preventing premature expiration and erroneous error logging. The logic appears sound and effectively resolves the described issue. I have one suggestion to enhance the robustness of the implementation.

"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using discard() instead of remove() would be more robust. In a complex distributed system with asynchronous operations, it's safer to use discard() to avoid potential KeyError exceptions if the req_id is unexpectedly not in the set due to race conditions. This would make the worker more resilient.

A similar change would be beneficial on line 1126.

Suggested change
self._reqs_to_process.remove(req_id)
self._reqs_to_process.discard(req_id)

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested in a WideEP setup -- fixes crashes I was seeing

@tlrmchlsmth tlrmchlsmth added this to the v0.11.0 Cherry Picks milestone Sep 27, 2025
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 27, 2025
@tlrmchlsmth tlrmchlsmth merged commit da63274 into vllm-project:main Sep 27, 2025
51 checks passed
simon-mo pushed a commit that referenced this pull request Sep 28, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
@njhill
Copy link
Member

njhill commented Sep 29, 2025

Thanks for this @NickLucche. I'm wondering whether there's possibility for a "leak" of entries in self._reqs_to_process, specifically if the scheduler decides not to designate these as requests to send, for example if the prefill is aborted while in progress, or it's decided that it's too few tokens or whatever.

In this case the expiration won't be set and there will be no nixl notifications, so I can't see where the corresponding req_id would ever be removed from that set.

pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
shyeh25 pushed a commit to shyeh25/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][WideEP]: vLLM Crashes with D/P when --async-scheduling is used.

3 participants