diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py new file mode 100644 index 000000000000..616fc7a86059 --- /dev/null +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import random +from typing import Optional, Union + +import pytest +import torch + +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel +from vllm.forward_context import get_forward_context +from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration +from vllm.model_executor.models.registry import ModelRegistry +from vllm.model_executor.models.utils import extract_layer_index +from vllm.sequence import IntermediateTensors + +from ...utils import fork_new_process_for_each_test + + +class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + attn_metadata = get_forward_context().attn_metadata + # attn_metadata is None during dummy runs + if (attn_metadata is not None + and self.cache_config.kv_sharing_fast_prefill): + assert isinstance(attn_metadata, dict) # true in V1 + # Gemma3n-E2B has 30 layers, with last 20 layers being + # cross-decoder layers. Check attention metadata is correct + for layer_name, metadata in attn_metadata.items(): + layer_idx = extract_layer_index(layer_name) + if layer_idx >= 20: + assert hasattr(metadata, 'logits_indices_padded') + assert hasattr(metadata, 'num_logits_indices') + else: + assert not hasattr(metadata, 'logits_indices_padded') + assert not hasattr(metadata, 'num_logits_indices') + + # Last layer will be a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.model.language_model.layers[-1].self_attn.attn.layer_name] + logits_indices_padded = (layer_attn_metadata.logits_indices_padded) + assert logits_indices_padded is not None + num_logits_indices = layer_attn_metadata.num_logits_indices + assert num_logits_indices > 0 + # Reset hidden states to random values and + # only set logits at logits_indices to valid values + # Because logits_indices are the only positions that are used + # for output token sampling, this still produces same outputs + logits_hs = hidden_states[logits_indices_padded] + hidden_states = torch.randn_like(hidden_states) + gen_indices = logits_indices_padded[:num_logits_indices] + hidden_states[gen_indices] = logits_hs[:num_logits_indices] + + return hidden_states + + +@pytest.fixture +def test_prompts(): + """ + Adapted from tests/v1/e2e/test_spec_decode.py + """ + prompt_types = ["repeat", "sentence"] + # Setting higher num prompts increases the chance of numerics mismatch + # due to matrix multiplication numerics depending on batch dimension + num_prompts = 10 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f"""please repeat the word '{word}' 10 times.""" + elif kind == "sentence": + prompt = f"""please give a ten-word sentence that + uses the word {word} at least once.""" + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append(prompt) + + return prompts + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_kv_sharing_fast_prefill( + monkeypatch: pytest.MonkeyPatch, + enforce_eager: bool, + test_prompts: list[str], +): + ModelRegistry.register_model("Gemma3nForConditionalGeneration", + TestGemma3nForConditionalGeneration) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + compilation_config = CompilationConfig( + # This allows vLLM compilation backend to handle allocating and + # managing buffers for cudagraph + cudagraph_copy_inputs=True, + level=CompilationLevel.PIECEWISE + if not enforce_eager else CompilationLevel.NO_COMPILATION) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + ) + ref_responses = llm.generate(test_prompts, sampling_params) + + del llm + gc.collect() + torch.cuda.empty_cache() + + llm = LLM(model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + kv_sharing_fast_prefill=True) + optimized_responses = llm.generate(test_prompts, sampling_params) + + misses = 0 + + for ref_response, optimized_response in zip(ref_responses, + optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[ + 0].text: + misses += 1 + + assert misses == 0 diff --git a/vllm/config.py b/vllm/config.py index 07df71ec51ef..599eab35b4da 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1684,6 +1684,16 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + kv_sharing_fast_prefill: bool = False + """This feature is work in progress and no prefill optimization takes place + with this flag enabled currently. + + In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), + some layers can skip tokens corresponding to prefill. This flag enables + attention metadata for eligible layers to be overriden with metadata + necessary for implementating this optimization in some models (e.g. Gemma3n) + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -1725,6 +1735,11 @@ def _verify_args(self) -> Self: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.kv_sharing_fast_prefill: + logger.warning_once( + "--kv-sharing-fast-prefill is currently work in progress " + "and not functional yet (i.e. no prefill savings)") + return self def _verify_cache_dtype(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 709968004718..efa5880096c1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -438,6 +438,9 @@ class EngineArgs: # DEPRECATED enable_prompt_adapter: bool = False + kv_sharing_fast_prefill: bool = \ + CacheConfig.kv_sharing_fast_prefill + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -686,6 +689,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) + cache_group.add_argument("--kv-sharing-fast-prefill", + **cache_kwargs["kv_sharing_fast_prefill"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1056,6 +1061,7 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 7d163320e0d6..a4390e783d41 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -771,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): del lora_config # Unused. super().__init__() self.config = config + self.cache_config = vllm_config.cache_config self.model = Gemma3nModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..56a49bb37c07 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -3,8 +3,8 @@ import abc import functools from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, make_dataclass +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -501,3 +501,34 @@ def reorder_batch_to_split_decodes_and_prefills( modified_batch = True return modified_batch + + +KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ + ('logits_indices_padded', Optional[torch.Tensor], None), + ('num_logits_indices', int, 0), +] + + +def subclass_attention_metadata( + name_prefix: str, + metadata_cls: Any, + fields: list[tuple[str, Any, Any]], +) -> Any: + """ + Return a new subclass of `metadata_cls` with additional fields + """ + name: str = name_prefix + metadata_cls.__name__ # type: ignore + Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + return Wrapped + + +def make_kv_sharing_fast_prefill_attention_metadata( + metadata_cls: Any, ) -> Any: + """ + Return a new subclass of `metadata_cls` for fast prefill + """ + return subclass_attention_metadata( + name_prefix="KVSharingFastPrefill", + metadata_cls=metadata_cls, + fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32004ced4aae..d82cfb62556e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses import gc import time from contextlib import contextmanager @@ -45,6 +46,7 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, + make_kv_sharing_fast_prefill_attention_metadata, make_local_attention_virtual_batches) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, @@ -317,6 +319,12 @@ def __init__( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() + + self.kv_sharing_fast_prefill_logits_indices = None + if self.cache_config.kv_sharing_fast_prefill: + self.kv_sharing_fast_prefill_logits_indices = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=self.device) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -732,6 +740,55 @@ def _prepare_inputs( spec_decode_common_attn_metadata = None + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + + logits_indices_padded = None + if self.cache_config.kv_sharing_fast_prefill: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.use_cuda_graph + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph( + num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] + ) + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -787,37 +844,34 @@ def _prepare_inputs( common_attn_metadata=common_attn_metadata, )) + fast_prefill_metadata = attn_metadata_i + if (self.cache_config.kv_sharing_fast_prefill + and self.kv_sharing_fast_prefill_eligible_layers): + # Dynamically create a a dataclass type that inherits + # from attention metadata type but includes additional + # fields logits_indices_padded and num_logits_indices + # which are required for prefill truncation + fast_prefill_metadata_type = ( + make_kv_sharing_fast_prefill_attention_metadata( + metadata_cls=type(attn_metadata_i), )) + fast_prefill_metadata = fast_prefill_metadata_type( + **dataclasses.asdict(attn_metadata_i), + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + ) + for layer_name in kv_cache_group_spec.layer_names: + if (self.cache_config.kv_sharing_fast_prefill and layer_name + in self.kv_sharing_fast_prefill_eligible_layers): + attn_metadata[layer_name] = fast_prefill_metadata + continue + attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - spec_decode_metadata = None - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) - logits_indices = spec_decode_metadata.logits_indices - # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1364,6 +1418,7 @@ def execute_model( spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata) = ( self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -2690,6 +2745,16 @@ def initialize_kv_cache_tensors( kv_cache_config.kv_cache_groups, kv_caches, ) + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) + # Iterate in reversed order and add layers that re-use KV cache + # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) + for layer_name in reversed(attn_layers): + if layer_name in self.shared_kv_cache_layers: + self.kv_sharing_fast_prefill_eligible_layers.add( + layer_name) + else: + break bind_kv_cache(kv_caches, self.compilation_config.static_forward_context,