From 075939b288fbd66157f6666374c315a62e6300f8 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Thu, 18 Jan 2024 09:15:49 -0500 Subject: [PATCH 01/18] Added max_capacity to prefix pool, together with tests --- tests/prefix_caching/test_prefix_pool.py | 61 ++++++++++++++++++++ vllm/prefix.py | 71 +++++++++++++++++++----- 2 files changed, 117 insertions(+), 15 deletions(-) create mode 100644 tests/prefix_caching/test_prefix_pool.py diff --git a/tests/prefix_caching/test_prefix_pool.py b/tests/prefix_caching/test_prefix_pool.py new file mode 100644 index 000000000000..3c7f3e5900dd --- /dev/null +++ b/tests/prefix_caching/test_prefix_pool.py @@ -0,0 +1,61 @@ +from vllm.prefix import PrefixPool + +import pytest + +@pytest.fixture +def no_max_capacity_prefix_pool() -> PrefixPool: + return PrefixPool(block_size=32) + + +def test_prefix_length_behaviours(no_max_capacity_prefix_pool: PrefixPool): + """ + This test checks that prefixes of length less than pool.block_size are not created and are not added to the pool. + It also checks that prefixes of length equal to or greater to pool.block_size are created and added to the pool. + """ + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size - 1))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) + prefix_3 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size * 2))) + assert prefix_1 is None + assert prefix_2 is not None + assert prefix_3 is not None + assert len(no_max_capacity_prefix_pool) == 2 + +def test_same_prefix_added_twice(no_max_capacity_prefix_pool: PrefixPool): + """ + Tests that when a prefix is added more than once to the pool, all subsequent additions + return the same prefix object that was created the first time. + """ + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) + assert prefix_1 is prefix_2 + assert len(no_max_capacity_prefix_pool) == 1 + +def test_prefix_pool_max_capacity(): + max_capacity = 1 + max_capacity_prefix_pool = PrefixPool(block_size=32, max_capacity=max_capacity) + + # Tests that on the third insertion, new object is created because capacity limits reached, + # but that the newly created object is equal to the old object + prefix_1 = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size))) + _ = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size * 2))) + prefix_3 = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size))) + assert prefix_1 is not prefix_3 + assert prefix_1 == prefix_3 + + # Tests that the max capacity remains the same + for i in range(10): + _ = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size + i))) + assert len(max_capacity_prefix_pool) == max_capacity + + +def test_assertion_raised_with_invalid_max_capacity(): + with pytest.raises(AssertionError): + _ = PrefixPool(32, max_capacity=-1) + + with pytest.raises(AssertionError): + _ = PrefixPool(32, max_capacity=0) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) \ No newline at end of file diff --git a/vllm/prefix.py b/vllm/prefix.py index 5b6e8e4b92be..0425099c2aa7 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Sequence, Tuple, Optional +from typing import Any, Dict, List, Sequence, Tuple, Optional +from collections import OrderedDict from vllm.block import BlockTable @@ -18,7 +19,7 @@ class Prefix: def __init__( self, token_ids: Sequence[int], - block_size: int, + block_size: int ) -> None: self.token_ids = tuple(token_ids) self.block_size = block_size @@ -43,16 +44,20 @@ def get_length(self) -> int: def __hash__(self) -> int: return self.hash + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Prefix): + return False + return self.hash == other.hash def set_block_table(self, block_table: BlockTable) -> None: self.block_table = block_table.copy() class PrefixPool: - """Manages all the prompt prefixes. - - NOTE: This feature is experimental and may be replaced with automatic - prefix caching in the future. + """Manages all the prompt prefixes. If the max_capacity argument is not None, + the pool will act as a LRU cache and remove the least recently used prefix once + the capacity is reached. Args: block_size: The block size of the executed model. @@ -60,28 +65,64 @@ class PrefixPool: Attributes: prefixes: A list of all the prefixes. block_size: The block size of the executed model. + max_capacity: The maximum number of prefixes allowed in the pool at any given time. + The default value is None, which means there is no limit. It can only take positive + values. """ def __init__( self, block_size: int, + max_capacity: Optional[int] = None ) -> None: - # TODO(zhuohan): Add a capacity limit to the prefix pool. - self.prefixes: Dict[int, Prefix] = {} + self.prefixes: OrderedDict[int, Prefix] = OrderedDict() self.block_size = block_size - + + if max_capacity is not None: + # NOTE(to remove after consultation): I have also been thinking if we need to + # assert that max_capacity must be greater than or equal to the max allowed + # batch size. My own analysis has led me to believe that this is not necessary + # because even if a prefix is removed from the pool before it even gets computed, + # this will not stop the prefix from being computed. It only means that the prefix + # will not stay allocated for next cycles of computation and future requests + # with the prefix will have to recompute it. + assert max_capacity > 0, "max_capacity must be positive." + + self.max_capacity = max_capacity + + def __len__(self) -> int: + return len(self.prefixes) + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int], - lora_int_id: int) -> Optional[Prefix]: + def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + """ + Adds a prefix to the pool if it does not already exist. If it does exist, + it returns the existing prefix. If the pool is at max capacity, it removes + the least recently used prefix before adding the new prefix. + + Notice that if the length of token_ids is less than the block_size, no + prefix is created and None is returned. + """ token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None + + # Check first if prefix exists, moving it to the end of the OrderedDict. + # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) - prefix_hash = hash((prefix, lora_int_id)) - if prefix_hash not in self.prefixes: - self.prefixes[prefix_hash] = prefix - return self.prefixes[prefix_hash] + prefix_hash = hash(prefix) + if prefix_hash in self.prefixes: + prefix = self.prefixes[prefix_hash] + self.prefixes.move_to_end(prefix_hash) + return prefix + + # Prefix does not exist. Add created prefix to the pool and return it. + # Always, before adding anything to the pool, checking the capacity constraints. + if len(self.prefixes) == self.max_capacity: + _ = self.prefixes.popitem(last=False) + self.prefixes[prefix_hash] = prefix + return prefix From ee90c30230af6cd454e6bc8a37ff7bc39b2e00bf Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 19 Jan 2024 17:10:00 -0500 Subject: [PATCH 02/18] Added a deallocation mechanism for prefixes --- tests/prefix_caching/test_prefix_pool.py | 35 ++++++++---- vllm/block.py | 6 ++ vllm/core/block_manager.py | 42 +++++++++++++- vllm/core/scheduler.py | 50 ++++++++++++++--- vllm/engine/llm_engine.py | 17 ++++-- vllm/prefix.py | 71 +++++++++++++++++------- vllm/sequence.py | 4 ++ 7 files changed, 177 insertions(+), 48 deletions(-) diff --git a/tests/prefix_caching/test_prefix_pool.py b/tests/prefix_caching/test_prefix_pool.py index 3c7f3e5900dd..653406e3ee8e 100644 --- a/tests/prefix_caching/test_prefix_pool.py +++ b/tests/prefix_caching/test_prefix_pool.py @@ -2,6 +2,7 @@ import pytest + @pytest.fixture def no_max_capacity_prefix_pool() -> PrefixPool: return PrefixPool(block_size=32) @@ -12,39 +13,51 @@ def test_prefix_length_behaviours(no_max_capacity_prefix_pool: PrefixPool): This test checks that prefixes of length less than pool.block_size are not created and are not added to the pool. It also checks that prefixes of length equal to or greater to pool.block_size are created and added to the pool. """ - prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size - 1))) - prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) - prefix_3 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size * 2))) + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size - 1))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) + prefix_3 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size * 2))) assert prefix_1 is None assert prefix_2 is not None assert prefix_3 is not None assert len(no_max_capacity_prefix_pool) == 2 + def test_same_prefix_added_twice(no_max_capacity_prefix_pool: PrefixPool): """ Tests that when a prefix is added more than once to the pool, all subsequent additions return the same prefix object that was created the first time. """ - prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) - prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(list(range(no_max_capacity_prefix_pool.block_size))) + prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) + prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix( + list(range(no_max_capacity_prefix_pool.block_size))) assert prefix_1 is prefix_2 assert len(no_max_capacity_prefix_pool) == 1 + def test_prefix_pool_max_capacity(): max_capacity = 1 - max_capacity_prefix_pool = PrefixPool(block_size=32, max_capacity=max_capacity) + max_capacity_prefix_pool = PrefixPool(block_size=32, + max_capacity=max_capacity) # Tests that on the third insertion, new object is created because capacity limits reached, # but that the newly created object is equal to the old object - prefix_1 = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size))) - _ = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size * 2))) - prefix_3 = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size))) + prefix_1 = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size))) + _ = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size * 2))) + prefix_3 = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size))) assert prefix_1 is not prefix_3 assert prefix_1 == prefix_3 # Tests that the max capacity remains the same for i in range(10): - _ = max_capacity_prefix_pool.add_or_get_prefix(list(range(max_capacity_prefix_pool.block_size + i))) + _ = max_capacity_prefix_pool.add_or_get_prefix( + list(range(max_capacity_prefix_pool.block_size + i))) assert len(max_capacity_prefix_pool) == max_capacity @@ -58,4 +71,4 @@ def test_assertion_raised_with_invalid_max_capacity(): if __name__ == "__main__": import pytest - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/vllm/block.py b/vllm/block.py index 5fe39ed47b2f..8a0c2d054abb 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -60,6 +60,12 @@ def __init__( self.block_number = block_number self.block_size = block_size + # Contains the number of sequences that share this block that + # are currently allocated in the same device as this block. + # Notice that prefix blocks will have an extra 1 added to this + # reference count to guarantee that prefix blocks are not deallocated + # by the standard way that the block manager uses to free the memory + # for blocks. self.ref_count = 0 def __repr__(self) -> str: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 7f91051f03ac..917d86b8ad53 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,6 +5,7 @@ from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +from vllm.prefix import Prefix class BlockAllocator: @@ -47,6 +48,12 @@ def free(self, block: PhysicalTokenBlock) -> None: if block.ref_count == 0: self.free_blocks.append(block) + def force_free(self, block: PhysicalTokenBlock) -> None: + """Force free a block without checking its ref count. + Currently used to free prefix blocks, whose block.ref_count is going + to be 1 at this moment in time. Handle with care!""" + self.free_blocks.append(block) + def get_num_free_blocks(self) -> int: return len(self.free_blocks) @@ -133,6 +140,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: num_prefix_blocks = 0 prefix = seq_group.prefix + # Update the reference counts to the prefix from all the seqs + # in this group. + prefix.seq_ref_count += seq_group.num_seqs() if prefix is not None and prefix.allocated: # Prefix has already been allocated. Use the existing block table. num_prompt_blocks -= prefix.get_num_blocks() @@ -155,6 +165,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: # KV cache in this run. num_prefix_blocks = prefix.get_num_blocks() prefix_block_table = block_table[:num_prefix_blocks] + # The extra reference count increment on the prefix blocks + # guarantees that the prefix blocks will not be freed even + # if all sequences that share the prefix are finished. for block in prefix_block_table: block.ref_count += 1 prefix.set_block_table(prefix_block_table) @@ -202,11 +215,16 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: self.gpu_allocator.free(last_block) return last_block.block_number, new_block.block_number - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + def fork(self, + parent_seq: Sequence, + child_seq: Sequence, + prefix: Optional[Prefix] = None) -> None: # NOTE: fork does not allocate a new physical block. # Thus, it is always safe from OOM. src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.copy() + if prefix is not None: + prefix.seq_ref_count += 1 for block in src_block_table: block.ref_count += 1 @@ -223,12 +241,19 @@ def _get_physical_blocks( def can_swap_in(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) + seq_group_blocks = len(blocks) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_free_blocks = self.gpu_allocator.get_num_free_blocks() + + # If sequence has a prefix and prefix is already allocated, + # there is no need to count blocks for the prefix. + if seq_group.prefix is not None and seq_group.prefix.allocated: + seq_group_blocks -= seq_group.prefix.get_num_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs + num_required_blocks = seq_group_blocks + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: @@ -276,6 +301,8 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: + # TODO(jadielam) The `in` operation on a list (the block_table) might be slow + # if the list is long. Think of alternatives to speed this up. if (seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table): # NOTE: We do not swap out the prefix blocks for now. @@ -306,14 +333,23 @@ def _free_block_table(self, block_table: BlockTable) -> None: else: self.cpu_allocator.free(block) - def free(self, seq: Sequence) -> None: + def free(self, seq: Sequence, prefix: Optional[Prefix] = None) -> None: if seq.seq_id not in self.block_tables: # Already freed or haven't been scheduled yet. return + if prefix is not None: + prefix.seq_ref_count -= 1 block_table = self.block_tables[seq.seq_id] self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_prefix_blocks(self, prefix: Prefix) -> None: + # Safety check here, don't know if necessary. + assert prefix.allocated and prefix.seq_ref_count == 0 + block_table = prefix.block_table + for block in set(block_table): + self.gpu_allocator.force_free(block) + def reset(self) -> None: for block_table in self.block_tables.values(): self._free_block_table(block_table) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4fdf9ec341cf..d5e89c411060 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.prefix import PrefixPool +from vllm.prefix import PrefixPool, Prefix logger = init_logger(__name__) @@ -149,7 +149,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if seq.is_finished(): continue seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) + self.free_seq(seq, aborted_group.prefix) def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped @@ -167,6 +167,22 @@ def _schedule(self) -> SchedulerOutputs: now = time.monotonic() # Join waiting sequences if possible. + # TODO(jadielam): There is an inefficiency here in the way we are scheduling + # sequences. We should first process swapped sequences, and then process waiting + # sequences (until resources allow of course). + # Currently when we process swapped sequences, we do not consider afterwards if + # there are enough resources to add some of the waiting sequences. + # In some situations (that is, when the number of swapped sequences is small) + # we might be leaving some resources unused because of that. + + # TODO(jadielam): Also, there seems to be a bigger problem with the current code: + # On this one I am not so sure, but I am making a note of it here to not forget. + # If there are no swapped sequences, we never consider if we have enough + # resources to add the next token on the currently running sequences. + # Question that is not clear to me regarding this: Do we mix and match and run + # requests in the prompt stage together with requests in the generation stage + # on an engine step? If the answer is NO, then this is not a problem, but + # if the answer is YES, then my concerns here might be valid. if not self.swapped: ignored_seq_groups: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = [] @@ -388,16 +404,28 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) + def fork_seq(self, + parent_seq: Sequence, + child_seq: Sequence, + prefix: Optional[Prefix] = None) -> None: + self.block_manager.fork(parent_seq, child_seq, prefix) - def free_seq(self, seq: Sequence) -> None: - self.block_manager.free(seq) + def free_seq(self, seq: Sequence, prefix: Optional[Prefix] = None) -> None: + self.block_manager.free(seq, prefix) def free_finished_seq_groups(self) -> None: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) + def free_old_prefixes(self): + """ + Deallocates the GPU memory of prefixes that have been moved to the list of candidates + for deallocation in the prefix pool and that are not currently being used by any sequence group. + """ + prefixes_to_free = self.prefix_pool.get_prefixes_to_free() + for prefix in prefixes_to_free: + self.block_manager.free_prefix(prefix) + def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -423,6 +451,14 @@ def _preempt( blocks_to_swap_out: Dict[int, int], preemption_mode: Optional[PreemptionMode] = None, ) -> None: + """ + Preempts the given sequence group. PreemptionMode.RECOMPUTE can only + be used if the sequence group has a single sequence. + + Raises: + AssertionError if PreemptionMode.RECOMPUTE is used for a + sequence group with multiple sequences. + """ # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -454,7 +490,7 @@ def _preempt_by_recompute( assert len(seqs) == 1 for seq in seqs: seq.status = SequenceStatus.WAITING - self.block_manager.free(seq) + self.block_manager.free(seq, seq_group.prefix) # NOTE: For FCFS, we insert the preempted sequence group to the front # of the waiting queue. self.waiting.appendleft(seq_group) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0d836a1fb13a..9120b419efde 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -565,7 +565,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # not be used in the future iterations. parent.status = SequenceStatus.FINISHED_ABORTED seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) + self.scheduler.free_seq(parent, seq_group.prefix) continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: @@ -594,7 +594,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + self.scheduler.fork_seq(parent, seq, seq_group.prefix) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. @@ -602,7 +602,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # old sequences. for seq, parent in child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) return # Beam search case @@ -689,13 +689,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq is not parent: seq_group.add(seq) if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) + self.scheduler.fork_seq(parent, seq, seq_group.prefix) # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. for seq, parent in selected_child_seqs: if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) # Remove the unselected parent sequences from the sequence group and # free their memory in block manager. @@ -704,7 +704,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Remove the parent sequence if it is not selected for next # iteration seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + self.scheduler.free_seq(seq, seq_group.prefix) def _process_model_outputs( self, output: SamplerOutput, @@ -732,6 +732,11 @@ def _process_model_outputs( and not seq_group.prefix.computed): seq_group.prefix.computed = True + # Remove prefixes that are due for removal because they are no longer + # being used by any sequence group and have been moved to the candidates + # to remove list in the PrefixPool + self.scheduler.free_old_prefixes() + if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/prefix.py b/vllm/prefix.py index 0425099c2aa7..2e456a1c1eab 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Sequence, Tuple, Optional +from typing import Any, List, Sequence, Tuple, Optional from collections import OrderedDict from vllm.block import BlockTable @@ -16,11 +16,7 @@ class Prefix: block_size: The block size of the executed model. """ - def __init__( - self, - token_ids: Sequence[int], - block_size: int - ) -> None: + def __init__(self, token_ids: Sequence[int], block_size: int) -> None: self.token_ids = tuple(token_ids) self.block_size = block_size self.length = len(token_ids) @@ -29,6 +25,12 @@ def __init__( self.block_table: Optional[BlockTable] = None self.computed = False + # Contains a reference count of the number of sequences that share this + # prefix, regardless of whether they are swapped out or not. + # Must not be initialized to 1 at creation time because a prefix might be created + # and thrown away, or sequences sharing this prefix might never be allocated. + self.seq_ref_count = 0 + @property def allocated(self) -> bool: return self.block_table is not None @@ -44,7 +46,7 @@ def get_length(self) -> int: def __hash__(self) -> int: return self.hash - + def __eq__(self, other: Any) -> bool: if not isinstance(other, Prefix): return False @@ -70,29 +72,28 @@ class PrefixPool: values. """ - def __init__( - self, - block_size: int, - max_capacity: Optional[int] = None - ) -> None: + def __init__(self, + block_size: int, + max_capacity: Optional[int] = None) -> None: self.prefixes: OrderedDict[int, Prefix] = OrderedDict() self.block_size = block_size - + if max_capacity is not None: - # NOTE(to remove after consultation): I have also been thinking if we need to - # assert that max_capacity must be greater than or equal to the max allowed + # NOTE(to remove after consultation): I have also been thinking if we need to + # assert that max_capacity must be greater than or equal to the max allowed # batch size. My own analysis has led me to believe that this is not necessary # because even if a prefix is removed from the pool before it even gets computed, # this will not stop the prefix from being computed. It only means that the prefix # will not stay allocated for next cycles of computation and future requests # with the prefix will have to recompute it. assert max_capacity > 0, "max_capacity must be positive." - + self.max_capacity = max_capacity + self._candidates_to_free: List[Prefix] = [] def __len__(self) -> int: return len(self.prefixes) - + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size return tuple(token_ids[:new_length]) @@ -110,7 +111,7 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: if len(token_ids) == 0: # Prefix is empty. return None - + # Check first if prefix exists, moving it to the end of the OrderedDict. # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) @@ -119,10 +120,38 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: prefix = self.prefixes[prefix_hash] self.prefixes.move_to_end(prefix_hash) return prefix - + # Prefix does not exist. Add created prefix to the pool and return it. - # Always, before adding anything to the pool, checking the capacity constraints. + # Always, before adding anything to the pool, check the capacity constraints and + # remove the least recently used prefix if capacity constraints are violated. if len(self.prefixes) == self.max_capacity: - _ = self.prefixes.popitem(last=False) + _, candidate_prefix = self.prefixes.popitem(last=False) + self._candidates_to_free.append(candidate_prefix) self.prefixes[prefix_hash] = prefix return prefix + + def get_prefixes_to_free(self) -> List[Prefix]: + """ + Returns a list of prefixes that are ready to be deallocated. + For a prefix to be deallocated, it must fulfill the following two conditions: + 1. It must have been allocated already. + 2. It must have a seq_ref_count of 0. + + Condition number 1 is not evident, but is necessary because of the following rare situation: + 1. Prefix A is created, added to the pool, and assigned to a sequence group S. + Sequence group S becomes part of the sequence groups waiting to be allocated + 2. At some point in the future, while sequence group S is still waiting to be allocated, + prefix A is removed from the pool because of capacity constraints. + 3. If we remove prefix A from the self._candidates_to_free list at this point, + we will end up with a memory leak because of the following situation: + 3.1. Sequence group S eventually gets allocated, altogether with prefix A, + which is no longer in any data structure in the pool. + 3.2 Prefix A memory will never be removed from the GPU, even if its seq_ref_count + reaches 0 in the future, because it is not in the pool anymore and that + means that the .get_prefixes_to_free() function will not return it. + """ + indexes_to_remove = [ + i for i, prefix in enumerate(self._candidates_to_free) + if prefix.seq_ref_count == 0 and prefix.allocated + ] + return [self._candidates_to_free.pop(i) for i in indexes_to_remove] diff --git a/vllm/sequence.py b/vllm/sequence.py index d28627f47498..cd3f5e5a4450 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -176,6 +176,10 @@ def append_token_id( self.data.append_token_id(token_id, logprobs[token_id]) def get_len(self) -> int: + """ + Returns the lenght of the sequence in number of tokens. This length + includes the prompt and the generated tokens up to that point. + """ return self.data.get_len() def get_prompt_len(self) -> int: From dac2137d33abb08d83a71d46d6295853c3151915 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 19 Jan 2024 18:36:22 -0500 Subject: [PATCH 03/18] Fixing error making tests cases to fail --- vllm/core/block_manager.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 917d86b8ad53..00f78d616793 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -140,15 +140,15 @@ def allocate(self, seq_group: SequenceGroup) -> None: num_prefix_blocks = 0 prefix = seq_group.prefix - # Update the reference counts to the prefix from all the seqs - # in this group. - prefix.seq_ref_count += seq_group.num_seqs() - if prefix is not None and prefix.allocated: - # Prefix has already been allocated. Use the existing block table. - num_prompt_blocks -= prefix.get_num_blocks() - for block in prefix.block_table: - block.ref_count += seq_group.num_seqs() - block_table.append(block) + if prefix is not None: + prefix.seq_ref_count += seq_group.num_seqs() + + if prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None From eb0ed00ad179a5fa0d27eb0ca165b8ef4f5f3f2f Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Sat, 20 Jan 2024 10:51:58 -0500 Subject: [PATCH 04/18] Fixed incorrect function name. Changed status of prefix to not allocated after deallocating it --- vllm/core/block_manager.py | 1 + vllm/core/scheduler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 00f78d616793..450166a8da31 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -349,6 +349,7 @@ def free_prefix_blocks(self, prefix: Prefix) -> None: block_table = prefix.block_table for block in set(block_table): self.gpu_allocator.force_free(block) + prefix.block_table = None def reset(self) -> None: for block_table in self.block_tables.values(): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d5e89c411060..30388b893ecc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -424,7 +424,7 @@ def free_old_prefixes(self): """ prefixes_to_free = self.prefix_pool.get_prefixes_to_free() for prefix in prefixes_to_free: - self.block_manager.free_prefix(prefix) + self.block_manager.free_prefix_blocks(prefix) def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) From 38526fbf0c212f5671689303c42f1d882724b59e Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Sat, 20 Jan 2024 21:19:50 -0500 Subject: [PATCH 05/18] Removed unnecessary comment --- vllm/prefix.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/prefix.py b/vllm/prefix.py index 2e456a1c1eab..2f160d8ae831 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -79,13 +79,6 @@ def __init__(self, self.block_size = block_size if max_capacity is not None: - # NOTE(to remove after consultation): I have also been thinking if we need to - # assert that max_capacity must be greater than or equal to the max allowed - # batch size. My own analysis has led me to believe that this is not necessary - # because even if a prefix is removed from the pool before it even gets computed, - # this will not stop the prefix from being computed. It only means that the prefix - # will not stay allocated for next cycles of computation and future requests - # with the prefix will have to recompute it. assert max_capacity > 0, "max_capacity must be positive." self.max_capacity = max_capacity From f1508123f2ca32052eb7c31f3d97ab75b34ecd40 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Mon, 22 Jan 2024 08:19:33 -0500 Subject: [PATCH 06/18] Added prefix_pool_max_capacity parameter to the pertinent constructors and factories --- vllm/config.py | 7 +++++++ vllm/core/scheduler.py | 5 ++++- vllm/engine/arg_utils.py | 10 +++++++++- vllm/entrypoints/llm.py | 2 ++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 197f20c1ec9a..efd6e08468fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -286,12 +286,14 @@ def __init__( swap_space: int, cache_dtype: str, sliding_window: Optional[int] = None, + prefix_pool_max_capacity: Optional[int] = None ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window + self.prefix_pool_max_capacity = prefix_pool_max_capacity self._verify_args() self._verify_cache_dtype() @@ -304,6 +306,11 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.prefix_pool_max_capacity is not None: + if self.prefix_pool_max_capacity <= 0: + raise ValueError( + "prefix_pool_max_capacity must be positive. Got " + f"{self.prefix_pool_max_capacity}.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 30388b893ecc..5d7333c89cad 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -98,7 +98,10 @@ def __init__( sliding_window=self.cache_config.sliding_window) # Create the prefix pool to cache the prefixes. - self.prefix_pool = PrefixPool(self.cache_config.block_size) + self.prefix_pool = PrefixPool( + self.cache_config.block_size, + max_capacity=self.cache_config.prefix_pool_max_capacity + ) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 231ce3321cdc..f05e7951e97f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,6 +27,7 @@ class EngineArgs: block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 + prefix_pool_max_capacity: Optional[int] = None max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 @@ -179,6 +180,12 @@ def add_cli_args( help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--prefix-pool-max-capacity', + type=int, + default=EngineArgs.prefix_pool_max_capacity, + help='the maximum number of prefixes allowed in the prefix pool. If None, ' + 'there is no limit. It can only take positive values.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -279,7 +286,8 @@ def create_engine_configs( cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.prefix_pool_max_capacity) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 614e6fa520c8..0ec53d3d2382 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -84,6 +84,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + prefix_pool_max_capacity: Optional[int] = None, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -104,6 +105,7 @@ def __init__( enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + prefix_pool_max_capacity=prefix_pool_max_capacity, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) From 2e66162754f5592e5ef3d6b0d36217373ce8358d Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Mon, 22 Jan 2024 08:22:53 -0500 Subject: [PATCH 07/18] Fixed formatting issues --- vllm/config.py | 28 ++++++++++++++++++++++++---- vllm/core/scheduler.py | 5 ++--- vllm/engine/arg_utils.py | 3 ++- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index efd6e08468fe..f033e7de8b4a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -306,11 +306,31 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.prefix_pool_max_capacity is not None: - if self.prefix_pool_max_capacity <= 0: + if self.prefix_pool_max_capacity is not None and self.prefix_pool_max_capacity <= 0: + raise ValueError("prefix_pool_max_capacity must be positive. Got " + f"{self.prefix_pool_max_capacity}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype == "fp8_e5m2": + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): raise ValueError( - "prefix_pool_max_capacity must be positive. Got " - f"{self.prefix_pool_max_capacity}.") + "FP8 is not supported when cuda version is lower than 11.8." + ) + device_name = torch.cuda.get_device_name() + if "AMD" in device_name: + raise NotImplementedError( + "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") + logger.info( + "Using fp8_e5m2 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5d7333c89cad..3758a7ea8429 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -99,9 +99,8 @@ def __init__( # Create the prefix pool to cache the prefixes. self.prefix_pool = PrefixPool( - self.cache_config.block_size, - max_capacity=self.cache_config.prefix_pool_max_capacity - ) + self.cache_config.block_size, + max_capacity=self.cache_config.prefix_pool_max_capacity) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f05e7951e97f..88fab8588ecb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -184,7 +184,8 @@ def add_cli_args( '--prefix-pool-max-capacity', type=int, default=EngineArgs.prefix_pool_max_capacity, - help='the maximum number of prefixes allowed in the prefix pool. If None, ' + help= + 'the maximum number of prefixes allowed in the prefix pool. If None, ' 'there is no limit. It can only take positive values.') parser.add_argument('--max-num-batched-tokens', type=int, From fe0258f96fd05dfc40f6d340c6697bf2d752f1e3 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Mon, 22 Jan 2024 09:25:19 -0500 Subject: [PATCH 08/18] Added test cases to test prefix pool with max capacity --- tests/prefix_caching/test_prefix_caching.py | 42 ++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 1e301bedfc21..db902ba74a3e 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,8 +2,12 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ +from typing import Optional +from importlib import reload + import pytest +import vllm.model_executor.parallel_utils.parallel_state as parallel_state from vllm import LLM, SamplingParams prefix = ( @@ -20,12 +24,19 @@ @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("prefix_pool_max_capacity", [None, 1, 2]) def test_prefix_caching( example_prompts, model: str, max_tokens: int, + prefix_pool_max_capacity: Optional[int], ): - llm = LLM(model=model) + # IMPORTANT: If this line is removed from here, adding more than 1 item to + # any of the parametrization lists above causes all tests but the first one + # to fail with the message: "AssertionError: tensor model parallel group is + # already initialized." + reload(parallel_state) + llm = LLM(model=model, prefix_pool_max_capacity=prefix_pool_max_capacity) # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 prompts = [prefix + prompt for prompt in example_prompts] @@ -39,3 +50,32 @@ def test_prefix_caching( assert (output_without_prefix.outputs[0].token_ids == output_with_prefix.outputs[0].token_ids) assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("prefix_pool_max_capacity", [1, 2, 4, 6]) +def test_prefix_caching_with_multiple_prefixes( + example_prompts, model: str, max_tokens: int, + prefix_pool_max_capacity: Optional[int]): + """ + Tests that the scheduler prefix pool size (length) does not go over the + maximum capacity at any moment in time. + """ + reload(parallel_state) + llm = LLM(model="facebook/opt-125m", + prefix_pool_max_capacity=prefix_pool_max_capacity) + # -1 since the last token can change when concatenating prompts. + + # Use 10 different prefixes: + for i in range(prefix_pool_max_capacity + 1): + new_prefix = str(i) + ' ' + prefix + prefix_pos = len(llm.llm_engine.tokenizer.encode(new_prefix)) - 1 + prompts = [new_prefix + prompt for prompt in example_prompts] + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens) + _ = llm.generate(prompts, + sampling_params, + prefix_pos=[prefix_pos] * len(prompts)) + assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == min( + i + 1, prefix_pool_max_capacity) From f770a64373e4fee5e76d3e34c069ad5f96280169 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Mon, 22 Jan 2024 11:57:57 -0500 Subject: [PATCH 09/18] Fixing Index out of range bug --- tests/prefix_caching/test_prefix_caching.py | 8 ++++++-- vllm/prefix.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index db902ba74a3e..4eb5bd1222ea 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -62,14 +62,18 @@ def test_prefix_caching_with_multiple_prefixes( Tests that the scheduler prefix pool size (length) does not go over the maximum capacity at any moment in time. """ + # IMPORTANT: If this line is removed from here, adding more than 1 item to + # any of the parametrization lists above causes all tests but the first one + # to fail with the message: "AssertionError: tensor model parallel group is + # already initialized." reload(parallel_state) llm = LLM(model="facebook/opt-125m", prefix_pool_max_capacity=prefix_pool_max_capacity) - # -1 since the last token can change when concatenating prompts. - + # Use 10 different prefixes: for i in range(prefix_pool_max_capacity + 1): new_prefix = str(i) + ' ' + prefix + # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(new_prefix)) - 1 prompts = [new_prefix + prompt for prompt in example_prompts] sampling_params = SamplingParams(temperature=0.0, diff --git a/vllm/prefix.py b/vllm/prefix.py index 2f160d8ae831..f8b8d1c51329 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -147,4 +147,6 @@ def get_prefixes_to_free(self) -> List[Prefix]: i for i, prefix in enumerate(self._candidates_to_free) if prefix.seq_ref_count == 0 and prefix.allocated ] - return [self._candidates_to_free.pop(i) for i in indexes_to_remove] + # Popping needs to happen with the indexes_to_remove list in reverse order + # so that we don't get Index out of range errors + return [self._candidates_to_free.pop(i) for i in indexes_to_remove[::-1]] From 58634a77df1118d8eff292d819853a08638c1f22 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Mon, 22 Jan 2024 11:59:29 -0500 Subject: [PATCH 10/18] Fixing formatting --- tests/prefix_caching/test_prefix_caching.py | 2 +- vllm/prefix.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 4eb5bd1222ea..d7380b88e4c2 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -69,7 +69,7 @@ def test_prefix_caching_with_multiple_prefixes( reload(parallel_state) llm = LLM(model="facebook/opt-125m", prefix_pool_max_capacity=prefix_pool_max_capacity) - + # Use 10 different prefixes: for i in range(prefix_pool_max_capacity + 1): new_prefix = str(i) + ' ' + prefix diff --git a/vllm/prefix.py b/vllm/prefix.py index f8b8d1c51329..2c6e453c6e88 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -149,4 +149,6 @@ def get_prefixes_to_free(self) -> List[Prefix]: ] # Popping needs to happen with the indexes_to_remove list in reverse order # so that we don't get Index out of range errors - return [self._candidates_to_free.pop(i) for i in indexes_to_remove[::-1]] + return [ + self._candidates_to_free.pop(i) for i in indexes_to_remove[::-1] + ] From 83cf2b07293e9cbe9b3aa607fb397a488c12ac3b Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Wed, 24 Jan 2024 20:43:01 -0500 Subject: [PATCH 11/18] Capacity of prefix pool now measured in terms of gpu blocks --- tests/prefix_caching/test_prefix_caching.py | 31 +++++---- tests/prefix_caching/test_prefix_pool.py | 76 ++++++++++++++++++--- vllm/config.py | 15 ++-- vllm/core/scheduler.py | 5 +- vllm/engine/arg_utils.py | 12 ++-- vllm/entrypoints/llm.py | 4 +- vllm/prefix.py | 65 +++++++++++++----- 7 files changed, 154 insertions(+), 54 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index d7380b88e4c2..0ecba11344b1 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -24,19 +24,19 @@ @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("prefix_pool_max_capacity", [None, 1, 2]) +@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.2]) def test_prefix_caching( example_prompts, model: str, max_tokens: int, - prefix_pool_max_capacity: Optional[int], + prefix_pool_memory_utilization: float, ): # IMPORTANT: If this line is removed from here, adding more than 1 item to # any of the parametrization lists above causes all tests but the first one # to fail with the message: "AssertionError: tensor model parallel group is # already initialized." reload(parallel_state) - llm = LLM(model=model, prefix_pool_max_capacity=prefix_pool_max_capacity) + llm = LLM(model=model, prefix_pool_memory_utilization=prefix_pool_memory_utilization) # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 prompts = [prefix + prompt for prompt in example_prompts] @@ -49,15 +49,15 @@ def test_prefix_caching( outputs_without_prefix, outputs_with_prefix): assert (output_without_prefix.outputs[0].token_ids == output_with_prefix.outputs[0].token_ids) - assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 - + if prefix_pool_memory_utilization == 0: + assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 0 @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("prefix_pool_max_capacity", [1, 2, 4, 6]) +@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.3, 0.4]) def test_prefix_caching_with_multiple_prefixes( example_prompts, model: str, max_tokens: int, - prefix_pool_max_capacity: Optional[int]): + prefix_pool_memory_utilization: float): """ Tests that the scheduler prefix pool size (length) does not go over the maximum capacity at any moment in time. @@ -67,19 +67,24 @@ def test_prefix_caching_with_multiple_prefixes( # to fail with the message: "AssertionError: tensor model parallel group is # already initialized." reload(parallel_state) - llm = LLM(model="facebook/opt-125m", - prefix_pool_max_capacity=prefix_pool_max_capacity) + llm = LLM(model=model, + prefix_pool_memory_utilization=prefix_pool_memory_utilization) # Use 10 different prefixes: - for i in range(prefix_pool_max_capacity + 1): + for i in range(10): new_prefix = str(i) + ' ' + prefix # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(new_prefix)) - 1 prompts = [new_prefix + prompt for prompt in example_prompts] sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - _ = llm.generate(prompts, + outputs_with_prefix = llm.generate(prompts, sampling_params, prefix_pos=[prefix_pos] * len(prompts)) - assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == min( - i + 1, prefix_pool_max_capacity) + outputs_without_prefix = llm.generate(prompts, + sampling_params) + for output_without_prefix, output_with_prefix in zip( + outputs_without_prefix, outputs_with_prefix): + assert (output_without_prefix.outputs[0].token_ids == + output_with_prefix.outputs[0].token_ids) + diff --git a/tests/prefix_caching/test_prefix_pool.py b/tests/prefix_caching/test_prefix_pool.py index 653406e3ee8e..77e69d8cf4dd 100644 --- a/tests/prefix_caching/test_prefix_pool.py +++ b/tests/prefix_caching/test_prefix_pool.py @@ -5,7 +5,7 @@ @pytest.fixture def no_max_capacity_prefix_pool() -> PrefixPool: - return PrefixPool(block_size=32) + return PrefixPool(block_size=32, max_capacity_in_blocks=float('inf')) def test_prefix_length_behaviours(no_max_capacity_prefix_pool: PrefixPool): @@ -39,9 +39,12 @@ def test_same_prefix_added_twice(no_max_capacity_prefix_pool: PrefixPool): def test_prefix_pool_max_capacity(): - max_capacity = 1 + """ + Tests that the pool is evicting prefixes when it reaches max capacity. + """ + max_capacity_in_blocks = 2 max_capacity_prefix_pool = PrefixPool(block_size=32, - max_capacity=max_capacity) + max_capacity_in_blocks=max_capacity_in_blocks) # Tests that on the third insertion, new object is created because capacity limits reached, # but that the newly created object is equal to the old object @@ -53,20 +56,73 @@ def test_prefix_pool_max_capacity(): list(range(max_capacity_prefix_pool.block_size))) assert prefix_1 is not prefix_3 assert prefix_1 == prefix_3 + + assert len(max_capacity_prefix_pool) == 1 + assert max_capacity_prefix_pool.current_block_usage == 1 - # Tests that the max capacity remains the same +def test_current_block_usage(): + """ + Tests that the current_block_usage property remains the same thorough the + lifetime of the pool when adding prefixes that are always the same length equal + to the max capacity. + """ + max_capacity_in_blocks = 2 + max_capacity_prefix_pool = PrefixPool(block_size=32, + max_capacity_in_blocks=max_capacity_in_blocks) + for i in range(10): _ = max_capacity_prefix_pool.add_or_get_prefix( - list(range(max_capacity_prefix_pool.block_size + i))) - assert len(max_capacity_prefix_pool) == max_capacity + list(range(max_capacity_prefix_pool.block_size * max_capacity_in_blocks))) + assert len(max_capacity_prefix_pool) == 1 + assert max_capacity_prefix_pool.current_block_usage == max_capacity_in_blocks +def test_prefix_truncation_1(): + """ + Tests that prefix is truncated if it exceeds the max capacity. + """ + prefix_pool = PrefixPool(block_size=1, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4]) + assert prefix.token_ids == (1, 2) -def test_assertion_raised_with_invalid_max_capacity(): - with pytest.raises(AssertionError): - _ = PrefixPool(32, max_capacity=-1) +def test_prefix_truncation_2(): + """ + Testing truncation on non-block boundary + """ + prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=3) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) + assert prefix.token_ids == (1, 2, 3, 4) +def test_prefix_truncation_3(): + """ + Tests truncation because of both max capacity exceeded and no block boundary. + """ + prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) + assert prefix.token_ids == (1, 2, 3, 4) + +def test_none_prefix_returned_1(): + """ + Tests that when the max capacity is zero, no prefix is created and None is returned. + """ + prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=0) + prefix = prefix_pool.add_or_get_prefix( + list(range(prefix_pool.block_size))) + assert prefix is None + assert len(prefix_pool) == 0 + +def test_none_prefix_returned_2(): + """ + Tests that when prefix length is less than block size, a None prefix is returned. + """ + prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=2) + prefix = prefix_pool.add_or_get_prefix( + list(range(prefix_pool.block_size - 1))) + assert prefix is None + assert len(prefix_pool) == 0 + +def test_assertion_raised_with_invalid_max_capacity(): with pytest.raises(AssertionError): - _ = PrefixPool(32, max_capacity=0) + _ = PrefixPool(32, max_capacity_in_blocks=-1) if __name__ == "__main__": diff --git a/vllm/config.py b/vllm/config.py index f033e7de8b4a..03e0f7fdcdbf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -286,14 +286,14 @@ def __init__( swap_space: int, cache_dtype: str, sliding_window: Optional[int] = None, - prefix_pool_max_capacity: Optional[int] = None + prefix_pool_memory_utilization: float = 0 ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.cache_dtype = cache_dtype self.sliding_window = sliding_window - self.prefix_pool_max_capacity = prefix_pool_max_capacity + self.prefix_pool_memory_utilization = prefix_pool_memory_utilization self._verify_args() self._verify_cache_dtype() @@ -306,9 +306,14 @@ def _verify_args(self) -> None: raise ValueError( "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.prefix_pool_max_capacity is not None and self.prefix_pool_max_capacity <= 0: - raise ValueError("prefix_pool_max_capacity must be positive. Got " - f"{self.prefix_pool_max_capacity}.") + if self.prefix_pool_memory_utilization < 0: + raise ValueError("prefix_pool_memory_utilization must be non negative. " + f"{self.prefix_pool_memory_utilization}.") + if self.prefix_pool_memory_utilization > self.gpu_memory_utilization: + raise ValueError( + "prefix_pool_memory_utilization must be less than or equal to " + "gpu_memory_utilization." + ) def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3758a7ea8429..99c118be6f5a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -100,7 +100,8 @@ def __init__( # Create the prefix pool to cache the prefixes. self.prefix_pool = PrefixPool( self.cache_config.block_size, - max_capacity=self.cache_config.prefix_pool_max_capacity) + max_capacity_in_blocks= \ + int(self.cache_config.prefix_pool_memory_utilization * self.cache_config.num_gpu_blocks)) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() @@ -424,7 +425,7 @@ def free_old_prefixes(self): Deallocates the GPU memory of prefixes that have been moved to the list of candidates for deallocation in the prefix pool and that are not currently being used by any sequence group. """ - prefixes_to_free = self.prefix_pool.get_prefixes_to_free() + prefixes_to_free = self.prefix_pool.get_prefixes_to_deallocate() for prefix in prefixes_to_free: self.block_manager.free_prefix_blocks(prefix) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 88fab8588ecb..f47b3793f846 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,7 +27,7 @@ class EngineArgs: block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 - prefix_pool_max_capacity: Optional[int] = None + prefix_pool_memory_utilization: float = 0.0 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_paddings: int = 256 @@ -183,10 +183,10 @@ def add_cli_args( parser.add_argument( '--prefix-pool-max-capacity', type=int, - default=EngineArgs.prefix_pool_max_capacity, - help= - 'the maximum number of prefixes allowed in the prefix pool. If None, ' - 'there is no limit. It can only take positive values.') + default=EngineArgs.prefix_pool_memory_utilization, + help='the fraction of GPU memory to be used by prefixes allocated and ' + 'present in the prefix pool. If 0, no prefix cache is used. It cannot be ' + 'larger than gpu_memory_utilization. If unspecified, will use the default of 0.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -288,7 +288,7 @@ def create_engine_configs( self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window(), - self.prefix_pool_max_capacity) + self.prefix_pool_memory_utilization) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0ec53d3d2382..fbef31c7c0c3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -84,7 +84,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - prefix_pool_max_capacity: Optional[int] = None, + prefix_pool_memory_utilization: float = 0, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -105,7 +105,7 @@ def __init__( enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - prefix_pool_max_capacity=prefix_pool_max_capacity, + prefix_pool_memory_utilization=prefix_pool_memory_utilization, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/prefix.py b/vllm/prefix.py index 2c6e453c6e88..f705c2c67ca6 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -67,28 +67,43 @@ class PrefixPool: Attributes: prefixes: A list of all the prefixes. block_size: The block size of the executed model. - max_capacity: The maximum number of prefixes allowed in the pool at any given time. - The default value is None, which means there is no limit. It can only take positive - values. + max_capacity_in_blocks: The maximum number of blocks that can be used for prefixes in the pool + at any given time. The default value is 0, which effectively means there is no prefix cache. + If the value is float('inf'), is means that the capacity of the pool is unbounded (not recommended). """ def __init__(self, block_size: int, - max_capacity: Optional[int] = None) -> None: + max_capacity_in_blocks: int | float = 0) -> None: self.prefixes: OrderedDict[int, Prefix] = OrderedDict() self.block_size = block_size + self._current_block_usage = 0 + + self.max_allowed_prefix_length = float('inf') + if max_capacity_in_blocks < float('inf'): + assert max_capacity_in_blocks >= 0, "max_capacity must be non-negative." + self.max_allowed_prefix_length = self.block_size * max_capacity_in_blocks - if max_capacity is not None: - assert max_capacity > 0, "max_capacity must be positive." - - self.max_capacity = max_capacity - self._candidates_to_free: List[Prefix] = [] + self.max_capacity_in_blocks = max_capacity_in_blocks + + self._candidates_to_deallocate: List[Prefix] = [] def __len__(self) -> int: + """ + Returns the number of prefixes in the pool. + """ return len(self.prefixes) + + @property + def current_block_usage(self) -> int: + """ + Returns the number of blocks currently used by the pool. + """ + return self._current_block_usage def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size + new_length = min(new_length, self.max_allowed_prefix_length) return tuple(token_ids[:new_length]) def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: @@ -97,14 +112,27 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: it returns the existing prefix. If the pool is at max capacity, it removes the least recently used prefix before adding the new prefix. - Notice that if the length of token_ids is less than the block_size, no + There are two situations when None is returned: + 1. If the length of the token_ids of the prefix is less than the block_size, no prefix is created and None is returned. + 2. If the max_capacity of the pool is 0, then no prefix is created and None is returned. + + There is also two situations where the prefix is shortened to fit block boundaries: + 1. If the length of the token_ids of the prefix is not a multiple of the block_size. + 2. If the number of blocks needed to allocate the prefix exceeds the max_capacity of the pool, + the prefix is shortened to fit the max_capacity. Notice that this second occurence happens once + we have already attempted all other recourses to be able to allocate the prefix on its entirity, such + as evicting older prefixes from the pool. """ + if self.max_capacity_in_blocks == 0: + # Prefix cache is disabled. + return None + token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None - + # Check first if prefix exists, moving it to the end of the OrderedDict. # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) @@ -117,13 +145,18 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: # Prefix does not exist. Add created prefix to the pool and return it. # Always, before adding anything to the pool, check the capacity constraints and # remove the least recently used prefix if capacity constraints are violated. - if len(self.prefixes) == self.max_capacity: + prefix_num_blocks = prefix.get_num_blocks() + + while self._current_block_usage > 0 and prefix_num_blocks > self.max_capacity_in_blocks - self._current_block_usage: _, candidate_prefix = self.prefixes.popitem(last=False) - self._candidates_to_free.append(candidate_prefix) + self._candidates_to_deallocate.append(candidate_prefix) + self._current_block_usage -= candidate_prefix.get_num_blocks() + self.prefixes[prefix_hash] = prefix + self._current_block_usage += prefix_num_blocks return prefix - def get_prefixes_to_free(self) -> List[Prefix]: + def get_prefixes_to_deallocate(self) -> List[Prefix]: """ Returns a list of prefixes that are ready to be deallocated. For a prefix to be deallocated, it must fulfill the following two conditions: @@ -144,11 +177,11 @@ def get_prefixes_to_free(self) -> List[Prefix]: means that the .get_prefixes_to_free() function will not return it. """ indexes_to_remove = [ - i for i, prefix in enumerate(self._candidates_to_free) + i for i, prefix in enumerate(self._candidates_to_deallocate) if prefix.seq_ref_count == 0 and prefix.allocated ] # Popping needs to happen with the indexes_to_remove list in reverse order # so that we don't get Index out of range errors return [ - self._candidates_to_free.pop(i) for i in indexes_to_remove[::-1] + self._candidates_to_deallocate.pop(i) for i in indexes_to_remove[::-1] ] From 880c33ffb8a4e7e2badadb92b3fce0ecbeee0a4d Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Wed, 24 Jan 2024 20:45:52 -0500 Subject: [PATCH 12/18] Fixing formatting --- tests/prefix_caching/test_prefix_caching.py | 18 ++++++------- tests/prefix_caching/test_prefix_pool.py | 28 +++++++++++++-------- vllm/config.py | 8 +++--- vllm/engine/arg_utils.py | 6 +++-- vllm/prefix.py | 15 +++++------ 5 files changed, 43 insertions(+), 32 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 0ecba11344b1..95a68ca6537b 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,7 +2,6 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ -from typing import Optional from importlib import reload import pytest @@ -36,7 +35,8 @@ def test_prefix_caching( # to fail with the message: "AssertionError: tensor model parallel group is # already initialized." reload(parallel_state) - llm = LLM(model=model, prefix_pool_memory_utilization=prefix_pool_memory_utilization) + llm = LLM(model=model, + prefix_pool_memory_utilization=prefix_pool_memory_utilization) # -1 since the last token can change when concatenating prompts. prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 prompts = [prefix + prompt for prompt in example_prompts] @@ -52,6 +52,7 @@ def test_prefix_caching( if prefix_pool_memory_utilization == 0: assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 0 + @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("max_tokens", [16]) @pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.3, 0.4]) @@ -79,12 +80,11 @@ def test_prefix_caching_with_multiple_prefixes( sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs_with_prefix = llm.generate(prompts, - sampling_params, - prefix_pos=[prefix_pos] * len(prompts)) - outputs_without_prefix = llm.generate(prompts, - sampling_params) + sampling_params, + prefix_pos=[prefix_pos] * + len(prompts)) + outputs_without_prefix = llm.generate(prompts, sampling_params) for output_without_prefix, output_with_prefix in zip( - outputs_without_prefix, outputs_with_prefix): + outputs_without_prefix, outputs_with_prefix): assert (output_without_prefix.outputs[0].token_ids == - output_with_prefix.outputs[0].token_ids) - + output_with_prefix.outputs[0].token_ids) diff --git a/tests/prefix_caching/test_prefix_pool.py b/tests/prefix_caching/test_prefix_pool.py index 77e69d8cf4dd..2a4b6709f29e 100644 --- a/tests/prefix_caching/test_prefix_pool.py +++ b/tests/prefix_caching/test_prefix_pool.py @@ -43,8 +43,8 @@ def test_prefix_pool_max_capacity(): Tests that the pool is evicting prefixes when it reaches max capacity. """ max_capacity_in_blocks = 2 - max_capacity_prefix_pool = PrefixPool(block_size=32, - max_capacity_in_blocks=max_capacity_in_blocks) + max_capacity_prefix_pool = PrefixPool( + block_size=32, max_capacity_in_blocks=max_capacity_in_blocks) # Tests that on the third insertion, new object is created because capacity limits reached, # but that the newly created object is equal to the old object @@ -56,10 +56,11 @@ def test_prefix_pool_max_capacity(): list(range(max_capacity_prefix_pool.block_size))) assert prefix_1 is not prefix_3 assert prefix_1 == prefix_3 - + assert len(max_capacity_prefix_pool) == 1 assert max_capacity_prefix_pool.current_block_usage == 1 + def test_current_block_usage(): """ Tests that the current_block_usage property remains the same thorough the @@ -67,15 +68,18 @@ def test_current_block_usage(): to the max capacity. """ max_capacity_in_blocks = 2 - max_capacity_prefix_pool = PrefixPool(block_size=32, - max_capacity_in_blocks=max_capacity_in_blocks) - - for i in range(10): + max_capacity_prefix_pool = PrefixPool( + block_size=32, max_capacity_in_blocks=max_capacity_in_blocks) + + for _ in range(10): _ = max_capacity_prefix_pool.add_or_get_prefix( - list(range(max_capacity_prefix_pool.block_size * max_capacity_in_blocks))) + list( + range(max_capacity_prefix_pool.block_size * + max_capacity_in_blocks))) assert len(max_capacity_prefix_pool) == 1 assert max_capacity_prefix_pool.current_block_usage == max_capacity_in_blocks + def test_prefix_truncation_1(): """ Tests that prefix is truncated if it exceeds the max capacity. @@ -84,6 +88,7 @@ def test_prefix_truncation_1(): prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4]) assert prefix.token_ids == (1, 2) + def test_prefix_truncation_2(): """ Testing truncation on non-block boundary @@ -92,6 +97,7 @@ def test_prefix_truncation_2(): prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) assert prefix.token_ids == (1, 2, 3, 4) + def test_prefix_truncation_3(): """ Tests truncation because of both max capacity exceeded and no block boundary. @@ -100,16 +106,17 @@ def test_prefix_truncation_3(): prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5]) assert prefix.token_ids == (1, 2, 3, 4) + def test_none_prefix_returned_1(): """ Tests that when the max capacity is zero, no prefix is created and None is returned. """ prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=0) - prefix = prefix_pool.add_or_get_prefix( - list(range(prefix_pool.block_size))) + prefix = prefix_pool.add_or_get_prefix(list(range(prefix_pool.block_size))) assert prefix is None assert len(prefix_pool) == 0 + def test_none_prefix_returned_2(): """ Tests that when prefix length is less than block size, a None prefix is returned. @@ -120,6 +127,7 @@ def test_none_prefix_returned_2(): assert prefix is None assert len(prefix_pool) == 0 + def test_assertion_raised_with_invalid_max_capacity(): with pytest.raises(AssertionError): _ = PrefixPool(32, max_capacity_in_blocks=-1) diff --git a/vllm/config.py b/vllm/config.py index 03e0f7fdcdbf..af40e82a62dd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -307,13 +307,13 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") if self.prefix_pool_memory_utilization < 0: - raise ValueError("prefix_pool_memory_utilization must be non negative. " - f"{self.prefix_pool_memory_utilization}.") + raise ValueError( + "prefix_pool_memory_utilization must be non negative. " + f"{self.prefix_pool_memory_utilization}.") if self.prefix_pool_memory_utilization > self.gpu_memory_utilization: raise ValueError( "prefix_pool_memory_utilization must be less than or equal to " - "gpu_memory_utilization." - ) + "gpu_memory_utilization.") def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47b3793f846..9e2414ad481d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -184,9 +184,11 @@ def add_cli_args( '--prefix-pool-max-capacity', type=int, default=EngineArgs.prefix_pool_memory_utilization, - help='the fraction of GPU memory to be used by prefixes allocated and ' + help= + 'the fraction of GPU memory to be used by prefixes allocated and ' 'present in the prefix pool. If 0, no prefix cache is used. It cannot be ' - 'larger than gpu_memory_utilization. If unspecified, will use the default of 0.') + 'larger than gpu_memory_utilization. If unspecified, will use the default of 0.' + ) parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, diff --git a/vllm/prefix.py b/vllm/prefix.py index f705c2c67ca6..1f254c456745 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -78,14 +78,14 @@ def __init__(self, self.prefixes: OrderedDict[int, Prefix] = OrderedDict() self.block_size = block_size self._current_block_usage = 0 - + self.max_allowed_prefix_length = float('inf') if max_capacity_in_blocks < float('inf'): assert max_capacity_in_blocks >= 0, "max_capacity must be non-negative." self.max_allowed_prefix_length = self.block_size * max_capacity_in_blocks self.max_capacity_in_blocks = max_capacity_in_blocks - + self._candidates_to_deallocate: List[Prefix] = [] def __len__(self) -> int: @@ -93,7 +93,7 @@ def __len__(self) -> int: Returns the number of prefixes in the pool. """ return len(self.prefixes) - + @property def current_block_usage(self) -> int: """ @@ -127,12 +127,12 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: if self.max_capacity_in_blocks == 0: # Prefix cache is disabled. return None - + token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None - + # Check first if prefix exists, moving it to the end of the OrderedDict. # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) @@ -151,7 +151,7 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: _, candidate_prefix = self.prefixes.popitem(last=False) self._candidates_to_deallocate.append(candidate_prefix) self._current_block_usage -= candidate_prefix.get_num_blocks() - + self.prefixes[prefix_hash] = prefix self._current_block_usage += prefix_num_blocks return prefix @@ -183,5 +183,6 @@ def get_prefixes_to_deallocate(self) -> List[Prefix]: # Popping needs to happen with the indexes_to_remove list in reverse order # so that we don't get Index out of range errors return [ - self._candidates_to_deallocate.pop(i) for i in indexes_to_remove[::-1] + self._candidates_to_deallocate.pop(i) + for i in indexes_to_remove[::-1] ] From 7348f5500267b682be8517a5c2d10f5d06cbc4d0 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Wed, 24 Jan 2024 21:37:13 -0500 Subject: [PATCH 13/18] Fixed arg_utils typos --- vllm/engine/arg_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9e2414ad481d..2e57d21f7ad4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -181,8 +181,8 @@ def add_cli_args( 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') parser.add_argument( - '--prefix-pool-max-capacity', - type=int, + '--prefix-pool-memory-utilization', + type=float, default=EngineArgs.prefix_pool_memory_utilization, help= 'the fraction of GPU memory to be used by prefixes allocated and ' From 4e6b020d271d7d5912fab489e05773d5060440d7 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 26 Jan 2024 11:17:36 -0500 Subject: [PATCH 14/18] Solves memory leak issue with prefixes in waiting list --- vllm/prefix.py | 9 +++++++++ vllm/sequence.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/vllm/prefix.py b/vllm/prefix.py index 1f254c456745..952ef6bcaf70 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -30,6 +30,7 @@ def __init__(self, token_ids: Sequence[int], block_size: int) -> None: # Must not be initialized to 1 at creation time because a prefix might be created # and thrown away, or sequences sharing this prefix might never be allocated. self.seq_ref_count = 0 + self.expired = False @property def allocated(self) -> bool: @@ -180,6 +181,14 @@ def get_prefixes_to_deallocate(self) -> List[Prefix]: i for i, prefix in enumerate(self._candidates_to_deallocate) if prefix.seq_ref_count == 0 and prefix.allocated ] + + # Mark the prefix as expired, so that if a sequence group still in the + # waiting list that shares this prefix tries to allocate it as a prefix, + # it will fail. + for i in indexes_to_remove: + prefix = self._candidates_to_deallocate[i] + prefix.expired = True + # Popping needs to happen with the indexes_to_remove list in reverse order # so that we don't get Index out of range errors return [ diff --git a/vllm/sequence.py b/vllm/sequence.py index cd3f5e5a4450..3d56b0bfc94a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -259,7 +259,7 @@ def __init__( self.sampling_params = sampling_params self.arrival_time = arrival_time self.lora_request = lora_request - self.prefix: Optional[Prefix] = prefix + self._prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -268,6 +268,20 @@ def prompt(self) -> str: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt + @property + def prefix(self) -> Optional[Prefix]: + """ + If the prefix has been marked as expired, set the prefix to None. + The reason for this is that when a prefix has expired it means that it has + been moved out of the prefix pool, so is no longer valid, and it should + no longer be allocated as a prefix. + + Return the prefix. + """ + if self._prefix is not None and self._prefix.expired: + self._prefix = None + return self._prefix + @property def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. From 36e71a3526f8ffda7ac135f3b238a166d2cbccaf Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 26 Jan 2024 11:35:40 -0500 Subject: [PATCH 15/18] Added support for lora_int_id --- vllm/prefix.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/prefix.py b/vllm/prefix.py index 952ef6bcaf70..bb95213087b8 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -107,8 +107,13 @@ def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = min(new_length, self.max_allowed_prefix_length) return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + def add_or_get_prefix(self, token_ids: Sequence[int], lora_int_id: int = 0) -> Optional[Prefix]: """ + Arguments: + - token_ids: The token ids of the prefix to add to the pool. + - lora_int_id: The lora_int_id of the request, which will be used to hash the prefix too. + If the lora_int_id is not given, defaults to 0. + Adds a prefix to the pool if it does not already exist. If it does exist, it returns the existing prefix. If the pool is at max capacity, it removes the least recently used prefix before adding the new prefix. @@ -137,7 +142,7 @@ def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: # Check first if prefix exists, moving it to the end of the OrderedDict. # so that the LRU policy is maintained. Return the existing prefix. prefix = Prefix(token_ids, self.block_size) - prefix_hash = hash(prefix) + prefix_hash = hash((prefix, lora_int_id)) if prefix_hash in self.prefixes: prefix = self.prefixes[prefix_hash] self.prefixes.move_to_end(prefix_hash) From 4566e39c010d749e65ba424990a2f188d3bd9284 Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 26 Jan 2024 11:41:21 -0500 Subject: [PATCH 16/18] Fixed formtting issues --- vllm/prefix.py | 8 +++++--- vllm/sequence.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/prefix.py b/vllm/prefix.py index bb95213087b8..eac9ef490d46 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -107,7 +107,9 @@ def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = min(new_length, self.max_allowed_prefix_length) return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int], lora_int_id: int = 0) -> Optional[Prefix]: + def add_or_get_prefix(self, + token_ids: Sequence[int], + lora_int_id: int = 0) -> Optional[Prefix]: """ Arguments: - token_ids: The token ids of the prefix to add to the pool. @@ -186,14 +188,14 @@ def get_prefixes_to_deallocate(self) -> List[Prefix]: i for i, prefix in enumerate(self._candidates_to_deallocate) if prefix.seq_ref_count == 0 and prefix.allocated ] - + # Mark the prefix as expired, so that if a sequence group still in the # waiting list that shares this prefix tries to allocate it as a prefix, # it will fail. for i in indexes_to_remove: prefix = self._candidates_to_deallocate[i] prefix.expired = True - + # Popping needs to happen with the indexes_to_remove list in reverse order # so that we don't get Index out of range errors return [ diff --git a/vllm/sequence.py b/vllm/sequence.py index 3d56b0bfc94a..9632806cfe17 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -281,7 +281,7 @@ def prefix(self) -> Optional[Prefix]: if self._prefix is not None and self._prefix.expired: self._prefix = None return self._prefix - + @property def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. From 8f83072b322688964d6f32212d136fb1acc230bb Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Fri, 26 Jan 2024 11:48:39 -0500 Subject: [PATCH 17/18] Added documentation for the prefix-caching-memory-utilization parameter --- docs/source/models/engine_args.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d89b79514950..10136c9cb1bc 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -95,6 +95,14 @@ Below, you can find an explanation of every engine argument for vLLM: For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9. +.. option:: --prefix-caching-memory-utilization + + The fraction of GPU memory to be used for the prefix caching, which can range from 0 to --gpu-memory-utilization. + For example, a value of 0.5 would imply 50% GPU memory utilization. + If unspecified, will use the default value of 0. A value of 0 means no prefix caching at all. + The size of the prefixes relative to the length of the rest of the prompts and the generated + sequences should dictate the relative value of this parameter with respect to gpu-memory-utilization. + .. option:: --max-num-batched-tokens Maximum number of batched tokens per iteration. From 9bd33a61195bb59cc855902b8751a1ebb2b238ac Mon Sep 17 00:00:00 2001 From: Jadiel de Armas Date: Tue, 30 Jan 2024 14:38:55 -0500 Subject: [PATCH 18/18] Fixing formatting issues --- vllm/config.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index af40e82a62dd..822e914c6697 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -279,15 +279,13 @@ class CacheConfig: cache_dtype: Data type for kv cache storage. """ - def __init__( - self, - block_size: int, - gpu_memory_utilization: float, - swap_space: int, - cache_dtype: str, - sliding_window: Optional[int] = None, - prefix_pool_memory_utilization: float = 0 - ) -> None: + def __init__(self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + sliding_window: Optional[int] = None, + prefix_pool_memory_utilization: float = 0) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB