diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d89b79514950..10136c9cb1bc 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -95,6 +95,14 @@ Below, you can find an explanation of every engine argument for vLLM: For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9. +.. option:: --prefix-caching-memory-utilization + + The fraction of GPU memory to be used for the prefix caching, which can range from 0 to --gpu-memory-utilization. + For example, a value of 0.5 would imply 50% GPU memory utilization. + If unspecified, will use the default value of 0. A value of 0 means no prefix caching at all. + The size of the prefixes relative to the length of the rest of the prompts and the generated + sequences should dictate the relative value of this parameter with respect to gpu-memory-utilization. + .. option:: --max-num-batched-tokens Maximum number of batched tokens per iteration. diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 1e301bedfc21..95a68ca6537b 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,8 +2,11 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ +from importlib import reload + import pytest +import vllm.model_executor.parallel_utils.parallel_state as parallel_state from vllm import LLM, SamplingParams prefix = ( @@ -20,12 +23,20 @@ @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.2]) def test_prefix_caching( example_prompts, model: str, max_tokens: int, + prefix_pool_memory_utilization: float, ): - llm = LLM(model=model) + # IMPORTANT: If this line is removed from here, adding more than 1 item to + # any of the parametrization lists above causes all tests but the first one + # to fail with the message: "AssertionError: tensor model parallel group is + # already initialized." + reload(parallel_state) + llm = LLM(model=model, + prefix_pool_memory_utilization=prefix_pool_memory_utilization) # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 prompts = [prefix + prompt for prompt in example_prompts] @@ -38,4 +49,42 @@ def test_prefix_caching( outputs_without_prefix, outputs_with_prefix): assert (output_without_prefix.outputs[0].token_ids == output_with_prefix.outputs[0].token_ids) - assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 + if prefix_pool_memory_utilization == 0: + assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 0 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.3, 0.4]) +def test_prefix_caching_with_multiple_prefixes( + example_prompts, model: str, max_tokens: int, + prefix_pool_memory_utilization: float): + """ + Tests that the scheduler prefix pool size (length) does not go over the + maximum capacity at any moment in time. + """ + # IMPORTANT: If this line is removed from here, adding more than 1 item to + # any of the parametrization lists above causes all tests but the first one + # to fail with the message: "AssertionError: tensor model parallel group is + # already initialized." + reload(parallel_state) + llm = LLM(model=model, + prefix_pool_memory_utilization=prefix_pool_memory_utilization) + + # Use 10 different prefixes: + for i in range(10): + new_prefix = str(i) + ' ' + prefix + # -1 since the last token can change when concatenating prompts. + prefix_pos = len(llm.llm_engine.tokenizer.encode(new_prefix)) - 1 + prompts = [new_prefix + prompt for prompt in example_prompts] + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens) + outputs_with_prefix = llm.generate(prompts, + sampling_params, + prefix_pos=[prefix_pos] * + len(prompts)) + outputs_without_prefix = llm.generate(prompts, sampling_params) + for output_without_prefix, output_with_prefix in zip( + outputs_without_prefix, outputs_with_prefix): + assert (output_without_prefix.outputs[0].token_ids == + output_with_prefix.outputs[0].token_ids) diff --git a/tests/prefix_caching/test_prefix_pool.py b/tests/prefix_caching/test_prefix_pool.py new file mode 100644 index 000000000000..2a4b6709f29e --- /dev/null +++ b/tests/prefix_caching/test_prefix_pool.py @@ -0,0 +1,138 @@ +from vllm.prefix import PrefixPool + +import pytest + + +@pytest.fixture +def no_max_capacity_prefix_pool() -> PrefixPool: + return PrefixPool(block_size=32, max_capacity_in_blocks=float('inf')) + + +def test_prefix_length_behaviours(no_max_capacity_prefix_pool: PrefixPool): + """ + This test checks that prefixes of length less than pool.block_size are not created and are not added to the pool. + It also checks that prefixes of length equal to or greater to pool.block_size are created and added to the pool. + """ + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size - 1))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) + prefix_3 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size * 2))) + assert prefix_1 is None + assert prefix_2 is not None + assert prefix_3 is not None + assert len(no_max_capacity_prefix_pool) == 2 + + +def test_same_prefix_added_twice(no_max_capacity_prefix_pool: PrefixPool): + """ + Tests that when a prefix is added more than once to the pool, all subsequent additions + return the same prefix object that was created the first time. + """ + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) + assert prefix_1 is prefix_2 + assert len(no_max_capacity_prefix_pool) == 1 + + +def test_prefix_pool_max_capacity(): + """ + Tests that the pool is evicting prefixes when it reaches max capacity. + """ + max_capacity_in_blocks = 2 + max_capacity_prefix_pool = PrefixPool( + block_size=32, max_capacity_in_blocks=max_capacity_in_blocks) + + # Tests that on the third insertion, new object is created because capacity limits reached, + # but that the newly created object is equal to the old object + prefix_1 = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size))) + _ = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size * 2))) + prefix_3 = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size))) + assert prefix_1 is not prefix_3 + assert prefix_1 == prefix_3 + + assert len(max_capacity_prefix_pool) == 1 + assert max_capacity_prefix_pool.current_block_usage == 1 + + +def test_current_block_usage(): + """ + Tests that the current_block_usage property remains the same thorough the + lifetime of the pool when adding prefixes that are always the same length equal + to the max capacity. + """ + max_capacity_in_blocks = 2 + max_capacity_prefix_pool = PrefixPool( + block_size=32, max_capacity_in_blocks=max_capacity_in_blocks) + + for _ in range(10): + _ = max_capacity_prefix_pool.add_or_get_prefix( + list( + range(max_capacity_prefix_pool.block_size * + max_capacity_in_blocks))) + assert len(max_capacity_prefix_pool) == 1 + assert max_capacity_prefix_pool.current_block_usage == max_capacity_in_blocks + + +def test_prefix_truncation_1(): + """ + Tests that prefix is truncated if it exceeds the max capacity. + """ + prefix_pool = PrefixPool(block_size=1, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4]) + assert prefix.token_ids == (1, 2) + + +def test_prefix_truncation_2(): + """ + Testing truncation on non-block boundary + """ + prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=3) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) + assert prefix.token_ids == (1, 2, 3, 4) + + +def test_prefix_truncation_3(): + """ + Tests truncation because of both max capacity exceeded and no block boundary. + """ + prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) + assert prefix.token_ids == (1, 2, 3, 4) + + +def test_none_prefix_returned_1(): + """ + Tests that when the max capacity is zero, no prefix is created and None is returned. + """ + prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=0) + prefix = prefix_pool.add_or_get_prefix(list(range(prefix_pool.block_size))) + assert prefix is None + assert len(prefix_pool) == 0 + + +def test_none_prefix_returned_2(): + """ + Tests that when prefix length is less than block size, a None prefix is returned. + """ + prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix( + list(range(prefix_pool.block_size - 1))) + assert prefix is None + assert len(prefix_pool) == 0 + + +def test_assertion_raised_with_invalid_max_capacity(): + with pytest.raises(AssertionError): + _ = PrefixPool(32, max_capacity_in_blocks=-1) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/block.py b/vllm/block.py index 5fe39ed47b2f..8a0c2d054abb 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -60,6 +60,12 @@ def __init__( self.block_number = block_number self.block_size = block_size + # Contains the number of sequences that share this block that + # are currently allocated in the same device as this block. + # Notice that prefix blocks will have an extra 1 added to this + # reference count to guarantee that prefix blocks are not deallocated + # by the standard way that the block manager uses to free the memory + # for blocks. self.ref_count = 0 def __repr__(self) -> str: diff --git a/vllm/config.py b/vllm/config.py index 197f20c1ec9a..822e914c6697 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -279,19 +279,19 @@ class CacheConfig: cache_dtype: Data type for kv cache storage. """ - def __init__( - self, - block_size: int, - gpu_memory_utilization: float, - swap_space: int, - cache_dtype: str, - sliding_window: Optional[int] = None, - ) -> None: + def __init__(self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + sliding_window: Optional[int] = None, + prefix_pool_memory_utilization: float = 0) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window + self.prefix_pool_memory_utilization = prefix_pool_memory_utilization self._verify_args() self._verify_cache_dtype() @@ -304,6 +304,36 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.prefix_pool_memory_utilization < 0: + raise ValueError( + "prefix_pool_memory_utilization must be non negative. " + f"{self.prefix_pool_memory_utilization}.") + if self.prefix_pool_memory_utilization > self.gpu_memory_utilization: + raise ValueError( + "prefix_pool_memory_utilization must be less than or equal to " + "gpu_memory_utilization.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype == "fp8_e5m2": + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is lower than 11.8." + ) + device_name = torch.cuda.get_device_name() + if "AMD" in device_name: + raise NotImplementedError( + "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") + logger.info( + "Using fp8_e5m2 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 7f91051f03ac..450166a8da31 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,7 @@ from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +from vllm.prefix import Prefix class BlockAllocator: @@ -47,6 +48,12 @@ def free(self, block: PhysicalTokenBlock) -> None: if block.ref_count == 0: self.free_blocks.append(block) + def force_free(self, block: PhysicalTokenBlock) -> None: + """Force free a block without checking its ref count. + Currently used to free prefix blocks, whose block.ref_count is going + to be 1 at this moment in time. Handle with care!""" + self.free_blocks.append(block) + def get_num_free_blocks(self) -> int: return len(self.free_blocks) @@ -133,12 +140,15 @@ def allocate(self, seq_group: SequenceGroup) -> None: num_prefix_blocks = 0 prefix = seq_group.prefix - if prefix is not None and prefix.allocated: - # Prefix has already been allocated. Use the existing block table. - num_prompt_blocks -= prefix.get_num_blocks() - for block in prefix.block_table: - block.ref_count += seq_group.num_seqs() - block_table.append(block) + if prefix is not None: + prefix.seq_ref_count += seq_group.num_seqs() + + if prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None @@ -155,6 +165,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: # KV cache in this run. num_prefix_blocks = prefix.get_num_blocks() prefix_block_table = block_table[:num_prefix_blocks] + # The extra reference count increment on the prefix blocks + # guarantees that the prefix blocks will not be freed even + # if all sequences that share the prefix are finished. for block in prefix_block_table: block.ref_count += 1 prefix.set_block_table(prefix_block_table) @@ -202,11 +215,16 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: self.gpu_allocator.free(last_block) return last_block.block_number, new_block.block_number - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + def fork(self, + parent_seq: Sequence, + child_seq: Sequence, + prefix: Optional[Prefix] = None) -> None: # NOTE: fork does not allocate a new physical block. # Thus, it is always safe from OOM. src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.copy() + if prefix is not None: + prefix.seq_ref_count += 1 for block in src_block_table: block.ref_count += 1 @@ -223,12 +241,19 @@ def _get_physical_blocks( def can_swap_in(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) + seq_group_blocks = len(blocks) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_free_blocks = self.gpu_allocator.get_num_free_blocks() + + # If sequence has a prefix and prefix is already allocated, + # there is no need to count blocks for the prefix. + if seq_group.prefix is not None and seq_group.prefix.allocated: + seq_group_blocks -= seq_group.prefix.get_num_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs + num_required_blocks = seq_group_blocks + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: @@ -276,6 +301,8 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: + # TODO(jadielam) The `in` operation on a list (the block_table) might be slow + # if the list is long. Think of alternatives to speed this up. if (seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table): # NOTE: We do not swap out the prefix blocks for now. @@ -306,14 +333,24 @@ def _free_block_table(self, block_table: BlockTable) -> None: else: self.cpu_allocator.free(block) - def free(self, seq: Sequence) -> None: + def free(self, seq: Sequence, prefix: Optional[Prefix] = None) -> None: if seq.seq_id not in self.block_tables: # Already freed or haven't been scheduled yet. return + if prefix is not None: + prefix.seq_ref_count -= 1 block_table = self.block_tables[seq.seq_id] self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_prefix_blocks(self, prefix: Prefix) -> None: + # Safety check here, don't know if necessary. + assert prefix.allocated and prefix.seq_ref_count == 0 + block_table = prefix.block_table + for block in set(block_table): + self.gpu_allocator.force_free(block) + prefix.block_table = None + def reset(self) -> None: for block_table in self.block_tables.values(): self._free_block_table(block_table) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4fdf9ec341cf..99c118be6f5a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.prefix import PrefixPool +from vllm.prefix import PrefixPool, Prefix logger = init_logger(__name__) @@ -98,7 +98,10 @@ def __init__( sliding_window=self.cache_config.sliding_window) # Create the prefix pool to cache the prefixes. - self.prefix_pool = PrefixPool(self.cache_config.block_size) + self.prefix_pool = PrefixPool( + self.cache_config.block_size, + max_capacity_in_blocks= \ + int(self.cache_config.prefix_pool_memory_utilization * self.cache_config.num_gpu_blocks)) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() @@ -149,7 +152,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if seq.is_finished(): continue seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) + self.free_seq(seq, aborted_group.prefix) def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped @@ -167,6 +170,22 @@ def _schedule(self) -> SchedulerOutputs: now = time.monotonic() # Join waiting sequences if possible. + # TODO(jadielam): There is an inefficiency here in the way we are scheduling + # sequences. We should first process swapped sequences, and then process waiting + # sequences (until resources allow of course). + # Currently when we process swapped sequences, we do not consider afterwards if + # there are enough resources to add some of the waiting sequences. + # In some situations (that is, when the number of swapped sequences is small) + # we might be leaving some resources unused because of that. + + # TODO(jadielam): Also, there seems to be a bigger problem with the current code: + # On this one I am not so sure, but I am making a note of it here to not forget. + # If there are no swapped sequences, we never consider if we have enough + # resources to add the next token on the currently running sequences. + # Question that is not clear to me regarding this: Do we mix and match and run + # requests in the prompt stage together with requests in the generation stage + # on an engine step? If the answer is NO, then this is not a problem, but + # if the answer is YES, then my concerns here might be valid. if not self.swapped: ignored_seq_groups: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = [] @@ -388,16 +407,28 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) + def fork_seq(self, + parent_seq: Sequence, + child_seq: Sequence, + prefix: Optional[Prefix] = None) -> None: + self.block_manager.fork(parent_seq, child_seq, prefix) - def free_seq(self, seq: Sequence) -> None: - self.block_manager.free(seq) + def free_seq(self, seq: Sequence, prefix: Optional[Prefix] = None) -> None: + self.block_manager.free(seq, prefix) def free_finished_seq_groups(self) -> None: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) + def free_old_prefixes(self): + """ + Deallocates the GPU memory of prefixes that have been moved to the list of candidates + for deallocation in the prefix pool and that are not currently being used by any sequence group. + """ + prefixes_to_free = self.prefix_pool.get_prefixes_to_deallocate() + for prefix in prefixes_to_free: + self.block_manager.free_prefix_blocks(prefix) + def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -423,6 +454,14 @@ def _preempt( blocks_to_swap_out: Dict[int, int], preemption_mode: Optional[PreemptionMode] = None, ) -> None: + """ + Preempts the given sequence group. PreemptionMode.RECOMPUTE can only + be used if the sequence group has a single sequence. + + Raises: + AssertionError if PreemptionMode.RECOMPUTE is used for a + sequence group with multiple sequences. + """ # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -454,7 +493,7 @@ def _preempt_by_recompute( assert len(seqs) == 1 for seq in seqs: seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + self.block_manager.free(seq, seq_group.prefix) # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 231ce3321cdc..2e57d21f7ad4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,6 +27,7 @@ class EngineArgs: block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 + prefix_pool_memory_utilization: float = 0.0 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 @@ -179,6 +180,15 @@ def add_cli_args( help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--prefix-pool-memory-utilization', + type=float, + default=EngineArgs.prefix_pool_memory_utilization, + help= + 'the fraction of GPU memory to be used by prefixes allocated and ' + 'present in the prefix pool. If 0, no prefix cache is used. It cannot be ' + 'larger than gpu_memory_utilization. If unspecified, will use the default of 0.' + ) parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -279,7 +289,8 @@ def create_engine_configs( cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.prefix_pool_memory_utilization) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0d836a1fb13a..9120b419efde 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -565,7 +565,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # not be used in the future iterations. parent.status = SequenceStatus.FINISHED_ABORTED seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) + self.scheduler.free_seq(parent, seq_group.prefix) continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: @@ -594,7 +594,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + self.scheduler.fork_seq(parent, seq, seq_group.prefix) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. @@ -602,7 +602,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # old sequences. for seq, parent in child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) return # Beam search case @@ -689,13 +689,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + self.scheduler.fork_seq(parent, seq, seq_group.prefix) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. for seq, parent in selected_child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) # Remove the unselected parent sequences from the sequence group and # free their memory in block manager. @@ -704,7 +704,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Remove the parent sequence if it is not selected for next # iteration seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) def _process_model_outputs( self, output: SamplerOutput, @@ -732,6 +732,11 @@ def _process_model_outputs( and not seq_group.prefix.computed): seq_group.prefix.computed = True + # Remove prefixes that are due for removal because they are no longer + # being used by any sequence group and have been moved to the candidates + # to remove list in the PrefixPool + self.scheduler.free_old_prefixes() + if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 614e6fa520c8..fbef31c7c0c3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -84,6 +84,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + prefix_pool_memory_utilization: float = 0, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -104,6 +105,7 @@ def __init__( enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + prefix_pool_memory_utilization=prefix_pool_memory_utilization, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/prefix.py b/vllm/prefix.py index 5b6e8e4b92be..eac9ef490d46 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Sequence, Tuple, Optional +from typing import Any, List, Sequence, Tuple, Optional +from collections import OrderedDict from vllm.block import BlockTable @@ -15,11 +16,7 @@ class Prefix: block_size: The block size of the executed model. """ - def __init__( - self, - token_ids: Sequence[int], - block_size: int, - ) -> None: + def __init__(self, token_ids: Sequence[int], block_size: int) -> None: self.token_ids = tuple(token_ids) self.block_size = block_size self.length = len(token_ids) @@ -28,6 +25,13 @@ def __init__( self.block_table: Optional[BlockTable] = None self.computed = False + # Contains a reference count of the number of sequences that share this + # prefix, regardless of whether they are swapped out or not. + # Must not be initialized to 1 at creation time because a prefix might be created + # and thrown away, or sequences sharing this prefix might never be allocated. + self.seq_ref_count = 0 + self.expired = False + @property def allocated(self) -> bool: return self.block_table is not None @@ -44,15 +48,19 @@ def get_length(self) -> int: def __hash__(self) -> int: return self.hash + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Prefix): + return False + return self.hash == other.hash + def set_block_table(self, block_table: BlockTable) -> None: self.block_table = block_table.copy() class PrefixPool: - """Manages all the prompt prefixes. - - NOTE: This feature is experimental and may be replaced with automatic - prefix caching in the future. + """Manages all the prompt prefixes. If the max_capacity argument is not None, + the pool will act as a LRU cache and remove the least recently used prefix once + the capacity is reached. Args: block_size: The block size of the executed model. @@ -60,28 +68,137 @@ class PrefixPool: Attributes: prefixes: A list of all the prefixes. block_size: The block size of the executed model. + max_capacity_in_blocks: The maximum number of blocks that can be used for prefixes in the pool + at any given time. The default value is 0, which effectively means there is no prefix cache. + If the value is float('inf'), is means that the capacity of the pool is unbounded (not recommended). """ - def __init__( - self, - block_size: int, - ) -> None: - # TODO(zhuohan): Add a capacity limit to the prefix pool. - self.prefixes: Dict[int, Prefix] = {} + def __init__(self, + block_size: int, + max_capacity_in_blocks: int | float = 0) -> None: + self.prefixes: OrderedDict[int, Prefix] = OrderedDict() self.block_size = block_size + self._current_block_usage = 0 + + self.max_allowed_prefix_length = float('inf') + if max_capacity_in_blocks < float('inf'): + assert max_capacity_in_blocks >= 0, "max_capacity must be non-negative." + self.max_allowed_prefix_length = self.block_size * max_capacity_in_blocks + + self.max_capacity_in_blocks = max_capacity_in_blocks + + self._candidates_to_deallocate: List[Prefix] = [] + + def __len__(self) -> int: + """ + Returns the number of prefixes in the pool. + """ + return len(self.prefixes) + + @property + def current_block_usage(self) -> int: + """ + Returns the number of blocks currently used by the pool. + """ + return self._current_block_usage def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size + new_length = min(new_length, self.max_allowed_prefix_length) return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int], - lora_int_id: int) -> Optional[Prefix]: + def add_or_get_prefix(self, + token_ids: Sequence[int], + lora_int_id: int = 0) -> Optional[Prefix]: + """ + Arguments: + - token_ids: The token ids of the prefix to add to the pool. + - lora_int_id: The lora_int_id of the request, which will be used to hash the prefix too. + If the lora_int_id is not given, defaults to 0. + + Adds a prefix to the pool if it does not already exist. If it does exist, + it returns the existing prefix. If the pool is at max capacity, it removes + the least recently used prefix before adding the new prefix. + + There are two situations when None is returned: + 1. If the length of the token_ids of the prefix is less than the block_size, no + prefix is created and None is returned. + 2. If the max_capacity of the pool is 0, then no prefix is created and None is returned. + + There is also two situations where the prefix is shortened to fit block boundaries: + 1. If the length of the token_ids of the prefix is not a multiple of the block_size. + 2. If the number of blocks needed to allocate the prefix exceeds the max_capacity of the pool, + the prefix is shortened to fit the max_capacity. Notice that this second occurence happens once + we have already attempted all other recourses to be able to allocate the prefix on its entirity, such + as evicting older prefixes from the pool. + """ + if self.max_capacity_in_blocks == 0: + # Prefix cache is disabled. + return None + token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None + + # Check first if prefix exists, moving it to the end of the OrderedDict. + # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) prefix_hash = hash((prefix, lora_int_id)) - if prefix_hash not in self.prefixes: - self.prefixes[prefix_hash] = prefix - return self.prefixes[prefix_hash] + if prefix_hash in self.prefixes: + prefix = self.prefixes[prefix_hash] + self.prefixes.move_to_end(prefix_hash) + return prefix + + # Prefix does not exist. Add created prefix to the pool and return it. + # Always, before adding anything to the pool, check the capacity constraints and + # remove the least recently used prefix if capacity constraints are violated. + prefix_num_blocks = prefix.get_num_blocks() + + while self._current_block_usage > 0 and prefix_num_blocks > self.max_capacity_in_blocks - self._current_block_usage: + _, candidate_prefix = self.prefixes.popitem(last=False) + self._candidates_to_deallocate.append(candidate_prefix) + self._current_block_usage -= candidate_prefix.get_num_blocks() + + self.prefixes[prefix_hash] = prefix + self._current_block_usage += prefix_num_blocks + return prefix + + def get_prefixes_to_deallocate(self) -> List[Prefix]: + """ + Returns a list of prefixes that are ready to be deallocated. + For a prefix to be deallocated, it must fulfill the following two conditions: + 1. It must have been allocated already. + 2. It must have a seq_ref_count of 0. + + Condition number 1 is not evident, but is necessary because of the following rare situation: + 1. Prefix A is created, added to the pool, and assigned to a sequence group S. + Sequence group S becomes part of the sequence groups waiting to be allocated + 2. At some point in the future, while sequence group S is still waiting to be allocated, + prefix A is removed from the pool because of capacity constraints. + 3. If we remove prefix A from the self._candidates_to_free list at this point, + we will end up with a memory leak because of the following situation: + 3.1. Sequence group S eventually gets allocated, altogether with prefix A, + which is no longer in any data structure in the pool. + 3.2 Prefix A memory will never be removed from the GPU, even if its seq_ref_count + reaches 0 in the future, because it is not in the pool anymore and that + means that the .get_prefixes_to_free() function will not return it. + """ + indexes_to_remove = [ + i for i, prefix in enumerate(self._candidates_to_deallocate) + if prefix.seq_ref_count == 0 and prefix.allocated + ] + + # Mark the prefix as expired, so that if a sequence group still in the + # waiting list that shares this prefix tries to allocate it as a prefix, + # it will fail. + for i in indexes_to_remove: + prefix = self._candidates_to_deallocate[i] + prefix.expired = True + + # Popping needs to happen with the indexes_to_remove list in reverse order + # so that we don't get Index out of range errors + return [ + self._candidates_to_deallocate.pop(i) + for i in indexes_to_remove[::-1] + ] diff --git a/vllm/sequence.py b/vllm/sequence.py index d28627f47498..9632806cfe17 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -176,6 +176,10 @@ def append_token_id( self.data.append_token_id(token_id, logprobs[token_id]) def get_len(self) -> int: + """ + Returns the lenght of the sequence in number of tokens. This length + includes the prompt and the generated tokens up to that point. + """ return self.data.get_len() def get_prompt_len(self) -> int: @@ -255,7 +259,7 @@ def __init__( self.sampling_params = sampling_params self.arrival_time = arrival_time self.lora_request = lora_request - self.prefix: Optional[Prefix] = prefix + self._prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -264,6 +268,20 @@ def prompt(self) -> str: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt + @property + def prefix(self) -> Optional[Prefix]: + """ + If the prefix has been marked as expired, set the prefix to None. + The reason for this is that when a prefix has expired it means that it has + been moved out of the prefix pool, so is no longer valid, and it should + no longer be allocated as a prefix. + + Return the prefix. + """ + if self._prefix is not None and self._prefix.expired: + self._prefix = None + return self._prefix + @property def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt.