Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_basic_lifecycle():
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1

Expand All @@ -67,6 +68,7 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 0

# ... but blocks should not be freed.
assert len(scheduler.requests) == 1
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
Expand All @@ -76,6 +78,7 @@ def test_basic_lifecycle():
# STEP (2): Send Finished to PB.
# (2a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
Expand All @@ -92,6 +95,7 @@ def test_basic_lifecycle():
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
Expand Down Expand Up @@ -133,6 +137,7 @@ def test_short_prompt_lifecycle():
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1

Expand Down Expand Up @@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle():
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)

#####################
Expand Down Expand Up @@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)


def test_abort_during_kv_transfer():
"""Test aborting request does not release blocks for remote decode."""

vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)

# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
Comment on lines +231 to +232
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not really important for this test, we could simplify


request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)

scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request])
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)

# Request removed from PB but blocks should not be freed.
assert len(scheduler.requests) == 1

# Abort the request, and check the blocks are still not freed
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
assert len(scheduler.requests) == 1

# Simulate a finished sending notification
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request.request_id]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
17 changes: 11 additions & 6 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
temporary buffer alloc by the CacheManager.
update_connector_output() - update KVConnector state after
output is received from worker-side connectors.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
request_finished() - called once when a request is finished,
with the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or if the
connector now assumes responsibility for freeing the
the blocks asynchronously. Also optionally returns KV
transfer params.
take_events() - returns new KV events that were collected
by the connector since the last call.

Expand Down Expand Up @@ -362,7 +363,11 @@ def request_finished(
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Called exactly once when a request has finished, before its blocks are
freed.

The connector may assumes responsibility for freeing the the blocks
asynchronously by returning True.

Returns:
True if the request is being saved/sent asynchronously and blocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)
# We should never get an abort after setting an expiry timer
assert req_id not in self._reqs_to_send

# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():
Expand Down
12 changes: 3 additions & 9 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ def finish_requests(
# First pass: collect requests to remove from queues
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
if request is None or request.is_finished():
# Invalid request ID.
continue

Expand Down Expand Up @@ -1369,14 +1369,8 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
self.finished_recving_kv_req_ids.add(req_id)
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
if req_id not in self.requests:
logger.warning(
"Got finished sending KV transfer for request %s,"
"but the request is already freed.",
req_id,
)
else:
self._free_blocks(self.requests[req_id])
assert req_id in self.requests
self._free_blocks(self.requests[req_id])

def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request], invalid_block_ids: set[int]
Expand Down