Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions tests/models/language/pooling/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
["Qwen/Qwen3-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]

Expand All @@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
Expand All @@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str):

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0

# test enable_prefix_caching plus all pooling
# we need to skip reading cache at this request by
# request.skip_reading_prefix_cache
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="token_embed",
)

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0

# skip_reading_prefix_cache can still write to cache
# to accelerate following requests
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="embed",
)

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens > 0
12 changes: 12 additions & 0 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
skip_reading_prefix_cache: bool = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

Expand Down Expand Up @@ -93,6 +94,8 @@ def verify(
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
if self.skip_reading_prefix_cache is None:
self.skip_reading_prefix_cache = True
return

# NOTE: Task validation needs to done against the model instance,
Expand Down Expand Up @@ -122,6 +125,15 @@ def _merge_default_parameters(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))

if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of all pooling may less than n_prompt_tokens,
# we need to skip reading cache at this request.
if self.task in ["token_embed", "token_classify"]:
self.skip_reading_prefix_cache = True
else:
self.skip_reading_prefix_cache = False

self._verify_step_pooling(pooler_config, valid_parameters)

def _verify_step_pooling(
Expand Down
8 changes: 8 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None

skip_reading_prefix_cache: bool = None

@staticmethod
def from_optional(
n: int | None = 1,
Expand Down Expand Up @@ -412,6 +414,12 @@ def __post_init__(self) -> None:
self.structured_outputs = self.guided_decoding
self.guided_decoding = None

if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
# we need to skip reading cache at this request.
self.skip_reading_prefix_cache = self.prompt_logprobs is not None

def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
Expand Down
11 changes: 5 additions & 6 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,11 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if not self.enable_caching or (
request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None
):
# We skip finding the prefix cache hit when prefix caching is
# disabled or the request is marked as skipping kv cache read
# (which happens when the request requires prompt logprobs
# or calls a pooling model with all pooling).
if not self.enable_caching or request.skip_reading_prefix_cache:
return self.empty_kv_cache_blocks, 0

# NOTE: When all tokens hit the cache, we must recompute the last token
Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def __init__(
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()

self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()

@classmethod
def from_engine_core_request(
cls,
Expand Down Expand Up @@ -180,6 +182,19 @@ def num_tokens_with_spec(self) -> int:
def num_output_tokens(self) -> int:
return len(self._output_token_ids)

def get_skip_reading_prefix_cache(self) -> bool:
if (
self.sampling_params is not None
and self.sampling_params.skip_reading_prefix_cache is not None
):
return self.sampling_params.skip_reading_prefix_cache
elif (
self.pooling_params is not None
and self.pooling_params.skip_reading_prefix_cache is not None
):
return self.pooling_params.skip_reading_prefix_cache
return False

def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)

Expand Down