Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 175 additions & 64 deletions tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,33 @@ def __repr__(self) -> str:

class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.transfer_specs: dict[int, TransferSpec] = {}
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
self.waiting_jobs: set[int] = set()
self.completed_jobs: list[int] = []
self.flushed_jobs: set[int] = set()

def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
self.completed_transfers = []
return finished

def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
self.transfer_specs[job_id] = spec
self.waiting_jobs.add(job_id)
return True

def complete_jobs(self, job_ids: set[int]) -> None:
for job_id in job_ids:
if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True))

def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids
self.complete_jobs(job_ids)


class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
Expand All @@ -98,9 +112,22 @@ def get_handlers(
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler

def complete_transfers(self):
self.handler.complete_jobs(self.handler.waiting_jobs.copy())

def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
specs = [
self.handler.transfer_specs[job_id]
for job_id in self.handler.completed_jobs
]
self.handler.completed_jobs.clear()
return specs

def get_flushed_transfers(self):
specs = [
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
]
self.handler.flushed_jobs.clear()
return specs


Expand Down Expand Up @@ -170,12 +197,9 @@ def __init__(
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}

self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0

self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
self.flushed_gpu_block_indexes: set[int] = set()

# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
Expand All @@ -202,54 +226,60 @@ def new_request(self, token_ids: list[int]):

self.scheduler.add_request(req)

def _wait_for_transfers(self):
def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)

for block_id in src_spec.block_ids:
self.flushed_gpu_block_indexes.add(
self.gpu_block_index[block_id.item()]
)

block_size_factor = self.offloaded_block_size // self.gpu_block_size

while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec

if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec

assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)

gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])

# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))

if store:
assert len(gpu_block_indices) == len(offload_addresses)

self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]

self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec

if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec

assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)

gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])

# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))

if store:
assert len(gpu_block_indices) == len(offload_addresses)

self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]

self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)

def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
Expand All @@ -258,18 +288,19 @@ def _update_gpu_block_idx(self):
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx

def _run(self, decoded_tokens: list[int]):
def _run(self, decoded_tokens: list[int], complete_transfers: bool):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.

Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""

tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
while True:
assert self.scheduler.requests

scheduler_output = self.scheduler.schedule()
Expand All @@ -279,17 +310,20 @@ def _run(self, decoded_tokens: list[int]):
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)

self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)

self.pending_stores_count += self.unsubmitted_stores_count
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)

self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)

if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()

if complete_transfers:
self.offloading_spec.complete_transfers()

finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
Expand All @@ -300,15 +334,18 @@ def _run(self, decoded_tokens: list[int]):
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id,
token_id=token_id or 0,
)

if self.scheduler.running:
token_id = next(tokens_iter, None)

self.scheduler.update_from_output(scheduler_output, model_runner_output)

self._wait_for_transfers()
if token_id is None:
break

self._parse_transfers()

# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
Expand All @@ -333,23 +370,28 @@ def _run(self, decoded_tokens: list[int]):
def run(
self,
decoded_tokens: list[int],
complete_transfers: bool = True,
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.

Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""

self.manager.reset_mock()
self._run(decoded_tokens)
self._run(decoded_tokens, complete_transfers)

loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
Expand All @@ -373,6 +415,9 @@ def run(
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()

assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
self.flushed_gpu_block_indexes.clear()


@pytest.fixture
def request_runner():
Expand Down Expand Up @@ -539,3 +584,69 @@ def take_events() -> Iterable[OffloadingEvent]:
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"


def test_request_preemption(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100

runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)

free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
num_free_blocks_empty = free_block_queue.num_free_blocks

# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)

# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
)

# simulate KV cache running out of space
free_block_queue.num_free_blocks = 0

# request should be preempted now
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)

# restore KV cache space and reset GPU prefix cache
free_block_queue.num_free_blocks = num_free_blocks_empty
runner.scheduler.reset_prefix_cache()

# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)

runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)
12 changes: 12 additions & 0 deletions tests/v1/kv_offload/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def get_finished(self) -> list[TransferResult]:
del self.transfers[job_id]
return finished

def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished


class OffloadingHandler2To1(OffloadingHandler):
def __init__(self):
Expand All @@ -84,6 +90,12 @@ def get_finished(self) -> list[TransferResult]:
del self.transfers[job_id]
return finished

def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished


def test_offloading_worker():
"""
Expand Down
Loading