Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional, overload

from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
Expand Down Expand Up @@ -37,7 +37,24 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))

def get_block_ids(self) -> tuple[list[int], ...]:
@overload
def get_block_ids(
self,
allow_none: Literal[False] = False,
) -> tuple[list[int], ...]:
...

@overload
def get_block_ids(
self,
allow_none: Literal[True] = True,
) -> Optional[tuple[list[int], ...]]:
...

def get_block_ids(
self,
allow_none: bool = False,
):
"""
Converts the KVCacheBlocks instance to block_ids.

Expand All @@ -46,6 +63,8 @@ def get_block_ids(self) -> tuple[list[int], ...]:
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
if allow_none and all(len(group) == 0 for group in self.blocks):
return None
return tuple([blk.block_id for blk in group] for group in self.blocks)

def get_unhashed_block_ids(self) -> list[int]:
Expand Down Expand Up @@ -348,10 +367,13 @@ def take_events(self) -> list[KVCacheEvent]:
"""
return self.block_pool.take_events()

def get_blocks(self, request_id: str) -> KVCacheBlocks:
"""Get the blocks of a request."""
return KVCacheBlocks(self.coordinator.get_blocks(request_id))

def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
return self.get_blocks(request_id).get_block_ids()

def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CachedRequestData:
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int]

@property
Expand Down
24 changes: 12 additions & 12 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
Expand Down Expand Up @@ -185,7 +185,7 @@ def schedule(self) -> SchedulerOutput:
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}

req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
Expand Down Expand Up @@ -288,8 +288,7 @@ def schedule(self) -> SchedulerOutput:
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
Expand Down Expand Up @@ -496,8 +495,8 @@ def schedule(self) -> SchedulerOutput:

if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
Expand Down Expand Up @@ -546,16 +545,16 @@ def schedule(self) -> SchedulerOutput:
)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id])
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
req_to_new_blocks,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
Expand Down Expand Up @@ -628,11 +627,11 @@ def _make_cached_request_data(
resumed_reqs: list[Request],
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
req_to_new_blocks: dict[str, KVCacheBlocks],
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
num_computed_tokens: list[int] = []

use_connector = self.connector is not None
Expand All @@ -655,7 +654,8 @@ def _make_cached_request_data(
# out of bounds errors. TODO: Remove this once the KVConnector
# is updated to handle token IDs properly.
new_token_ids.append([])
new_block_ids.append(req_to_new_block_ids[req_id])
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:

# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
Expand All @@ -594,7 +596,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
if new_block_ids is not None:
self.input_batch.block_table.append_row(
new_block_ids, req_index)

# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
Expand All @@ -438,7 +440,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
if new_block_ids is not None:
self.input_batch.block_table.append_row(
new_block_ids, req_index)

# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
Expand Down