From ae51908ea96cd2cba166e151ced3e7275d105a2c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 24 Jul 2025 22:58:09 -0700 Subject: [PATCH 01/12] Add option to propagate padded logits_indices to model Signed-off-by: Yong Hoon Shin --- vllm/envs.py | 5 +++++ vllm/forward_context.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 0eff741519ae..3812163221ad 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -143,6 +143,7 @@ VLLM_USE_CUDNN_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" + VLLM_COMPUTE_PADDED_LOGITS_INDICES: bool = False def get_default_cache_root(): @@ -991,6 +992,10 @@ def get_vllm_port() -> Optional[int]: # The default value is "VLLM". "VLLM_PROCESS_NAME_PREFIX": lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), + + # Enable computing and propagating cudagraph padded logits indices + "VLLM_COMPUTE_PADDED_LOGITS_INDICES": + lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDD_LOGITS_INDICES", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feeaf..fda837eb82fd 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -95,6 +95,7 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False + logits_indices_padded: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -116,6 +117,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + logits_indices_padded: Optional[torch.Tensor] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -141,6 +143,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, + logits_indices_padded=logits_indices_padded, ) try: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32004ced4aae..88f8cd313054 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -318,6 +318,12 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + self.logits_indices = None + if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: + self.logits_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -1364,6 +1370,7 @@ def execute_model( spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata) = ( self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1436,6 +1443,23 @@ def execute_model( # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + logits_indices_padded = None + if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: + assert self.logits_indices is not None + num_logits = logits_indices.shape[0] + self.logits_indices[:num_logits].copy_(logits_indices) + # Ensure we keep duplicates instead of zeros + self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) + if (self.use_cuda_graph + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph( + num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = self.logits_indices[:num_logits_padded] + # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( @@ -1444,6 +1468,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, + logits_indices_padded=logits_indices_padded, ): self.maybe_setup_kv_connector(scheduler_output) From 69279d2e2e6f565ef8d2bea802b83b0c476f0226 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 24 Jul 2025 23:11:28 -0700 Subject: [PATCH 02/12] Fix typo Signed-off-by: Yong Hoon Shin --- vllm/envs.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 3812163221ad..a0c3b62ead6c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -995,7 +995,7 @@ def get_vllm_port() -> Optional[int]: # Enable computing and propagating cudagraph padded logits indices "VLLM_COMPUTE_PADDED_LOGITS_INDICES": - lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDD_LOGITS_INDICES", "0"))), + lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDED_LOGITS_INDICES", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 88f8cd313054..d16f5f84ccac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1447,6 +1447,7 @@ def execute_model( if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: assert self.logits_indices is not None num_logits = logits_indices.shape[0] + assert num_logits > 0 self.logits_indices[:num_logits].copy_(logits_indices) # Ensure we keep duplicates instead of zeros self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) From 94df2f1dc374a1129014772e4ddc64ed4d60ccbc Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 24 Jul 2025 23:28:18 -0700 Subject: [PATCH 03/12] Fix lint Signed-off-by: Yong Hoon Shin --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d16f5f84ccac..68490002b507 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1447,7 +1447,7 @@ def execute_model( if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: assert self.logits_indices is not None num_logits = logits_indices.shape[0] - assert num_logits > 0 + assert num_logits > 0 self.logits_indices[:num_logits].copy_(logits_indices) # Ensure we keep duplicates instead of zeros self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) From 4d19b7b6d033e7695d3b54bbf5cf7187e0fb7e29 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 10:54:45 -0700 Subject: [PATCH 04/12] Subclass attn metadata for cross-decoder layers to propagate logits_indices Signed-off-by: Yong Hoon Shin --- .../e2e/test_kv_sharing_truncated_prefill.py | 130 ++++++++++++++++++ vllm/config.py | 4 + vllm/engine/arg_utils.py | 8 ++ vllm/entrypoints/llm.py | 3 + vllm/envs.py | 5 - vllm/forward_context.py | 3 - vllm/model_executor/models/gemma3n.py | 22 ++- vllm/v1/attention/backends/utils.py | 17 ++- vllm/v1/worker/gpu_model_runner.py | 128 +++++++++++------ 9 files changed, 264 insertions(+), 56 deletions(-) create mode 100644 tests/v1/e2e/test_kv_sharing_truncated_prefill.py diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py new file mode 100644 index 000000000000..27c7f179275a --- /dev/null +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import random +from typing import Optional, Union + +import pytest +import torch + +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel +from vllm.forward_context import get_forward_context +from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration +from vllm.model_executor.models.registry import ModelRegistry +from vllm.sequence import IntermediateTensors + +from ...utils import fork_new_process_for_each_test + + +class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + attn_metadata = get_forward_context().attn_metadata + # attn_metadata is None during dummy runs + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) # true in V1 + # Layer 20 is a cross-decoder layer in YOCO + layer_attn_metadata = attn_metadata['model.language_model.layers.20.self_attn.attn'] + if hasattr(layer_attn_metadata, 'logits_indices_padded'): + # This field is only set when + # enable_kv_sharing_truncated_prefill is set to True + assert self.cache_config.enable_kv_sharing_truncated_prefill + logits_indices_padded = ( + layer_attn_metadata.logits_indices_padded + ) + assert logits_indices_padded is not None + num_logits_indices = layer_attn_metadata.num_logits_indices + assert num_logits_indices > 0 + + logits_hs = hidden_states[logits_indices_padded] + hidden_states = torch.randn_like(hidden_states) + gen_indices = logits_indices_padded[:num_logits_indices] + # Only set logits for logits_indices to valid values + hidden_states[gen_indices] = logits_hs[:num_logits_indices] + + return hidden_states + +@pytest.fixture +def test_prompts(): + """ + Adapted from tests/v1/e2e/test_spec_decode.py + """ + prompt_types = ["repeat", "sentence"] + # Setting higher num prompts increases the chance of numerics mismatch + # due to matrix multiplication numerics depending on batch dimension + num_prompts = 10 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f"""please repeat the word '{word}' 10 times.""" + elif kind == "sentence": + prompt = f"""please give a ten-word sentence that + uses the word {word} at least once.""" + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append(prompt) + + return prompts + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_kv_sharing_truncated_prefill( + monkeypatch: pytest.MonkeyPatch, + enforce_eager: bool, + test_prompts: list[str], +): + ModelRegistry.register_model("Gemma3nForConditionalGeneration", TestGemma3nForConditionalGeneration) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + compilation_config = CompilationConfig( + # This allows vLLM compilation backend to handle allocating and + # managing buffers for cudagraph + cudagraph_copy_inputs=True, + level=CompilationLevel. + PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + ) + ref_responses = llm.generate(test_prompts, sampling_params) + + del llm + gc.collect() + torch.cuda.empty_cache() + + llm = LLM(model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + enable_kv_sharing_truncated_prefill=True) + optimized_responses = llm.generate(test_prompts, sampling_params) + + misses = 0 + + for ref_response, optimized_response in zip(ref_responses, + optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[ + 0].text: + misses += 1 + + assert misses == 0 diff --git a/vllm/config.py b/vllm/config.py index 07df71ec51ef..0a77cd1317cb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1684,6 +1684,10 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + enable_kv_sharing_truncated_prefill: bool = False + """Skip prefill for tokens where applicable in YOCO-like KV-sharing + setups (e.g. Gemma3n)""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 709968004718..10a2331799d5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -438,6 +438,9 @@ class EngineArgs: # DEPRECATED enable_prompt_adapter: bool = False + enable_kv_sharing_truncated_prefill: bool = \ + CacheConfig.enable_kv_sharing_truncated_prefill + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -686,6 +689,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) + cache_group.add_argument( + "--enable-kv-sharing-truncated-prefill", + **cache_kwargs["enable_kv_sharing_truncated_prefill"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1056,6 +1062,8 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + enable_kv_sharing_truncated_prefill=self. + enable_kv_sharing_truncated_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2f766a2dae57..96d7bfa0358a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -193,6 +193,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + enable_kv_sharing_truncated_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -266,6 +267,8 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + enable_kv_sharing_truncated_prefill=\ + enable_kv_sharing_truncated_prefill, **kwargs, ) diff --git a/vllm/envs.py b/vllm/envs.py index a0c3b62ead6c..0eff741519ae 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -143,7 +143,6 @@ VLLM_USE_CUDNN_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" - VLLM_COMPUTE_PADDED_LOGITS_INDICES: bool = False def get_default_cache_root(): @@ -992,10 +991,6 @@ def get_vllm_port() -> Optional[int]: # The default value is "VLLM". "VLLM_PROCESS_NAME_PREFIX": lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), - - # Enable computing and propagating cudagraph padded logits indices - "VLLM_COMPUTE_PADDED_LOGITS_INDICES": - lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDED_LOGITS_INDICES", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index fda837eb82fd..dd55b19feeaf 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -95,7 +95,6 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False - logits_indices_padded: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -117,7 +116,6 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, - logits_indices_padded: Optional[torch.Tensor] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -143,7 +141,6 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, - logits_indices_padded=logits_indices_padded, ) try: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 7d163320e0d6..36a9fda19938 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -581,6 +581,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Gemma3nDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + self.self_decoder_layers = self.layers[:first_kv_shared_layer_idx] + # Layer idx 20-34 are cross-decoder layers in YOCO + # Refer to YOCO paper https://arxiv.org/abs/2405.05254 + self.cross_decoder_layers = self.layers[first_kv_shared_layer_idx:] + self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, @@ -646,7 +655,17 @@ def forward( hidden_states = torch.stack(hidden_states, dim=0) # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): + for layer_idx, layer in enumerate(self.self_decoder_layers): + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + for layer_idx, layer in enumerate(self.cross_decoder_layers, + start=len(self.self_decoder_layers)): # [altup_num_inputs, num_tokens, hidden_size] hidden_states = layer( positions=positions, @@ -771,6 +790,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): del lora_config # Unused. super().__init__() self.config = config + self.cache_config = vllm_config.cache_config self.model = Gemma3nModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..620664a5c8ba 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -3,8 +3,8 @@ import abc import functools from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, make_dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -501,3 +501,16 @@ def reorder_batch_to_split_decodes_and_prefills( modified_batch = True return modified_batch + + +def subclass_attention_metadata( + name_prefix: str, + metadata_cls: Any, + fields: list[tuple[str, Any, Any]], +) -> Any: + """ + Return a new subclass of `metadata_cls` with additional fields + """ + name: str = name_prefix + metadata_cls.__name__ # type: ignore + Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + return Wrapped diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 68490002b507..dcf5f5381265 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses import gc import time from contextlib import contextmanager @@ -45,7 +46,7 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, subclass_attention_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -317,9 +318,10 @@ def __init__( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + self.truncated_prefill_eligible_layers: set[str] = set() self.logits_indices = None - if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: + if self.cache_config.enable_kv_sharing_truncated_prefill: self.logits_indices = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) @@ -738,6 +740,51 @@ def _prepare_inputs( spec_decode_common_attn_metadata = None + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + + logits_indices_padded = None + if self.cache_config.enable_kv_sharing_truncated_prefill: + assert self.logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.logits_indices[:num_logits].copy_(logits_indices) + # self.logits_indices[num_logits:] might have leftover indices from + # previous iterations, whose values may be greater than the batch + # size in the current iteration. To ensure the indices are always + # valid, we fill the padded indices with the last index. + self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) + if (self.use_cuda_graph + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph( + num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = self.logits_indices[:num_logits_padded] + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -794,36 +841,37 @@ def _prepare_inputs( )) for layer_name in kv_cache_group_spec.layer_names: + if (self.cache_config.enable_kv_sharing_truncated_prefill and + layer_name in self.truncated_prefill_eligible_layers): + fields = [ + ('logits_indices_padded', Optional[torch.Tensor], + None), + ('num_logits_indices', int, 0), + ] + # Dynamically createa a dataclass type that inherits + # from attention metadata type but includes additional + # fields logits_indices_padded and num_logits_indices + # which are required for prefill truncation + truncated_prefill_metadata_type = ( + subclass_attention_metadata( + name_prefix="TruncatedPrefill", + metadata_cls=type(attn_metadata_i), + fields=fields, + )) + attn_metadata[ + layer_name] = truncated_prefill_metadata_type( + **dataclasses.asdict(attn_metadata_i), + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + ) + continue + attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - spec_decode_metadata = None - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) - logits_indices = spec_decode_metadata.logits_indices - # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1443,24 +1491,6 @@ def execute_model( # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - logits_indices_padded = None - if envs.VLLM_COMPUTE_PADDED_LOGITS_INDICES: - assert self.logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.logits_indices[:num_logits].copy_(logits_indices) - # Ensure we keep duplicates instead of zeros - self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) - if (self.use_cuda_graph - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph( - num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = self.logits_indices[:num_logits_padded] - # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( @@ -1469,7 +1499,6 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - logits_indices_padded=logits_indices_padded, ): self.maybe_setup_kv_connector(scheduler_output) @@ -2716,6 +2745,15 @@ def initialize_kv_cache_tensors( kv_cache_config.kv_cache_groups, kv_caches, ) + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) + # Iterate in reversed order and add layers that re-use KV cache + # e.g. in YOCO-like KV sharing setups used for Gemma3n + for layer_name in reversed(attn_layers): + if layer_name in self.shared_kv_cache_layers: + self.truncated_prefill_eligible_layers.add(layer_name) + else: + break bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, From c5c7404a63e865f80275e56755e9f9182ca5cab9 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 11:33:01 -0700 Subject: [PATCH 05/12] Fix lint Signed-off-by: Yong Hoon Shin --- .../e2e/test_kv_sharing_truncated_prefill.py | 31 +++++++++++-------- vllm/config.py | 6 ++-- vllm/v1/worker/gpu_model_runner.py | 4 +-- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py index 27c7f179275a..11fe09ab62d7 100644 --- a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -19,6 +19,7 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): + def forward( self, input_ids: torch.Tensor, @@ -32,28 +33,31 @@ def forward( attn_metadata = get_forward_context().attn_metadata # attn_metadata is None during dummy runs if attn_metadata is not None: - assert isinstance(attn_metadata, dict) # true in V1 - # Layer 20 is a cross-decoder layer in YOCO - layer_attn_metadata = attn_metadata['model.language_model.layers.20.self_attn.attn'] + assert isinstance(attn_metadata, dict) # true in V1 + # Layer 20 is a cross-decoder layer for Gemma3n + layer_attn_metadata = attn_metadata[ + 'model.language_model.layers.20.self_attn.attn'] if hasattr(layer_attn_metadata, 'logits_indices_padded'): # This field is only set when - # enable_kv_sharing_truncated_prefill is set to True + # enable_kv_sharing_truncated_prefill is set to True assert self.cache_config.enable_kv_sharing_truncated_prefill logits_indices_padded = ( - layer_attn_metadata.logits_indices_padded - ) + layer_attn_metadata.logits_indices_padded) assert logits_indices_padded is not None num_logits_indices = layer_attn_metadata.num_logits_indices assert num_logits_indices > 0 - - logits_hs = hidden_states[logits_indices_padded] + # Reset hidden states to random values and + # only set logits at logits_indices to valid values + # Because logits_indices are the only positions that are used + # for output token sampling, this still produces same outputs + logits_hs = hidden_states[logits_indices_padded] hidden_states = torch.randn_like(hidden_states) gen_indices = logits_indices_padded[:num_logits_indices] - # Only set logits for logits_indices to valid values hidden_states[gen_indices] = logits_hs[:num_logits_indices] return hidden_states + @pytest.fixture def test_prompts(): """ @@ -90,14 +94,15 @@ def test_kv_sharing_truncated_prefill( enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", TestGemma3nForConditionalGeneration) + ModelRegistry.register_model("Gemma3nForConditionalGeneration", + TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( - # This allows vLLM compilation backend to handle allocating and + # This allows vLLM compilation backend to handle allocating and # managing buffers for cudagraph cudagraph_copy_inputs=True, - level=CompilationLevel. - PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION) + level=CompilationLevel.PIECEWISE + if not enforce_eager else CompilationLevel.NO_COMPILATION) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/vllm/config.py b/vllm/config.py index 0a77cd1317cb..231b4a403db6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1685,8 +1685,10 @@ class CacheConfig: """The number of blocks to allocate for CPU memory.""" enable_kv_sharing_truncated_prefill: bool = False - """Skip prefill for tokens where applicable in YOCO-like KV-sharing - setups (e.g. Gemma3n)""" + """In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), + some layers can skip tokens corresponding to prefill. This flag enables + attention metadata for eligible layers to be overriden with metadata + necessary for implementating this optimization in some models""" def compute_hash(self) -> str: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dcf5f5381265..04ee56cb8008 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -848,8 +848,8 @@ def _prepare_inputs( None), ('num_logits_indices', int, 0), ] - # Dynamically createa a dataclass type that inherits - # from attention metadata type but includes additional + # Dynamically create a a dataclass type that inherits + # from attention metadata type but includes additional # fields logits_indices_padded and num_logits_indices # which are required for prefill truncation truncated_prefill_metadata_type = ( From 051a32bb7ca548840c25db96ed60ddf76390c810 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 12:11:30 -0700 Subject: [PATCH 06/12] Add test for attn metadata override Signed-off-by: Yong Hoon Shin --- .../e2e/test_kv_sharing_truncated_prefill.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py index 11fe09ab62d7..9c6dcb8fb5b0 100644 --- a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -13,6 +13,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration from vllm.model_executor.models.registry import ModelRegistry +from vllm.model_executor.models.utils import extract_layer_index from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -32,28 +33,35 @@ def forward( inputs_embeds, **kwargs) attn_metadata = get_forward_context().attn_metadata # attn_metadata is None during dummy runs - if attn_metadata is not None: + if (attn_metadata is not None + and self.cache_config.enable_kv_sharing_truncated_prefill): assert isinstance(attn_metadata, dict) # true in V1 - # Layer 20 is a cross-decoder layer for Gemma3n + # Gemma3n-E2B has 30 layers, with last 20 layers being + # cross-decoder layers. Check attention metadata is correct + for layer_name, metadata in attn_metadata.items(): + layer_idx = extract_layer_index(layer_name) + if layer_idx >= 20: + assert hasattr(metadata, 'logits_indices_padded') + assert hasattr(metadata, 'num_logits_indices') + else: + assert not hasattr(metadata, 'logits_indices_padded') + assert not hasattr(metadata, 'num_logits_indices') + + # Layer 20 is the first cross-decoder layer for Gemma3n layer_attn_metadata = attn_metadata[ 'model.language_model.layers.20.self_attn.attn'] - if hasattr(layer_attn_metadata, 'logits_indices_padded'): - # This field is only set when - # enable_kv_sharing_truncated_prefill is set to True - assert self.cache_config.enable_kv_sharing_truncated_prefill - logits_indices_padded = ( - layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] + logits_indices_padded = (layer_attn_metadata.logits_indices_padded) + assert logits_indices_padded is not None + num_logits_indices = layer_attn_metadata.num_logits_indices + assert num_logits_indices > 0 + # Reset hidden states to random values and + # only set logits at logits_indices to valid values + # Because logits_indices are the only positions that are used + # for output token sampling, this still produces same outputs + logits_hs = hidden_states[logits_indices_padded] + hidden_states = torch.randn_like(hidden_states) + gen_indices = logits_indices_padded[:num_logits_indices] + hidden_states[gen_indices] = logits_hs[:num_logits_indices] return hidden_states From 1edb2196218ed4863287550b9cae951c0c9474cc Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 13:34:49 -0700 Subject: [PATCH 07/12] Create attn metadata subclass once Signed-off-by: Yong Hoon Shin --- vllm/v1/attention/backends/utils.py | 17 +++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 29 +++++++++++++---------------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 620664a5c8ba..ffc4f919c426 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -503,6 +503,12 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch +TRUNCATED_PREFILL_METADATA_FIELDS = [ + ('logits_indices_padded', Optional[torch.Tensor], None), + ('num_logits_indices', int, 0), +] + + def subclass_attention_metadata( name_prefix: str, metadata_cls: Any, @@ -514,3 +520,14 @@ def subclass_attention_metadata( name: str = name_prefix + metadata_cls.__name__ # type: ignore Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) return Wrapped + + +def make_truncated_prefill_attention_metadata(metadata_cls: Any, ) -> Any: + """ + Return a new subclass of `metadata_cls` for truncated prefill + """ + return subclass_attention_metadata( + name_prefix="TruncatedPrefill", + metadata_cls=metadata_cls, + fields=TRUNCATED_PREFILL_METADATA_FIELDS, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 04ee56cb8008..63b82c19c141 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -46,7 +46,8 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches, subclass_attention_metadata) + make_local_attention_virtual_batches, + make_truncated_prefill_attention_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -840,24 +841,20 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) + truncated_prefill_metadata_type = type(attn_metadata_i) + if (self.cache_config.enable_kv_sharing_truncated_prefill + and self.truncated_prefill_eligible_layers): + # Dynamically create a a dataclass type that inherits + # from attention metadata type but includes additional + # fields logits_indices_padded and num_logits_indices + # which are required for prefill truncation + truncated_prefill_metadata_type = ( + make_truncated_prefill_attention_metadata( + metadata_cls=type(attn_metadata_i), )) + for layer_name in kv_cache_group_spec.layer_names: if (self.cache_config.enable_kv_sharing_truncated_prefill and layer_name in self.truncated_prefill_eligible_layers): - fields = [ - ('logits_indices_padded', Optional[torch.Tensor], - None), - ('num_logits_indices', int, 0), - ] - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - truncated_prefill_metadata_type = ( - subclass_attention_metadata( - name_prefix="TruncatedPrefill", - metadata_cls=type(attn_metadata_i), - fields=fields, - )) attn_metadata[ layer_name] = truncated_prefill_metadata_type( **dataclasses.asdict(attn_metadata_i), From 034f08e65e596139d8da7fe53c8a892209b64ba3 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 14:44:57 -0700 Subject: [PATCH 08/12] Address comments Signed-off-by: Yong Hoon Shin --- .../e2e/test_kv_sharing_truncated_prefill.py | 4 ++-- vllm/config.py | 13 ++++++++++-- vllm/model_executor/models/gemma3n.py | 21 +------------------ vllm/v1/worker/gpu_model_runner.py | 16 +++++++------- 4 files changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py index 9c6dcb8fb5b0..3ec3bb94b4ef 100644 --- a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -47,9 +47,9 @@ def forward( assert not hasattr(metadata, 'logits_indices_padded') assert not hasattr(metadata, 'num_logits_indices') - # Layer 20 is the first cross-decoder layer for Gemma3n + # Last layer will be a KV sharing layer layer_attn_metadata = attn_metadata[ - 'model.language_model.layers.20.self_attn.attn'] + self.model.language_model.layers[-1].self_attn.attn.layer_name] logits_indices_padded = (layer_attn_metadata.logits_indices_padded) assert logits_indices_padded is not None num_logits_indices = layer_attn_metadata.num_logits_indices diff --git a/vllm/config.py b/vllm/config.py index 231b4a403db6..829d7c3be974 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1685,10 +1685,14 @@ class CacheConfig: """The number of blocks to allocate for CPU memory.""" enable_kv_sharing_truncated_prefill: bool = False - """In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), + """This feature is work in progress and no prefill optimization takes place + with this flag enabled currently. + + In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models""" + necessary for implementating this optimization in some models (e.g. Gemma3n) + """ def compute_hash(self) -> str: """ @@ -1731,6 +1735,11 @@ def _verify_args(self) -> Self: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.enable_kv_sharing_truncated_prefill: + logger.warning_once( + "This feature is currently work in progress " + "and not functional yet (i.e. no prefill savings)") + return self def _verify_cache_dtype(self) -> None: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 36a9fda19938..a4390e783d41 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -581,15 +581,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Gemma3nDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") - - first_kv_shared_layer_idx = (config.num_hidden_layers - - config.num_kv_shared_layers) - # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) - self.self_decoder_layers = self.layers[:first_kv_shared_layer_idx] - # Layer idx 20-34 are cross-decoder layers in YOCO - # Refer to YOCO paper https://arxiv.org/abs/2405.05254 - self.cross_decoder_layers = self.layers[first_kv_shared_layer_idx:] - self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, @@ -655,17 +646,7 @@ def forward( hidden_states = torch.stack(hidden_states, dim=0) # Transformer blocks. - for layer_idx, layer in enumerate(self.self_decoder_layers): - # [altup_num_inputs, num_tokens, hidden_size] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - per_layer_input=per_layer_inputs[:, layer_idx, :], - **kwargs, - ) - - for layer_idx, layer in enumerate(self.cross_decoder_layers, - start=len(self.self_decoder_layers)): + for layer_idx, layer in enumerate(self.layers): # [altup_num_inputs, num_tokens, hidden_size] hidden_states = layer( positions=positions, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 63b82c19c141..465c966e3b6a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -841,7 +841,7 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) - truncated_prefill_metadata_type = type(attn_metadata_i) + truncated_prefill_metadata = attn_metadata_i if (self.cache_config.enable_kv_sharing_truncated_prefill and self.truncated_prefill_eligible_layers): # Dynamically create a a dataclass type that inherits @@ -851,16 +851,16 @@ def _prepare_inputs( truncated_prefill_metadata_type = ( make_truncated_prefill_attention_metadata( metadata_cls=type(attn_metadata_i), )) + truncated_prefill_metadata = truncated_prefill_metadata_type( + **dataclasses.asdict(attn_metadata_i), + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + ) for layer_name in kv_cache_group_spec.layer_names: if (self.cache_config.enable_kv_sharing_truncated_prefill and layer_name in self.truncated_prefill_eligible_layers): - attn_metadata[ - layer_name] = truncated_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) + attn_metadata[layer_name] = truncated_prefill_metadata continue attn_metadata[layer_name] = attn_metadata_i @@ -2745,7 +2745,7 @@ def initialize_kv_cache_tensors( attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups used for Gemma3n + # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: self.truncated_prefill_eligible_layers.add(layer_name) From 29cf6dc43bc36868659b8a338d957f1e543ce1b2 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 15:05:50 -0700 Subject: [PATCH 09/12] More comments Signed-off-by: Yong Hoon Shin --- vllm/entrypoints/llm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 96d7bfa0358a..c23af66bd183 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -145,6 +145,10 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. + enable_kv_sharing_truncated_prefill: Work in progress feature to + enable metadata required to skip prefill in certain KV sharing + setups (e.g. YOCO). See + [CacheConfig][vllm.config.CacheConfig]. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: From 2e3b4c352c32c9838416a61c778b8a5830b0d746 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 17:59:17 -0700 Subject: [PATCH 10/12] Rename truncated prefill to fast prefill Signed-off-by: Yong Hoon Shin --- ...ill.py => test_kv_sharing_fast_prefill.py} | 6 ++-- vllm/config.py | 4 +-- vllm/engine/arg_utils.py | 12 ++++---- vllm/entrypoints/llm.py | 8 ++--- vllm/v1/attention/backends/utils.py | 11 +++---- vllm/v1/worker/gpu_model_runner.py | 30 +++++++++---------- 6 files changed, 35 insertions(+), 36 deletions(-) rename tests/v1/e2e/{test_kv_sharing_truncated_prefill.py => test_kv_sharing_fast_prefill.py} (96%) diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py similarity index 96% rename from tests/v1/e2e/test_kv_sharing_truncated_prefill.py rename to tests/v1/e2e/test_kv_sharing_fast_prefill.py index 3ec3bb94b4ef..616fc7a86059 100644 --- a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -34,7 +34,7 @@ def forward( attn_metadata = get_forward_context().attn_metadata # attn_metadata is None during dummy runs if (attn_metadata is not None - and self.cache_config.enable_kv_sharing_truncated_prefill): + and self.cache_config.kv_sharing_fast_prefill): assert isinstance(attn_metadata, dict) # true in V1 # Gemma3n-E2B has 30 layers, with last 20 layers being # cross-decoder layers. Check attention metadata is correct @@ -97,7 +97,7 @@ def test_prompts(): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_kv_sharing_truncated_prefill( +def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, test_prompts: list[str], @@ -129,7 +129,7 @@ def test_kv_sharing_truncated_prefill( llm = LLM(model="google/gemma-3n-E2B-it", enforce_eager=enforce_eager, compilation_config=compilation_config, - enable_kv_sharing_truncated_prefill=True) + kv_sharing_fast_prefill=True) optimized_responses = llm.generate(test_prompts, sampling_params) misses = 0 diff --git a/vllm/config.py b/vllm/config.py index 829d7c3be974..c582cebedf9b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1684,7 +1684,7 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" - enable_kv_sharing_truncated_prefill: bool = False + kv_sharing_fast_prefill: bool = False """This feature is work in progress and no prefill optimization takes place with this flag enabled currently. @@ -1735,7 +1735,7 @@ def _verify_args(self) -> Self: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.enable_kv_sharing_truncated_prefill: + if self.kv_sharing_fast_prefill: logger.warning_once( "This feature is currently work in progress " "and not functional yet (i.e. no prefill savings)") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 10a2331799d5..efa5880096c1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -438,8 +438,8 @@ class EngineArgs: # DEPRECATED enable_prompt_adapter: bool = False - enable_kv_sharing_truncated_prefill: bool = \ - CacheConfig.enable_kv_sharing_truncated_prefill + kv_sharing_fast_prefill: bool = \ + CacheConfig.kv_sharing_fast_prefill def __post_init__(self): # support `EngineArgs(compilation_config={...})` @@ -689,9 +689,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument( - "--enable-kv-sharing-truncated-prefill", - **cache_kwargs["enable_kv_sharing_truncated_prefill"]) + cache_group.add_argument("--kv-sharing-fast-prefill", + **cache_kwargs["kv_sharing_fast_prefill"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1062,8 +1061,7 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, - enable_kv_sharing_truncated_prefill=self. - enable_kv_sharing_truncated_prefill, + kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c23af66bd183..b4b0ed743a07 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -145,7 +145,7 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. - enable_kv_sharing_truncated_prefill: Work in progress feature to + kv_sharing_fast_prefill: Work in progress feature to enable metadata required to skip prefill in certain KV sharing setups (e.g. YOCO). See [CacheConfig][vllm.config.CacheConfig]. @@ -197,7 +197,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - enable_kv_sharing_truncated_prefill: bool = False, + kv_sharing_fast_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -271,8 +271,8 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, - enable_kv_sharing_truncated_prefill=\ - enable_kv_sharing_truncated_prefill, + kv_sharing_fast_prefill=\ + kv_sharing_fast_prefill, **kwargs, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ffc4f919c426..5afa71768d20 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -503,7 +503,7 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch -TRUNCATED_PREFILL_METADATA_FIELDS = [ +FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), ] @@ -522,12 +522,13 @@ def subclass_attention_metadata( return Wrapped -def make_truncated_prefill_attention_metadata(metadata_cls: Any, ) -> Any: +def make_kv_sharing_fast_prefill_attention_metadata( + metadata_cls: Any, ) -> Any: """ - Return a new subclass of `metadata_cls` for truncated prefill + Return a new subclass of `metadata_cls` for fast prefill """ return subclass_attention_metadata( - name_prefix="TruncatedPrefill", + name_prefix="FastPrefill", metadata_cls=metadata_cls, - fields=TRUNCATED_PREFILL_METADATA_FIELDS, + fields=FAST_PREFILL_METADATA_FIELDS, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 465c966e3b6a..6815b7d345e9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -46,8 +46,8 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches, - make_truncated_prefill_attention_metadata) + make_kv_sharing_fast_prefill_attention_metadata, + make_local_attention_virtual_batches) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -319,10 +319,10 @@ def __init__( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - self.truncated_prefill_eligible_layers: set[str] = set() + self.fast_prefill_eligible_layers: set[str] = set() self.logits_indices = None - if self.cache_config.enable_kv_sharing_truncated_prefill: + if self.cache_config.kv_sharing_fast_prefill: self.logits_indices = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) @@ -766,7 +766,7 @@ def _prepare_inputs( logits_indices = spec_decode_metadata.logits_indices logits_indices_padded = None - if self.cache_config.enable_kv_sharing_truncated_prefill: + if self.cache_config.kv_sharing_fast_prefill: assert self.logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 @@ -841,26 +841,26 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) - truncated_prefill_metadata = attn_metadata_i - if (self.cache_config.enable_kv_sharing_truncated_prefill - and self.truncated_prefill_eligible_layers): + fast_prefill_metadata = attn_metadata_i + if (self.cache_config.kv_sharing_fast_prefill + and self.fast_prefill_eligible_layers): # Dynamically create a a dataclass type that inherits # from attention metadata type but includes additional # fields logits_indices_padded and num_logits_indices # which are required for prefill truncation - truncated_prefill_metadata_type = ( - make_truncated_prefill_attention_metadata( + fast_prefill_metadata_type = ( + make_kv_sharing_fast_prefill_attention_metadata( metadata_cls=type(attn_metadata_i), )) - truncated_prefill_metadata = truncated_prefill_metadata_type( + fast_prefill_metadata = fast_prefill_metadata_type( **dataclasses.asdict(attn_metadata_i), logits_indices_padded=logits_indices_padded, num_logits_indices=logits_indices.size(0), ) for layer_name in kv_cache_group_spec.layer_names: - if (self.cache_config.enable_kv_sharing_truncated_prefill and - layer_name in self.truncated_prefill_eligible_layers): - attn_metadata[layer_name] = truncated_prefill_metadata + if (self.cache_config.kv_sharing_fast_prefill + and layer_name in self.fast_prefill_eligible_layers): + attn_metadata[layer_name] = fast_prefill_metadata continue attn_metadata[layer_name] = attn_metadata_i @@ -2748,7 +2748,7 @@ def initialize_kv_cache_tensors( # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.truncated_prefill_eligible_layers.add(layer_name) + self.fast_prefill_eligible_layers.add(layer_name) else: break From eaa142ac6382ee4de756cc04f159a61ffae20534 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 28 Jul 2025 18:14:11 -0700 Subject: [PATCH 11/12] Update warning log Signed-off-by: Yong Hoon Shin --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index c582cebedf9b..599eab35b4da 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1737,7 +1737,7 @@ def _verify_args(self) -> Self: if self.kv_sharing_fast_prefill: logger.warning_once( - "This feature is currently work in progress " + "--kv-sharing-fast-prefill is currently work in progress " "and not functional yet (i.e. no prefill savings)") return self From ba83304143428b5ee18994c923a31a7eecf71699 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 29 Jul 2025 09:20:31 -0700 Subject: [PATCH 12/12] Address comments Signed-off-by: Yong Hoon Shin --- vllm/entrypoints/llm.py | 7 ------ vllm/v1/attention/backends/utils.py | 6 ++--- vllm/v1/worker/gpu_model_runner.py | 36 ++++++++++++++++------------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b4b0ed743a07..2f766a2dae57 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -145,10 +145,6 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the level of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. - kv_sharing_fast_prefill: Work in progress feature to - enable metadata required to skip prefill in certain KV sharing - setups (e.g. YOCO). See - [CacheConfig][vllm.config.CacheConfig]. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: @@ -197,7 +193,6 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - kv_sharing_fast_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -271,8 +266,6 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, - kv_sharing_fast_prefill=\ - kv_sharing_fast_prefill, **kwargs, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 5afa71768d20..56a49bb37c07 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -503,7 +503,7 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch -FAST_PREFILL_METADATA_FIELDS = [ +KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ ('logits_indices_padded', Optional[torch.Tensor], None), ('num_logits_indices', int, 0), ] @@ -528,7 +528,7 @@ def make_kv_sharing_fast_prefill_attention_metadata( Return a new subclass of `metadata_cls` for fast prefill """ return subclass_attention_metadata( - name_prefix="FastPrefill", + name_prefix="KVSharingFastPrefill", metadata_cls=metadata_cls, - fields=FAST_PREFILL_METADATA_FIELDS, + fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6815b7d345e9..d82cfb62556e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -319,13 +319,12 @@ def __init__( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - self.fast_prefill_eligible_layers: set[str] = set() + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() - self.logits_indices = None + self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: - self.logits_indices = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) + self.kv_sharing_fast_prefill_logits_indices = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=self.device) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -767,15 +766,17 @@ def _prepare_inputs( logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: - assert self.logits_indices is not None + assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.logits_indices[:num_logits].copy_(logits_indices) - # self.logits_indices[num_logits:] might have leftover indices from - # previous iterations, whose values may be greater than the batch - # size in the current iteration. To ensure the indices are always + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. - self.logits_indices[num_logits:].fill_(logits_indices[-1].item()) + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) if (self.use_cuda_graph and num_logits <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -784,7 +785,9 @@ def _prepare_inputs( num_logits) else: num_logits_padded = num_logits - logits_indices_padded = self.logits_indices[:num_logits_padded] + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] + ) attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers @@ -843,7 +846,7 @@ def _prepare_inputs( fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill - and self.fast_prefill_eligible_layers): + and self.kv_sharing_fast_prefill_eligible_layers): # Dynamically create a a dataclass type that inherits # from attention metadata type but includes additional # fields logits_indices_padded and num_logits_indices @@ -858,8 +861,8 @@ def _prepare_inputs( ) for layer_name in kv_cache_group_spec.layer_names: - if (self.cache_config.kv_sharing_fast_prefill - and layer_name in self.fast_prefill_eligible_layers): + if (self.cache_config.kv_sharing_fast_prefill and layer_name + in self.kv_sharing_fast_prefill_eligible_layers): attn_metadata[layer_name] = fast_prefill_metadata continue @@ -2748,7 +2751,8 @@ def initialize_kv_cache_tensors( # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.fast_prefill_eligible_layers.add(layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add( + layer_name) else: break