Skip to content

[NIXL] Ignore abort on already-finished request#25067

Merged
njhill merged 3 commits intovllm-project:mainfrom
markmc:pd-abort
Oct 10, 2025
Merged

[NIXL] Ignore abort on already-finished request#25067
njhill merged 3 commits intovllm-project:mainfrom
markmc:pd-abort

Conversation

@markmc
Copy link
Member

@markmc markmc commented Sep 17, 2025

This situation can occur when the API server receives a client disconnect (and thus sends an abort) around the same time a prefill completes and we keep the blocks (delay_free_blocks) around for a remote decode. We should assume the blocks may be used, and so we ignore the abort. If they are not used, they should be freed by the connector after a timeout.

The original error was:

[scheduler.py:1183] Finished sending KV transfer for request cmpl-37c560d3-5680-4bd1-97f9-7ed31a56de60-0
  File "/opt/vllm-source/vllm/v1/engine/core.py", line 292, in step
     engine_core_outputs = self.scheduler.update_from_output(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm-source/vllm/v1/core/sched/scheduler.py", line 893, in update_from_output
    self._update_from_kv_xfer_finished(
  File "/opt/vllm-source/vllm/v1/core/sched/scheduler.py", line 1184, in _update_from_kv_xfer_finish>
    self._free_blocks(self.requests[req_id])
                      ~~~~~~~~~~~~~^^^^^^^^

KeyError: 'cmpl-37c560d3-5680-4bd1-97f9-7ed31a56de60-0'

But since #25844 we would log a warning. This fix makes it so that situation in _update_from_kv_xfer_finish() should never occur.

Observed under heavy load in a multi-node, llm-d 4P1D test environment. See llm-d/llm-d#187

More recently, #26012 introduced another case where this situation would cause a crash:

            logger.warning(
                "Releasing expired KV blocks for request %s which were "
                "retrieved by %d decode worker(s) within %d seconds.",
                req_id,
                count,
                envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
            )
>           self._reqs_to_process.remove(req_id)
E           KeyError: '0'

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:1238: KeyError

@markmc
Copy link
Member Author

markmc commented Sep 17, 2025

Related to llm-d/llm-d#187

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a critical race condition that could lead to a KeyError when a request is aborted during a KV cache transfer. The core change in vllm/v1/core/sched/scheduler.py to skip aborting already-finished requests is a clean and correct solution. The addition of the test_abort_during_kv_transfer unit test is excellent, as it specifically validates the fix for the identified scenario. The other changes, including test assertions and logging improvements, further enhance the robustness and debuggability of the codebase. Overall, this is a high-quality contribution that improves the stability of the system under heavy load.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this wouldn't affect regular abort workflow..
Let's quickly discuss offline with @njhill

@markmc
Copy link
Member Author

markmc commented Sep 17, 2025

Based on chat with @NickLucche there's a very obvious question I need to get to the bottom of:

If P is done with the request (finished, capped length), then how is the request being aborted in P ?

@mergify mergify bot added the kv-connector label Sep 18, 2025
@markmc
Copy link
Member Author

markmc commented Sep 19, 2025

Closing this for now - we really don't know what's going on here, my theory above doesn't make much sense, and we haven't been able to reproduce it lately.

@markmc markmc closed this Sep 19, 2025
@markmc markmc mentioned this pull request Sep 29, 2025
5 tasks
@markmc
Copy link
Member Author

markmc commented Oct 6, 2025

A reminder, we're talking about this scenario:

    def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
        for req_id in kv_connector_output.finished_sending or ():
            logger.debug("Finished sending KV transfer for request %s", req_id)
            if req_id not in self.requests:
                logger.warning(
                    "Got finished sending KV transfer for request %s,"
                    "but the request is already freed.",
                    req_id,
                )
            else:
                self._free_blocks(self.requests[req_id])

After the scheduler (engine core) finishes a request after prefill (due to max_tokens=1), the blocks are kept around for the decode worker, and when (much later) the timeout expires, the request has already been freed in the scheduler (or, at least, removed from Scheduler.requests)

I've managed to reproduce this, albeit quite artificially - but I think it's clear there is a race condition here, and it's the only reasonable possible way I see of hitting this situation.

The race condition is something like this (GPT-5's effort to illustrate it)

                +--------------------+                    +--------------------+
                |     API Server     |                    |    Engine Core     |
                +--------------------+                    +--------------------+
                         |                                           |
                         |---- send request ------------------------>|
                         |                                           |
                         |<--- stream partial outputs ---------------|
                         |<--- stream partial outputs ---------------|
                         |                                           |
                         |<--- final output -------------------------|
                         |                                           |
        [Client disconnects just before or during final output]      |
                         |                                           |
                         |  asyncio.CancelledError / GeneratorExit   |
                         |                                           |
                         |---- send abort(request_id) -------------->|
                         |                                           |
                         |                    [race window]          |
                         |       Engine finishes and frees request   |
                         |         before abort arrives              |
                         |                                           |
                         |<------------------- done -----------------|
                         |                                           |
                         |---- abort arrives too late -------------->|
                         |                                           |
                         |        [request already freed]            |
                         |        -> spurious abort handling         |
                         |        -> potential log warning/error     |
                         |                                           |

Why do I think this is the likely scenario? The engine core abort_requests() is the only call to finish_requests() which contains the only call to free_request() other than update_from_output(). And the only other reason the API server would send an abort is if it hits a stop string on detokenization.

My conclusion - we need to handle the case where an exception or client disconnect causes the engine core to receive an abort after a request has finished

@markmc
Copy link
Member Author

markmc commented Oct 6, 2025

See #26012 (comment) - I think #26012 has introduced another KeyError crash in this scenario, but this time int he worker

@markmc markmc reopened this Oct 6, 2025
@markmc markmc requested a review from ApostaC as a code owner October 6, 2025 14:14
@markmc markmc marked this pull request as draft October 6, 2025 14:14
@markmc markmc force-pushed the pd-abort branch 2 times, most recently from be65ca4 to 887f18d Compare October 6, 2025 14:15
@markmc markmc changed the title [V1][NIXL] Keep prefilled blocks around even if aborted [NIXL] Free prefilled blocks if aborted Oct 6, 2025
@markmc markmc marked this pull request as ready for review October 6, 2025 19:07
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markmc. I spent some time studying this and I think you're right re this race condition and I can see how it could cause a double-free type key error.

Comment on lines +1202 to +1205
logger.debug("Aborting a previously finished request %s", req_id)
request.status = finished_status
self._connector_finished(request)
self._free_blocks(request)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it may be better here to just do nothing (i.e. just continue / skip this req). The expiration will handle blocks freeing once the timeout triggers. We could free them immediately but that would require keeping track of another state possible state (or else we would need to ignore double-frees). And this is only an edge case anyhow.

I think for sure we shouldn't call _connector_finished here since that would have already been called (the reason for the race condition in the first place).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it may be better here to just do nothing (i.e. just continue / skip this req).

Yeah, that's what I had this PR do in its first iteration, but that was before I understood that the decode worker could not continue in this situation, so we know we can free

The expiration will handle blocks freeing once the timeout triggers.

I really don't love the prospect of a rare scenario whereby potentially many client-aborted requests stranded blocks for up to 8 minutes. You might kill a benchmark run, restart it, and get unpredictable results?

We could free them immediately but that would require keeping track of another state possible state (or else we would need to ignore double-frees).

We already guard against the (impossible) possibility of another abort by checking whether the request is in scheduler.requests

The "but the request is already freed" check in _update_from_kv_xfer_finished() is an example of us ignoring a double free already - I actually think we could make that unnecessary, but it's fine as defensive programming

Not sure I can see another issue?

And this is only an edge case anyhow.

We do need to do something to handle this edge case though - personally I think handling the abort by freeing is likely to be less brittle than just ignoring the abort and leaving it for the expiry timer to handle

I think for sure we shouldn't call _connector_finished here since that would have already been called (the reason for the race condition in the first place).

If the request is aborted, and you free the request, then the connector should also delete the timer - calling request_finished() is what tells the connector to do that, e.g. in NIXL:

	if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
            # Also include the case of a P/D Prefill request with immediate                                                                                                              
	    # block free (eg abort). Stop tracking this request.                                                                                                                         
            self._reqs_not_processed.add(request.request_id)
            return False, None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the decode worker could not continue in this situation, so we know we can free

This is true in the OpenAI API server + NixlConnector P/D case but not in the general case. If folks are using the AsyncLLM interface directly they can call abort "out of band", but then would still get the output for the request with any kv transfer params that the connector had returned.

For connectors in general, the contract is that if request_finished() returned async_save=True then the connector may be using/saving the blocks asynchronously until it notifies the framework that it's finished with them. So I'm not sure it would be "safe" to free them here since in a sense the connector owns the blocks at this point.

I really don't love the prospect of a rare scenario whereby potentially many client-aborted requests stranded blocks for up to 8 minutes. You might kill a benchmark run, restart it, and get unpredictable results?

This is fair! But in theory this should be a very narrow window. Given this, and looking closer at the flow, I think the reason we've encountered it is likely that in the core model loop, abort (and other new) requests aren't processed between the model execution / forward pass and the call to scheduler.update_from_outputs() which handles the request completion including notifying the scheduler.

We should probably look at changing it to process pending requests in-between so that in this case, the request status will have already been updated to ABORTED when it's passed to request_finished() (and then the nixl connector will return async_save=False so the blocks will be freed immediately.

If the request is aborted, and you free the request, then the connector should also delete the timer - calling request_finished() is what tells the connector to do that

My assumption of the contract of this method is that it should be called exactly once for each request.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the decode worker could not continue in this situation, so we know we can free

This is true in the OpenAI API server + NixlConnector P/D case but not in the general case. If folks are using the AsyncLLM interface directly they can call abort "out of band", but then would still get the output for the request with any kv transfer params that the connector had returned.

No, because any finished request gets removed from the output processor before it's returned, and then it's filtered out of any call to abort()

For connectors in general, the contract is that if request_finished() returned async_save=True then the connector may be using/saving the blocks asynchronously until it notifies the framework that it's finished with them. So I'm not sure it would be "safe" to free them here since in a sense the connector owns the blocks at this point.

I appreciate you're focused on a clear API contract with connectors, especially since we allow for out-of-repo connectors

I really don't love the prospect of a rare scenario whereby potentially many client-aborted requests stranded blocks for up to 8 minutes. You might kill a benchmark run, restart it, and get unpredictable results?

This is fair! But in theory this should be a very narrow window. Given this, and looking closer at the flow, I think the reason we've encountered it is likely that in the core model loop, abort (and other new) requests aren't processed between the model execution / forward pass and the call to scheduler.update_from_outputs() which handles the request completion including notifying the scheduler.

We should probably look at changing it to process pending requests in-between so that in this case, the request status will have already been updated to ABORTED when it's passed to request_finished() (and then the nixl connector will return async_save=False so the blocks will be freed immediately.

If the request is aborted, and you free the request, then the connector should also delete the timer - calling request_finished() is what tells the connector to do that

My assumption of the contract of this method is that it should be called exactly once for each request.

Ok, I'll change back to ignoring the abort and update the connector API contract documentation to say this outright

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably look at changing it to process pending requests in-between so that in this case

Filed as #26400

markmc added a commit to markmc/vllm that referenced this pull request Oct 7, 2025
We have observed a rare scenario with AsyncLLM where a client disconnect
triggers an abort request after the request has finished, but before
AsyncLLM has processed the request output.

See vllm-project#26012, vllm-project#25067, vllm-project#25844, and llm-d/llm-d#187.

Without the fix, the unit test fails with:

```
            logger.warning(
                "Releasing expired KV blocks for request %s which were "
                "retrieved by %d decode worker(s) within %d seconds.",
                req_id,
                count,
                envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
            )
>           self._reqs_to_process.remove(req_id)
E           KeyError: '0'

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:1238: KeyError
```

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This situation can occur when the API server receives a client
disconnect (and thus sends an abort) around the same time a prefill
completes and we keep the blocks (delay_free_blocks) around for a
remote decode. We should assume the blocks may be used, and so
we ignore the abort. If they are not used, they should be freed
by the connector after a timeout.

The original error was:

```
[scheduler.py:1183] Finished sending KV transfer for request cmpl-37c560d3-5680-4bd1-97f9-7ed31a56de60-0
  File "/opt/vllm-source/vllm/v1/engine/core.py", line 292, in step
     engine_core_outputs = self.scheduler.update_from_output(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm-source/vllm/v1/core/sched/scheduler.py", line 893, in update_from_output
    self._update_from_kv_xfer_finished(
  File "/opt/vllm-source/vllm/v1/core/sched/scheduler.py", line 1184, in _update_from_kv_xfer_finish>
    self._free_blocks(self.requests[req_id])
                      ~~~~~~~~~~~~~^^^^^^^^

KeyError: 'cmpl-37c560d3-5680-4bd1-97f9-7ed31a56de60-0'
```

But since vllm-project#25844 we would log a warning. This fix makes it so
that situation in `_update_from_kv_xfer_finish()` should never
occur.

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
@markmc markmc changed the title [NIXL] Free prefilled blocks if aborted [NIXL] Ignore abort on already-finished request Oct 8, 2025
markmc added 2 commits October 8, 2025 04:36
1. request_finished() should be called exactly once

2. Returning True from request_finished() means the
connector assumes responsibility for when the request
should be freed

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Now that we prevent against abort-after-finished, the current
assumptions in the NIXL connector are correct, but an assertion
helps document the assumption more clearly.

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @markmc! and thanks for the doc improvements!

@njhill njhill enabled auto-merge (squash) October 10, 2025 02:49
Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean, thanks for the fine work @markmc

Comment on lines +231 to +232
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not really important for this test, we could simplify

@njhill njhill merged commit 784c231 into vllm-project:main Oct 10, 2025
52 checks passed
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
jianzs added a commit to jianzs/vllm that referenced this pull request Feb 27, 2026
This is an enhancement to PR vllm-project#25067 which ignored aborts on finished
requests and relied on timeout-based cleanup. Instead of waiting for
the connector timeout to free blocks, immediately free them when
receiving FINISHED_ABORTED for an already-finished request.

This enables earlier KV cache memory reclamation, which is especially
important under heavy load in multi-node scenarios where memory
pressure is high.

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
jianzs added a commit to jianzs/vllm that referenced this pull request Mar 5, 2026
This is an enhancement to PR vllm-project#25067 which ignored aborts on finished
requests and relied on timeout-based cleanup. Instead of waiting for
the connector timeout to free blocks, immediately free them when
receiving FINISHED_ABORTED for an already-finished request.

This enables earlier KV cache memory reclamation, which is especially
important under heavy load in multi-node scenarios where memory
pressure is high.

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants