Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/models/engine_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ Below, you can find an explanation of every engine argument for vLLM:
For example, a value of 0.5 would imply 50% GPU memory utilization.
If unspecified, will use the default value of 0.9.

.. option:: --prefix-caching-memory-utilization <fraction>

The fraction of GPU memory to be used for the prefix caching, which can range from 0 to --gpu-memory-utilization.
For example, a value of 0.5 would imply 50% GPU memory utilization.
If unspecified, will use the default value of 0. A value of 0 means no prefix caching at all.
The size of the prefixes relative to the length of the rest of the prompts and the generated
sequences should dictate the relative value of this parameter with respect to gpu-memory-utilization.

.. option:: --max-num-batched-tokens <tokens>

Maximum number of batched tokens per iteration.
Expand Down
53 changes: 51 additions & 2 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
from importlib import reload

import pytest

import vllm.model_executor.parallel_utils.parallel_state as parallel_state
from vllm import LLM, SamplingParams

prefix = (
Expand All @@ -20,12 +23,20 @@

@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.2])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
prefix_pool_memory_utilization: float,
):
llm = LLM(model=model)
# IMPORTANT: If this line is removed from here, adding more than 1 item to
# any of the parametrization lists above causes all tests but the first one
# to fail with the message: "AssertionError: tensor model parallel group is
# already initialized."
reload(parallel_state)
llm = LLM(model=model,
prefix_pool_memory_utilization=prefix_pool_memory_utilization)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
Expand All @@ -38,4 +49,42 @@ def test_prefix_caching(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
if prefix_pool_memory_utilization == 0:
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 0


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("prefix_pool_memory_utilization", [0, 0.1, 0.3, 0.4])
def test_prefix_caching_with_multiple_prefixes(
example_prompts, model: str, max_tokens: int,
prefix_pool_memory_utilization: float):
"""
Tests that the scheduler prefix pool size (length) does not go over the
maximum capacity at any moment in time.
"""
# IMPORTANT: If this line is removed from here, adding more than 1 item to
# any of the parametrization lists above causes all tests but the first one
# to fail with the message: "AssertionError: tensor model parallel group is
# already initialized."
reload(parallel_state)
llm = LLM(model=model,
prefix_pool_memory_utilization=prefix_pool_memory_utilization)

# Use 10 different prefixes:
for i in range(10):
new_prefix = str(i) + ' ' + prefix
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(new_prefix)) - 1
prompts = [new_prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] *
len(prompts))
outputs_without_prefix = llm.generate(prompts, sampling_params)
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
138 changes: 138 additions & 0 deletions tests/prefix_caching/test_prefix_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from vllm.prefix import PrefixPool

import pytest


@pytest.fixture
def no_max_capacity_prefix_pool() -> PrefixPool:
return PrefixPool(block_size=32, max_capacity_in_blocks=float('inf'))


def test_prefix_length_behaviours(no_max_capacity_prefix_pool: PrefixPool):
"""
This test checks that prefixes of length less than pool.block_size are not created and are not added to the pool.
It also checks that prefixes of length equal to or greater to pool.block_size are created and added to the pool.
"""
prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(
list(range(no_max_capacity_prefix_pool.block_size - 1)))
prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(
list(range(no_max_capacity_prefix_pool.block_size)))
prefix_3 = no_max_capacity_prefix_pool.add_or_get_prefix(
list(range(no_max_capacity_prefix_pool.block_size * 2)))
assert prefix_1 is None
assert prefix_2 is not None
assert prefix_3 is not None
assert len(no_max_capacity_prefix_pool) == 2


def test_same_prefix_added_twice(no_max_capacity_prefix_pool: PrefixPool):
"""
Tests that when a prefix is added more than once to the pool, all subsequent additions
return the same prefix object that was created the first time.
"""
prefix_1 = no_max_capacity_prefix_pool.add_or_get_prefix(
list(range(no_max_capacity_prefix_pool.block_size)))
prefix_2 = no_max_capacity_prefix_pool.add_or_get_prefix(
list(range(no_max_capacity_prefix_pool.block_size)))
assert prefix_1 is prefix_2
assert len(no_max_capacity_prefix_pool) == 1


def test_prefix_pool_max_capacity():
"""
Tests that the pool is evicting prefixes when it reaches max capacity.
"""
max_capacity_in_blocks = 2
max_capacity_prefix_pool = PrefixPool(
block_size=32, max_capacity_in_blocks=max_capacity_in_blocks)

# Tests that on the third insertion, new object is created because capacity limits reached,
# but that the newly created object is equal to the old object
prefix_1 = max_capacity_prefix_pool.add_or_get_prefix(
list(range(max_capacity_prefix_pool.block_size)))
_ = max_capacity_prefix_pool.add_or_get_prefix(
list(range(max_capacity_prefix_pool.block_size * 2)))
prefix_3 = max_capacity_prefix_pool.add_or_get_prefix(
list(range(max_capacity_prefix_pool.block_size)))
assert prefix_1 is not prefix_3
assert prefix_1 == prefix_3

assert len(max_capacity_prefix_pool) == 1
assert max_capacity_prefix_pool.current_block_usage == 1


def test_current_block_usage():
"""
Tests that the current_block_usage property remains the same thorough the
lifetime of the pool when adding prefixes that are always the same length equal
to the max capacity.
"""
max_capacity_in_blocks = 2
max_capacity_prefix_pool = PrefixPool(
block_size=32, max_capacity_in_blocks=max_capacity_in_blocks)

for _ in range(10):
_ = max_capacity_prefix_pool.add_or_get_prefix(
list(
range(max_capacity_prefix_pool.block_size *
max_capacity_in_blocks)))
assert len(max_capacity_prefix_pool) == 1
assert max_capacity_prefix_pool.current_block_usage == max_capacity_in_blocks


def test_prefix_truncation_1():
"""
Tests that prefix is truncated if it exceeds the max capacity.
"""
prefix_pool = PrefixPool(block_size=1, max_capacity_in_blocks=2)
prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4])
assert prefix.token_ids == (1, 2)


def test_prefix_truncation_2():
"""
Testing truncation on non-block boundary
"""
prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=3)
prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5])
assert prefix.token_ids == (1, 2, 3, 4)


def test_prefix_truncation_3():
"""
Tests truncation because of both max capacity exceeded and no block boundary.
"""
prefix_pool = PrefixPool(block_size=2, max_capacity_in_blocks=2)
prefix = prefix_pool.add_or_get_prefix([1, 2, 3, 4, 5])
assert prefix.token_ids == (1, 2, 3, 4)


def test_none_prefix_returned_1():
"""
Tests that when the max capacity is zero, no prefix is created and None is returned.
"""
prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=0)
prefix = prefix_pool.add_or_get_prefix(list(range(prefix_pool.block_size)))
assert prefix is None
assert len(prefix_pool) == 0


def test_none_prefix_returned_2():
"""
Tests that when prefix length is less than block size, a None prefix is returned.
"""
prefix_pool = PrefixPool(block_size=32, max_capacity_in_blocks=2)
prefix = prefix_pool.add_or_get_prefix(
list(range(prefix_pool.block_size - 1)))
assert prefix is None
assert len(prefix_pool) == 0


def test_assertion_raised_with_invalid_max_capacity():
with pytest.raises(AssertionError):
_ = PrefixPool(32, max_capacity_in_blocks=-1)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
6 changes: 6 additions & 0 deletions vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def __init__(
self.block_number = block_number
self.block_size = block_size

# Contains the number of sequences that share this block that
# are currently allocated in the same device as this block.
# Notice that prefix blocks will have an extra 1 added to this
# reference count to guarantee that prefix blocks are not deallocated
# by the standard way that the block manager uses to free the memory
# for blocks.
self.ref_count = 0

def __repr__(self) -> str:
Expand Down
46 changes: 38 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,19 @@ class CacheConfig:
cache_dtype: Data type for kv cache storage.
"""

def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
) -> None:
def __init__(self,
block_size: int,
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
prefix_pool_memory_utilization: float = 0) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.prefix_pool_memory_utilization = prefix_pool_memory_utilization
self._verify_args()
self._verify_cache_dtype()

Expand All @@ -304,6 +304,36 @@ def _verify_args(self) -> None:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.prefix_pool_memory_utilization < 0:
raise ValueError(
"prefix_pool_memory_utilization must be non negative. "
f"{self.prefix_pool_memory_utilization}.")
if self.prefix_pool_memory_utilization > self.gpu_memory_utilization:
raise ValueError(
"prefix_pool_memory_utilization must be less than or equal to "
"gpu_memory_utilization.")

def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8."
)
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
Expand Down
Loading