Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_spec_tokens: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two points to discuss:

  1. Should we use "num_lookahead_tokens" to reduce confusion? After all these slots are for the proposed tokens that will be verified in the next step.
  2. Should we consider these slots along with the preallocated blocks? Specially if preallocated blocks can cover spec tokens, then we don't need to allocate additional slots?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen this term lookahead_tokens before. Can you share why this is more general than spec_tokens? Is it because it can also mean jump tokens?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No jump tokens should be in new_tokens. I just feel num_spec_tokens is confusing because it actually means the spec tokens we're going to propose by the end of this step. However, we also have spec_tokens in Request, but that spec_tokens were generated by the last step for verification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @comaniac I have the same two questions, too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am good with num_lookahead_tokens, will change here.
  2. yeah sure, we can do a more conservation way,
    preallocated_blocks -= num_lookahead_tokens // block_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preallocated_blocks -= num_lookahead_tokens // block_size

We might have to revert this when num of draft tokens become large espc with tree attn since then num draft tokens ~= num preallocated tokens which would lead to frequent block allocations.

) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.

Expand All @@ -174,6 +175,9 @@ def allocate_slots(
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_spec_tokens: The number of speculative tokens to allocate.
This field is only used by eagle. We allocate the slots for
the propose heads.

Blocks layout:
-----------------------------------------------------------------------
Expand Down Expand Up @@ -211,8 +215,9 @@ def allocate_slots(
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_spec_tokens,
self.block_size)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 @wwl2755 - Moving the discussion of why this PR is expected to improve the AL.

I have a hypothesis. Without this PR, the queries in the draft can go out of bounds in the block_table and pick up incorrect address and value which will corrupt the answer. block_table is used in FA cuda kernels and maybe we dont check illegal memory address access.

Lets say page size is 16. This corruption will arise when have < K slots left in the last block. The preallocate block computation (extra 4 blocks) wont trigger in this case since the last block is not full. As K increases, the changes of this increases. So K=4 has higher chances of having this than K=2 which reflects here.

But then block_table is gathered here too to form the slot_mapping for queries so out of index should have given an error which it did not when using bs=1 with MTBench so I am not sure if above hypothesis is correct.

Lmk what you guys think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon @LiuXiaoxuanPKU - can you also share your insight as to why this PR is expected to increase AL?

Copy link
Contributor

@wwl2755 wwl2755 Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is the statement "this PR can increase AL" already benchmarked OR is it set up as a goal of this PR?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a high level, if we don't have this PR, the current scheduler does not actually allocate slots for proposed tokens, they only allocate slots for verification. Therefore, it's not guaranteed the kv cache of the proposed heads is not contaminated.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU can you help us understand at a bit deeper level like which code line would be at fault?

My understanding is that If the scheduler doesn't allocate slots for the proposed tokens then torch should have thrown some error here when the new proposed tokens become the query? However, it didnt happen in our MTBench benchmark so probably there is no corruption without this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking! here will not trigger an error because block_table is always of a tensor of shape [batch_size, max_num_blocks_per_request], if those blocks are not allocated, the default values will be 0 in the block table.

num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))

Expand Down
12 changes: 10 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from collections.abc import Iterable
from typing import Optional, Union

from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
speculative_config: SpeculativeConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
Expand Down Expand Up @@ -112,6 +114,10 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.num_spec_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_spec_tokens = speculative_config.num_speculative_tokens

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -188,7 +194,9 @@ def schedule(self) -> SchedulerOutput:

while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
request,
num_new_tokens,
num_spec_tokens=self.num_spec_tokens)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that num_new_tokens is needed for the verfication of the spec token ids by the target model from previous step and num_spec_tokens is the num of spec tokens that the draft model is supposed to generate at the end of this step.

Based on that, if num_new_tokens is 8 and num_spec_tokens is 4 so can end up allocating 1 block (16 tokens) such that 1 block shares both target model and draft model's KV cache?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is similar. My interpretation is that it temporarily acquires extra num_spec_tokens for draft tokens and it won't aggregate the size in the next iteration.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the same blocks are shared by target and draft models then will it not be an issue since the KVC of target and draft model are adjacent in the logical mapping of block tables so draft model will attend to KVC of the target?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emmm, I think it should not cause a problem as long as the actual starting kv cache slot for draft model is marked somehow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion here!

  1. num_new_tokens is for verification, num_spec_tokens is for proposing heads. "Based on that, if num_new_tokens is 8 and num_spec_tokens is 4 so can end up allocating 1 block (16 tokens) such that 1 block shares both target model and draft model's KV cache?" --> yes, exactly.
  2. KV cache corruption: currently, kv cache is allocated independently but share the same slot mapping. We can think of kv cache as a map: {layer0_kv: [], layer1_kv: []...., layerk_kv, eagle_layer_kv:[]}. During each generation step, using the example above, we first verify tokens, which will write kv to layer0_kv...layerk_kv with slot mapping [0,1,2,3,4,5,6,7]. It will not write to the draft kv. If say only 2 tokens are accepted, 3 tokens are generated. In the proposing phase, we will send the three tokens to eagle proposer with slot mapping [0,1,2], which will populate the kv cache for the generated tokens, and also propose for the next token. We allocate 12 slots (8+4) in total, because it's possible that all tokens (with slot id 0-7) are accepted, in that case, proposing tokens need to write to kv cache with id [8,9,10,11].

Let me know if there is any confusion here!

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense now. The block_table where the allocated slots get saved are shared across all layers and eagle is just a layer on top of the target model's layer. When we add blocks for num_new_tokens + num_spec_tokens then the target model will use just the num_new_tokens slots but in the case when all the drafts are accepted, draft layer will use the num_new_tokens + num_spec_tokens slots.

if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
Expand Down
Loading