Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,12 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
# With sync scheduling, all-finished flush fires within this run.
# With async scheduling, the finish is delayed so flush fires later.
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# start a request to load the first block, but don't complete
Expand Down Expand Up @@ -325,6 +328,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# start a request to load the first block, but don't complete
Expand All @@ -351,6 +355,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
runner.run(
decoded_tokens=[],
expected_loaded=(0, 1, 2),
expected_flushed=(0, 1, 2),
)

# assert request is deleted
Expand Down Expand Up @@ -766,7 +771,11 @@ def test_do_remote_decode_stores_all_blocks(request_runner, async_scheduling: bo
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_stored=(0, 1, 2))
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# Reset GPU prefix cache so the next request must load from CPU.
runner.scheduler.reset_prefix_cache()
Expand Down Expand Up @@ -831,8 +840,13 @@ def test_fence_at_update_state_after_alloc(request_runner):
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False)
assert runner.connector_scheduler._block_id_to_pending_jobs
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * 4)
Expand All @@ -843,8 +857,6 @@ def test_fence_at_update_state_after_alloc(request_runner):
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

Expand All @@ -864,8 +876,13 @@ def test_fence_at_build_store_jobs(request_runner):
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], complete_transfers=False)
assert runner.connector_scheduler._block_id_to_pending_jobs
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[1] * 4)
Expand All @@ -875,8 +892,6 @@ def test_fence_at_build_store_jobs(request_runner):
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0,),
expected_flushed=(0,),
)
assert runner.connector_scheduler._block_id_to_pending_jobs == {}

Expand Down Expand Up @@ -921,6 +936,32 @@ def test_complete_store_called_per_job(request_runner, async_scheduling: bool):
assert runner.manager.complete_store.call_count == 0


def test_flush_all_jobs_when_no_requests_remain(request_runner):
"""When all tracked requests are finished, build_connector_meta flushes
all pending jobs since there will be no future step to complete them."""
block_size = 4
block_size_factor = 1
offloaded_block_size = block_size * block_size_factor

runner = request_runner(
block_size=block_size,
num_gpu_blocks=100,
async_scheduling=False,
block_size_factor=block_size_factor,
)

runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
complete_transfers=False,
expected_stored=(0,),
expected_flushed=(0,),
)


@pytest.mark.parametrize("async_scheduling", [True, False])
def test_reset_cache(request_runner, async_scheduling: bool):
"""reset_cache flushes in-flight loads, calls manager.reset_cache(), resets
Expand All @@ -942,7 +983,11 @@ def test_reset_cache(request_runner, async_scheduling: bool):
runner.manager.prepare_store.side_effect = lambda keys, req_context: (
generate_store_output(keys)
)
runner.run(decoded_tokens=[EOS_TOKEN_ID], expected_stored=(0, 1, 2))
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored=(0, 1, 2),
expected_flushed=(0, 1, 2) if not async_scheduling else (),
)

# Reset GPU prefix cache then start a request that loads from CPU.
# Leave the load in-flight so that reset_cache must flush it.
Expand Down
12 changes: 8 additions & 4 deletions tests/v1/kv_connector/unit/offloading_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,14 @@ def new_request(
def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)

for block_id in src_spec.block_ids:
self.flushed_gpu_blocks.add(self.gpu_blocks[block_id.item()])
if isinstance(src_spec, GPULoadStoreSpec):
# store flush
for block_id in src_spec.block_ids:
self.flushed_gpu_blocks.add(self.gpu_blocks[block_id.item()])
else:
# load flush
for block_id in dst_spec.block_ids:
self.flushed_gpu_blocks.add(self.gpu_blocks[block_id.item()])

block_size_factor = self.block_size_factor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,14 @@ def build_connector_meta(
assert self._jobs[any_jid].is_store
self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)

# If all tracked requests are finished, flush all pending jobs
# (both store and load) - there might not be a future scheduler
# step to trigger their completion.
if self._req_status and all(
rs.req.is_finished() for rs in self._req_status.values()
):
self._current_batch_jobs_to_flush.update(self._jobs.keys())

meta = OffloadingConnectorMetadata(
load_jobs=self._current_batch_load_jobs,
store_jobs=self._build_store_jobs(scheduler_output),
Expand Down
Loading