Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
409 changes: 231 additions & 178 deletions tests/v1/core/test_scheduler.py

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
from vllm.v1.request import RequestGenerationState

logger = init_logger(__name__)

Expand Down Expand Up @@ -208,15 +208,15 @@ def get_finished(
@abstractmethod
def get_num_new_matched_tokens(
self,
request: "Request",
request: "RequestGenerationState",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.

Args:
request (Request): the request object.
request (RequestGenerationState): the request generation state object.
num_computed_tokens (int): the number of locally
computed tokens for this request

Expand All @@ -231,7 +231,7 @@ def get_num_new_matched_tokens(
pass

@abstractmethod
def update_state_after_alloc(self, request: "Request",
def update_state_after_alloc(self, request: "RequestGenerationState",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Expand All @@ -244,7 +244,7 @@ def update_state_after_alloc(self, request: "Request",
are allocated, after the load/transfer is complete.

Args:
request (Request): the request object.
request (RequestGenerationState): the request generation state object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
Expand All @@ -267,8 +267,11 @@ def build_connector_meta(

def request_finished(
self,
request: "Request",
block_ids: list[int],
request_id: str,
request_status: "RequestStatus",
kv_transfer_params: Optional[dict[str, Any]],
num_computed_tokens: int,
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Expand Down
22 changes: 15 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.request import Request
from vllm.v1.request import RequestStatus

logger = init_logger(__name__)

Expand Down Expand Up @@ -176,13 +176,19 @@ def build_connector_meta(

def request_finished(
self,
request: "Request",
request_id: str,
request_status: "RequestStatus",
kv_transfer_params: Optional[dict[str, Any]],
num_computed_tokens: int,
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None

for c in self._connectors:
async_save, txfer_params = c.request_finished(request, blocks)
async_save, txfer_params = c.request_finished(
request_id, request_status, kv_transfer_params, num_computed_tokens, blocks
)
if async_save:
async_saves += 1
if txfer_params is not None:
Expand All @@ -192,10 +198,12 @@ def request_finished(
raise RuntimeError(
"Only one connector can produce KV transfer params")
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1

if async_saves > 1 and request_id:
self._extra_async_saves[request_id] = async_saves - 1

# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)
# Clean up other state for this request
if request_id:
self._requests_to_connector.pop(request_id, None)

return async_saves > 0, kv_txfer_params
32 changes: 21 additions & 11 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
from vllm.v1.request import RequestGenerationState

Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
Expand Down Expand Up @@ -133,11 +133,15 @@ def build_connector_meta(

def request_finished(
self,
request: "Request",
block_ids: list[int],
request_id: str,
request_status: "RequestStatus",
kv_transfer_params: Optional[dict[str, Any]],
num_computed_tokens: int,
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
return self.connector_scheduler.request_finished(
request_id, request_status, kv_transfer_params, num_computed_tokens, blocks)

############################################################
# Worker Side Methods
Expand Down Expand Up @@ -280,26 +284,32 @@ def build_connector_meta(

def request_finished(
self,
request: "Request",
block_ids: list[int],
request_id: str,
request_status: "RequestStatus",
kv_transfer_params: Optional[dict[str, Any]],
num_computed_tokens: int,
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""

params = request.kv_transfer_params
params = kv_transfer_params
logger.debug(
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
"kv_transfer_params=%s", request_status, params)

if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
or request_status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None

# Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = block_ids if all_full else block_ids[:-1]
# TODO(lucas): this is a edge case here where we are this request gets
# prempted by the next scheduler step (which may be past the stop token
# due to async scheduling).
all_full = num_computed_tokens % self.block_size == 0
computed_block_ids = blocks if all_full else blocks[:-1]

# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._requests_need_load: dict[str, "RequestGenerationState"] = {}
transfer_config = vllm_config.kv_transfer_config
self._storage_path = transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
Expand Down Expand Up @@ -224,7 +224,7 @@ def wait_for_save(self):

def get_num_new_matched_tokens(
self,
request: "Request",
request: "RequestGenerationState",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Expand Down Expand Up @@ -255,11 +255,11 @@ def get_num_new_matched_tokens(
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
len(request.params.prompt_token_ids) - 1, self._block_size)

return num_tokens_to_check - num_computed_tokens, False

def update_state_after_alloc(self, request: "Request",
def update_state_after_alloc(self, request: "RequestGenerationState",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Expand Down Expand Up @@ -314,8 +314,11 @@ def build_connector_meta(
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[cached_req.req_id]
total_tokens = (len(cached_req.new_token_ids) +
cached_req.num_computed_tokens)
# NOTE: new_token_ids is not available in CachedRequestData,
# use num_computed_tokens as the total since this is a resumed request
total_tokens = cached_req.num_computed_tokens

# Use all_token_ids from generation state (includes prompt + generated tokens)
token_ids = request.all_token_ids[:total_tokens]

# NOTE(rob): For resumed req, new_block_ids is all
Expand All @@ -338,14 +341,14 @@ def build_connector_meta(

def _found_match_for_request(
self,
request: "Request",
request: "RequestGenerationState",
) -> bool:
"""Check if the cache is hit for the request.
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
len(request.params.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
request.params.prompt_token_ids)[:num_tokens_to_check],
create_folder=False)
return os.path.exists(foldername)

Expand Down
2 changes: 1 addition & 1 deletion vllm/reasoning/abs_reasoning_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def vocab(self) -> dict[str, int]:
return self.model_tokenizer.get_vocab()

@abstractmethod
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
def is_reasoning_end(self, generated_input_ids: Sequence[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.

Expand Down
4 changes: 2 additions & 2 deletions vllm/reasoning/deepseek_r1_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")

def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def is_reasoning_end(self, generated_token_ids: list[int]) -> bool:
return self.end_token_id in generated_token_ids

def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Expand Down
4 changes: 2 additions & 2 deletions vllm/reasoning/qwen3_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
"Qwen3 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")

def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def is_reasoning_end(self, generated_token_ids: list[int]) -> bool:
return self.think_end_token_id in generated_token_ids

def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Expand Down
22 changes: 15 additions & 7 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
from vllm.v1.request import RequestGenerationState

logger = init_logger(__name__)

Expand Down Expand Up @@ -95,7 +95,8 @@ def get_cached_block(

def cache_full_blocks(
self,
request: Request,
request: "RequestParams",
token_ids: Optional[list[int]],
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
Expand Down Expand Up @@ -141,6 +142,9 @@ def cache_full_blocks(
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)

num_new_blocks_cached = 0

for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None

Expand All @@ -152,13 +156,13 @@ def cache_full_blocks(
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
elif token_ids is not None:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
block_tokens = token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
Expand All @@ -175,6 +179,10 @@ def cache_full_blocks(
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
else:
# Cannot cannot compute the block hash since the tokens are
# not available so we skip this block.
continue

# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
Expand All @@ -185,15 +193,15 @@ def cache_full_blocks(
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
num_new_blocks_cached += 1

if self.enable_kv_cache_events:
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.
all_token_ids[num_cached_blocks *
block_size:num_full_blocks * block_size],
token_ids=request.all_token_ids[num_cached_blocks *
block_size:num_full_blocks * block_size],
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
Expand Down
Loading