Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
OffloadingConnectorScheduler,
TransferJobStatus,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
Expand Down Expand Up @@ -258,9 +260,12 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
# With sync scheduling, all-finished flush fires within this run.
# With async scheduling, the finish is delayed so flush fires later.
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# start a request to load the first block, but don't complete
Expand Down Expand Up @@ -325,6 +330,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# start a request to load the first block, but don't complete
Expand Down Expand Up @@ -766,7 +772,11 @@ def test_do_remote_decode_stores_all_blocks(request_runner, async_scheduling: bo
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_stored=(0, 1, 2))
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# Reset GPU prefix cache so the next request must load from CPU.
runner.scheduler.reset_prefix_cache()
Expand Down Expand Up @@ -831,8 +841,13 @@ def test_fence_at_update_state_after_alloc(request_runner):
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False)
assert runner.connector_scheduler._block_id_to_pending_jobs
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * 4)
Expand All @@ -843,8 +858,6 @@ def test_fence_at_update_state_after_alloc(request_runner):
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

Expand All @@ -864,8 +877,13 @@ def test_fence_at_build_store_jobs(request_runner):
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False)
assert runner.connector_scheduler._block_id_to_pending_jobs
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[1] * 4)
Expand All @@ -875,8 +893,6 @@ def test_fence_at_build_store_jobs(request_runner):
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

Expand Down Expand Up @@ -919,3 +935,29 @@ def test_complete_store_called_per_job(request_runner, async_scheduling: bool):
# Finish: no store pending -> no further call.
runner.run(decoded_tokens=[EOS_TOKEN_ID])
assert runner.manager.complete_store.call_count == 0


def test_flush_all_jobs_when_no_requests_remain(request_runner):
"""When all tracked requests are finished, build_connector_meta flushes
all pending jobs since there will be no future step to complete them."""
block_size = 4
block_size_factor = 1
offloaded_block_size = block_size * block_size_factor

runner = request_runner(
block_size=block_size,
num_gpu_blocks=100,
async_scheduling=False,
block_size_factor=block_size_factor,
)

runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,14 @@ def build_connector_meta(
assert self._jobs[any_jid].is_store
self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)

# If all tracked requests are finished, flush all pending store
# jobs - there might not be a future scheduler step to trigger their
# completion.
if self._req_status and all(
rs.req.is_finished() for rs in self._req_status.values()
):
self._current_batch_jobs_to_flush.update(self._jobs.keys())

meta = OffloadingConnectorMetadata(
load_jobs=self._current_batch_load_jobs,
store_jobs=self._build_store_jobs(scheduler_output),
Expand Down
Loading