Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_prefix_caching_for_multi_turn():
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
req.update_block_hashes()

# Schedule the next-turn requests.
for req in next_turn_requests:
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,10 +982,8 @@ def _update_request_as_session(

session._all_token_ids.extend(update.prompt_token_ids or ())
session.prompt_token_ids.extend(update.prompt_token_ids or ())
# Update block hashes for the new tokens
# (mirrors Request.append_output_token_ids)
if session.get_hash_new_full_blocks is not None:
session.block_hashes.extend(session.get_hash_new_full_blocks())
# Update block hashes for the new tokens.
session.update_block_hashes()
session.num_prompt_tokens = len(session.prompt_token_ids)
session.arrival_time = update.arrival_time
session.sampling_params = update.sampling_params
Expand Down
18 changes: 11 additions & 7 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections import deque
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -164,10 +163,11 @@ def __init__(
self.num_external_computed_tokens = 0

self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
# Store the block hasher without binding self to avoid creating a
# reference cycle (Request -> partial -> Request) that prevents
# immediate garbage collection via reference counting.
self._block_hasher: Callable[[Request], list[BlockHash]] | None = block_hasher
self.update_block_hashes()

self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

Expand Down Expand Up @@ -212,8 +212,12 @@ def append_output_token_ids(
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)

if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
self.update_block_hashes()

def update_block_hashes(self) -> None:
"""Compute block hashes for any new full blocks and append them."""
if self._block_hasher is not None:
self.block_hashes.extend(self._block_hasher(self))

@property
def use_structured_output(self) -> bool:
Expand Down