diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index e6a69dc8a94..23097bf2a08 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -21,13 +21,23 @@ pytestmark = pytest.mark.cpu_test -def get_sliding_window_manager(sliding_window_spec, block_pool): - return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0) +def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): + return SlidingWindowManager( + sliding_window_spec, + block_pool, + enable_caching=enable_caching, + kv_cache_group_id=0, + ) -def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): +def get_chunked_local_attention_manager( + chunked_local_attention_spec, block_pool, enable_caching=True +): return ChunkedLocalAttentionManager( - chunked_local_attention_spec, block_pool, kv_cache_group_id=0 + chunked_local_attention_spec, + block_pool, + enable_caching=enable_caching, + kv_cache_group_id=0, ) @@ -332,11 +342,53 @@ def test_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + == 15 + ) + + +def test_evictable_cached_blocks_not_double_allocated(): + block_size = 2 + sliding_window_length = 2 * block_size + sliding_window_spec = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=sliding_window_length, + ) + + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) + + request_id = "req" + evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate + + num_blocks_to_allocate = manager.get_num_blocks_to_allocate( + request_id=request_id, + num_tokens=2 * block_size, + new_computed_blocks=[evictable_block], + total_computed_tokens=block_size, + ) + # Free capacity check should count evictable cached blocks, but allocation + # should only allocate the truly new block. + assert num_blocks_to_allocate == 2 + + manager.allocate_new_computed_blocks( + request_id, + [evictable_block], + num_local_computed_tokens=block_size, + num_external_computed_tokens=0, ) + new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4) + assert len(new_blocks) == 1 + assert len(manager.req_to_blocks[request_id]) == 2 def test_chunked_local_attention_get_num_blocks_to_allocate(): @@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): ] assert ( - manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0) + == 20 ) assert ( - manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0) + == 15 ) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index a6f06d1b16a..148ff632a42 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -254,6 +254,10 @@ def cache_full_blocks( [] if self.enable_kv_cache_events else None ) for i, blk in enumerate(new_full_blocks): + # Some blocks may be null blocks when enabling sparse attention like + # sliding window attention. We skip null blocks here. + if blk.is_null: + continue assert blk.block_hash is None block_hash = new_block_hashes[i] @@ -361,7 +365,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: ) return True - def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None: + def touch(self, blocks: Sequence[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -369,15 +373,14 @@ def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None: Args: blocks: A list of blocks to touch. """ - for blocks_per_group in blocks: - for block in blocks_per_group: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and not block.is_null: - self.free_block_queue.remove(block) - block.ref_cnt += 1 - if self.metrics_collector: - self.metrics_collector.on_block_accessed(block) + for block in blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and not block.is_null: + self.free_block_queue.remove(block) + block.ref_cnt += 1 + if self.metrics_collector: + self.metrics_collector.on_block_accessed(block) def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 4b09b76c1c5..1d00873e606 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -60,6 +60,7 @@ def __init__( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, + enable_caching=enable_caching, kv_cache_group_id=i, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, @@ -73,6 +74,7 @@ def get_num_blocks_to_allocate( num_tokens: int, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], num_encoder_tokens: int, + total_computed_tokens: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -85,9 +87,10 @@ def get_num_blocks_to_allocate( prefix caching. num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. + total_computed_tokens: Include both local and external tokens. Returns: - The number of blocks. + The number of blocks to allocate. """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): @@ -95,30 +98,48 @@ def get_num_blocks_to_allocate( # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, [] + request_id, num_encoder_tokens, [], 0 ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i] + request_id, + num_tokens, + new_computed_blocks[i], + total_computed_tokens, ) return num_blocks_to_allocate - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...] + def allocate_new_computed_blocks( + self, + request_id: str, + new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], + num_local_computed_tokens: int, + num_external_computed_tokens: int, ) -> None: """ - Add the new computed blocks to the request. + Add the new computed blocks to the request. Optionally allocate new + blocks for external computed tokens (if any). Args: request_id: The request ID. new_computed_blocks: The new computed blocks just hitting the prefix cache. + num_local_computed_tokens: The number of local computed tokens. + num_external_computed_tokens: The number of external computed tokens. """ for i, manager in enumerate(self.single_type_managers): - manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) + manager.allocate_new_computed_blocks( + request_id, + new_computed_blocks[i], + num_local_computed_tokens, + num_external_computed_tokens, + ) def allocate_new_blocks( - self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + self, + request_id: str, + num_tokens: int, + num_encoder_tokens: int = 0, ) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` @@ -184,17 +205,20 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]: for manager in self.single_type_managers ] - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + def remove_skipped_blocks( + self, request_id: str, total_computed_tokens: int + ) -> None: """ Remove the blocks that are no longer needed from `blocks` and replace the removed blocks with null_block. Args: request_id: The request ID. - num_computed_tokens: The number of tokens that have been computed. + total_computed_tokens: The total number of computed tokens, including + local computed tokens and external computed tokens. """ for manager in self.single_type_managers: - manager.remove_skipped_blocks(request_id, num_computed_tokens) + manager.remove_skipped_blocks(request_id, total_computed_tokens) def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 13086a66f6e..2197107c1fc 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -210,6 +210,7 @@ def allocate_slots( num_new_computed_tokens: int = 0, new_computed_blocks: KVCacheBlocks | None = None, num_lookahead_tokens: int = 0, + num_external_computed_tokens: int = 0, delay_cache_blocks: bool = False, num_encoder_tokens: int = 0, ) -> KVCacheBlocks | None: @@ -217,16 +218,16 @@ def allocate_slots( Args: request: The request to allocate slots. - num_new_tokens: The number of tokens to allocate, including external - tokens. Note that this does not include tokens that have - already been computed locally (i.e. new_computed_blocks). + num_new_tokens: The number of new tokens to be allocated and computed. num_new_computed_tokens: The number of new computed tokens just hitting the prefix caching, excluding external tokens. new_computed_blocks: The cached blocks for the above new computed - tokens. + tokens, grouped as a tuple by kv cache groups. num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + num_external_computed_tokens: The number of tokens that their + KV caches are not cached by vLLM but cached by the connector. delay_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer which will complete in a future step. @@ -236,29 +237,81 @@ def allocate_slots( Blocks layout: ``` - ----------------------------------------------------------------------- - | < computed > | < new computed > | < new > | < pre-allocated > | - ----------------------------------------------------------------------- - | < required > | - -------------------------------------------------- - | < full > | - ------------------------------------------------ - | | - -------------- + ---------------------------------------------------------------------- + | < comp > | < new_comp > | < ext_comp > | < new > | < lookahead > | + ---------------------------------------------------------------------- + | < to be computed > | + ---------------------------------------------------------------------- + | < to be allocated > | + ---------------------------------------------------------------------- + | < to be cached (roughly, | + | details below)> | + ---------------------------------------------------------------------- + | Prefix-cached tokens from either vLLM | + | or connector. Can be safely removed if | + | they are outside sliding window. | + ---------------------------------------------------------------------- + | < cached by vLLM > | not cached by | + | vLLM, but | + | ref_cnt | ref_cnt not | cached by | + | increased| increased yet| connector | + ---------------------------------------------------------------------- ``` - The following *_blocks are illustrated in this layout. + + Abbrivations: + + ``` + comp = request.num_computed_tokens + new_comp = num_new_computed_tokens + = len(new_computed_blocks) * block_size + ext_comp = num_external_computed_tokens, cached by the connector + new = num_new_tokens, including unverified draft tokens + lookahead = num_lookahead_tokens + ``` + + NOTE: for new tokens which include both verified and unverified draft + tokens, we only cache the verified tokens (by capping the number at + `request.num_tokens`). + + The allocation has three stages: + - Free unnecessary blocks in `comp` and check + if we have sufficient free blocks (return None if not). + - Handle prefix tokens (`comp + new_comp + ext_comp`): + - Free unnecessary blocks (e.g. outside sliding window) + - Allocate new blocks for `ext_comp` tokens inside + sliding window + - Allocate new blocks for tokens to be computed (`new + lookahead`) Returns: A list of new allocated blocks. """ - if num_new_tokens == 0: - raise ValueError("num_new_tokens must be greater than 0") + # When loading KV data asynchronously, we may have zero new tokens to + # compute while still allocating slots for externally computed tokens. + if num_new_tokens == 0 and num_external_computed_tokens == 0: + raise ValueError( + "num_new_tokens must be greater than 0 when there are no " + "external computed tokens" + ) if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = self.empty_kv_cache_blocks.blocks + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_local_computed_tokens = ( + request.num_computed_tokens + num_new_computed_tokens + ) + total_computed_tokens = min( + num_local_computed_tokens + num_external_computed_tokens, + self.max_model_len, + ) + num_tokens_need_slot = min( + total_computed_tokens + num_new_tokens + num_lookahead_tokens, + self.max_model_len, + ) + # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to @@ -266,15 +319,7 @@ def allocate_slots( # Should call this function before allocating new blocks to reduce # the number of evicted blocks. self.coordinator.remove_skipped_blocks( - request.request_id, request.num_computed_tokens - ) - - # The number of computed tokens is the number of computed tokens plus - # the new prefix caching hits - num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens - num_tokens_need_slot = min( - num_computed_tokens + num_new_tokens + num_lookahead_tokens, - self.max_model_len, + request.request_id, total_computed_tokens ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( @@ -282,25 +327,25 @@ def allocate_slots( num_tokens=num_tokens_need_slot, new_computed_blocks=new_computed_block_list, num_encoder_tokens=num_encoder_tokens, + total_computed_tokens=num_local_computed_tokens + + num_external_computed_tokens, ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks return None - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - self.block_pool.touch(new_computed_block_list) - else: - assert not any(new_computed_block_list), ( - "Computed blocks should be empty when prefix caching is disabled" - ) - - if new_computed_block_list is not self.empty_kv_cache_blocks.blocks: + if ( + new_computed_block_list is not self.empty_kv_cache_blocks.blocks + or num_external_computed_tokens > 0 + ): # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks( - request.request_id, new_computed_block_list + self.coordinator.allocate_new_computed_blocks( + request_id=request.request_id, + new_computed_blocks=new_computed_block_list, + num_local_computed_tokens=num_local_computed_tokens, + num_external_computed_tokens=num_external_computed_tokens, ) new_blocks = self.coordinator.allocate_new_blocks( @@ -312,12 +357,14 @@ def allocate_slots( if not self.enable_caching or delay_cache_blocks: return self.create_kv_cache_blocks(new_blocks) - # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + - # num_new_tokens, but must exclude "non-committable" tokens (e.g., - # draft tokens that could be rejected). Therefore, we cap the number - # at `request.num_tokens`, ensuring only "finalized" tokens are cached. + # NOTE(woosuk): We want to commit (cache) up to num_local_computed_tokens + # + num_external_computed_tokens + num_new_tokens, but must exclude + # "non-committable" tokens (e.g., draft tokens that could be rejected). + # Therefore, we cap the number at `request.num_tokens`, ensuring only + # "finalized" tokens are cached. num_tokens_to_cache = min( - num_computed_tokens + num_new_tokens, request.num_tokens + total_computed_tokens + num_new_tokens, + request.num_tokens, ) self.coordinator.cache_blocks(request, num_tokens_to_cache) @@ -333,6 +380,19 @@ def free(self, request: Request) -> None: """ self.coordinator.free(request.request_id) + def remove_skipped_blocks( + self, request_id: str, total_computed_tokens: int + ) -> None: + """Remove the blocks that are no longer needed from `blocks` and replace + the removed blocks with null_block. + + Args: + request_id: The request ID. + total_computed_tokens: The total number of computed tokens, including + local computed tokens and external computed tokens. + """ + self.coordinator.remove_skipped_blocks(request_id, total_computed_tokens) + def evict_blocks(self, block_ids: set[int]) -> None: """evict blocks from the prefix cache by their block IDs. @@ -408,7 +468,13 @@ def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: return self.get_blocks(request_id).get_block_ids() def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: - """Cache the blocks for the request, if enabled.""" + """Cache the blocks for the request, if enabled. + + Args: + request: The request to cache the blocks. + num_computed_tokens: The number of computed tokens, including tokens + that are already cached and tokens to be cached. + """ if self.enable_caching: self.coordinator.cache_blocks(request, num_computed_tokens) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index da8339558b1..3e617628974 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -587,10 +587,11 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_computed_tokens, - num_new_local_computed_tokens, - new_computed_blocks, + num_new_tokens, + num_new_computed_tokens=num_new_local_computed_tokens, + new_computed_blocks=new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, + num_external_computed_tokens=num_external_computed_tokens, delay_cache_blocks=load_kv_async, num_encoder_tokens=num_encoder_tokens, ) @@ -606,7 +607,7 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: self.connector.update_state_after_alloc( request, - new_computed_blocks + new_blocks, + self.kv_cache_manager.get_blocks(request.request_id), num_external_computed_tokens, ) @@ -1580,6 +1581,13 @@ def _connector_finished( if self.connector is None: return False, None + # Free any out-of-window prefix blocks before we hand the block table to + # the connector. + self.kv_cache_manager.remove_skipped_blocks( + request_id=request.request_id, + total_computed_tokens=request.num_tokens, + ) + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) if not isinstance(self.connector, SupportsHMA): diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8a0a39b1f9..ddc50ae5124 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -30,6 +30,7 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + enable_caching: bool, kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, @@ -48,6 +49,7 @@ def __init__( self.block_size *= dcp_world_size * pcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + self.enable_caching = enable_caching # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request @@ -68,6 +70,7 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], + total_computed_tokens: int, ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -78,46 +81,121 @@ def get_num_blocks_to_allocate( tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. + total_computed_tokens: Include both local and external computed + tokens. Returns: - The number of blocks. + The number of blocks to allocate. """ num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = ( - num_required_blocks - - len(new_computed_blocks) - - len(self.req_to_blocks[request_id]) + num_req_blocks = len(self.req_to_blocks.get(request_id, ())) + + if request_id in self.num_cached_block: + # Fast-path: a running request won't have any new prefix-cache hits. + assert len(new_computed_blocks) == 0 + # NOTE: With speculative decoding, request's blocks may be allocated + # for draft tokens which are later rejected. In this case, + # num_required_blocks may be smaller than num_req_blocks. + return max(num_required_blocks - num_req_blocks, 0) + + num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens) + num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks + # Number of whole blocks that are skipped by the attention window. + # If nothing is skipped, this is 0. + num_skipped_blocks = num_skipped_tokens // self.block_size + # We need blocks for the non-skipped suffix. If there are still + # local-computed blocks inside the window, they contribute to the + # required capacity; otherwise, skipped blocks dominate. + num_new_blocks = max( + num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks), + 0, ) - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it will be changed from a free block - # to a computed block when the request is allocated, so we also count - # it as needed to be allocated. - num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + + # Among the `new_computed_blocks`, the first `num_skipped_blocks` worth + # of blocks are skipped; `num_req_blocks` of those may already be in + # `req_to_blocks`, so only skip the remainder from `new_computed_blocks`. + num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks) + + # If a computed block is an eviction candidate (in the free queue and + # ref_cnt == 0), it will be removed from the free queue when touched by + # the allocated request, so we must count it in the free-capacity check. + num_evictable_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks[num_skipped_new_computed_blocks:] ) - return num_new_blocks + num_evictable_computed_blocks + return num_new_blocks + num_evictable_blocks - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] + def allocate_new_computed_blocks( + self, + request_id: str, + new_computed_blocks: Sequence[KVCacheBlock], + num_local_computed_tokens: int, + num_external_computed_tokens: int, ) -> None: """ - Add the new computed blocks to the request. + Add the new computed blocks to the request. This involves three steps: + 1. Touch the computed blocks to make sure they won't be evicted. + 1.5. (Optional) For sliding window, skip blocks are padded with null blocks. + 2. Add the remaining computed blocks. + 3. (Optional) For KV connectors, allocate new blocks for external computed + tokens (if any). Args: request_id: The request ID. new_computed_blocks: The new computed blocks just hitting the prefix cache. + num_local_computed_tokens: The number of local computed tokens. + num_external_computed_tokens: The number of external computed tokens. """ - if request_id not in self.num_cached_block: - # A new request. - req_blocks = self.req_to_blocks[request_id] - assert len(req_blocks) == 0 - req_blocks.extend(new_computed_blocks) - self.num_cached_block[request_id] = len(new_computed_blocks) - else: - # A running request. Should not have new computed blocks. + + if request_id in self.num_cached_block: + # Fast-path: a running request won't have any new prefix-cache hits. + # It should not have any new computed blocks. assert len(new_computed_blocks) == 0 + return + + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + num_total_computed_tokens = ( + num_local_computed_tokens + num_external_computed_tokens + ) + num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens) + num_skipped_blocks = num_skipped_tokens // self.block_size + if num_skipped_blocks > 0: + # It is possible that all new computed blocks are skipped when + # num_skipped_blocks > len(new_computed_blocks). + new_computed_blocks = new_computed_blocks[num_skipped_blocks:] + # Some external computed tokens may be skipped too. + num_external_computed_tokens = min( + num_total_computed_tokens - num_skipped_tokens, + num_external_computed_tokens, + ) + + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self.block_pool.touch(new_computed_blocks) + else: + assert not any(new_computed_blocks), ( + "Computed blocks should be empty when prefix caching is disabled" + ) + + # Skip blocks are padded with null blocks. + req_blocks.extend([self._null_block] * num_skipped_blocks) + # Add the remaining computed blocks. + req_blocks.extend(new_computed_blocks) + # All cached hits (including skipped nulls) are already cached; mark + # them so cache_blocks() will not try to re-cache blocks that already + # have a block_hash set. + self.num_cached_block[request_id] = len(req_blocks) + + if num_external_computed_tokens > 0: + # Allocate new blocks for external computed tokens. + allocated_blocks = self.block_pool.get_new_blocks( + cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks) + ) + req_blocks.extend(allocated_blocks) def allocate_new_blocks( self, request_id: str, num_tokens: int @@ -252,7 +330,9 @@ def find_longest_cache_hit( raise NotImplementedError - def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + def remove_skipped_blocks( + self, request_id: str, total_computed_tokens: int + ) -> None: """ Remove and free the blocks that are no longer needed for attention computation. The removed blocks should be replaced by null_block. @@ -262,18 +342,24 @@ def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> No Args: request_id: The request ID. - num_computed_tokens: The number of tokens that have been computed. + total_computed_tokens: The total number of computed tokens, including + local computed tokens and external computed tokens. """ # Remove the blocks that will be skipped during attention computation. - num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens) + num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens) if num_skipped_tokens <= 0: # This indicates that ALL tokens are inside attention window. # Thus we do not need to free any blocks outside attention window. # A typical case is full attention that we never free any token # before the request is finished. return - num_skipped_blocks = num_skipped_tokens // self.block_size blocks = self.req_to_blocks[request_id] + num_skipped_blocks = num_skipped_tokens // self.block_size + # `num_skipped_tokens` may include tokens that haven't been allocated yet + # (e.g., when the attention window moves into the external computed tokens + # range), so we must cap to the number of blocks that currently exist for + # this request. + num_skipped_blocks = min(num_skipped_blocks, len(blocks)) removed_blocks: list[KVCacheBlock] = [] # Because the block starts from index 0, the num_skipped_block-th block # corresponds to index num_skipped_blocks - 1. @@ -486,7 +572,7 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: Returns: The number of tokens that will be skipped for attention computation. """ - return num_computed_tokens - self.sliding_window + 1 + return max(0, num_computed_tokens - self.sliding_window + 1) def get_num_common_prefix_blocks(self, running_request_id: str) -> int: """ @@ -711,6 +797,7 @@ def get_num_blocks_to_allocate( request_id: str, num_tokens: int, new_computed_blocks: Sequence[KVCacheBlock], + total_computed_tokens: int, ) -> int: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -721,7 +808,7 @@ def get_num_blocks_to_allocate( * self.kv_cache_spec.num_speculative_blocks ) return super().get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks + request_id, num_tokens, new_computed_blocks, total_computed_tokens ) def allocate_new_blocks( @@ -749,8 +836,12 @@ def get_num_skipped_tokens(self, num_computed_tokens: int) -> int: class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] + def allocate_new_computed_blocks( + self, + request_id: str, + new_computed_blocks: Sequence[KVCacheBlock], + num_local_computed_tokens: int, + num_external_computed_tokens: int, ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so `new_computed_blocks` should always be empty. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 68fe0853370..fae7fa620a2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -624,7 +624,7 @@ def execute_model( output = self.model_runner.execute_model( scheduler_output, intermediate_tensors ) - if isinstance(output, (ModelRunnerOutput, NoneType)): + if isinstance(output, ModelRunnerOutput | NoneType): return output assert isinstance(output, IntermediateTensors) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5f6136b178b..ab22d0af63a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -304,6 +304,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + # Init kv cache connector here, because it requires + # `kv_cache_config`. + # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, + # because `initialize_kv_cache` will inject kv cache groups not + # related to kv cache connector (e.g. kv cache sharing layers). + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) + self.model_runner.initialize_kv_cache(kv_cache_config) def check_health(self) -> None: @@ -336,8 +343,6 @@ def _init_tpu_worker_distributed_environment( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size ) - ensure_kv_transfer_initialized(vllm_config) - def shutdown(self) -> None: self.model_runner.ensure_kv_transfer_shutdown()