Skip to content
143 changes: 143 additions & 0 deletions tests/v1/e2e/test_kv_sharing_fast_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import gc
import random
from typing import Optional, Union

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors

from ...utils import fork_new_process_for_each_test


class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')

# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.model.language_model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]

return hidden_states


@pytest.fixture
def test_prompts():
"""
Adapted from tests/v1/e2e/test_spec_decode.py
"""
prompt_types = ["repeat", "sentence"]
# Setting higher num prompts increases the chance of numerics mismatch
# due to matrix multiplication numerics depending on batch dimension
num_prompts = 10
prompts = []

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)

for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""please repeat the word '{word}' 10 times."""
elif kind == "sentence":
prompt = f"""please give a ten-word sentence that
uses the word {word} at least once."""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append(prompt)

return prompts


@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
level=CompilationLevel.PIECEWISE
if not enforce_eager else CompilationLevel.NO_COMPILATION)

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
)
ref_responses = llm.generate(test_prompts, sampling_params)

del llm
gc.collect()
torch.cuda.empty_cache()

llm = LLM(model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
kv_sharing_fast_prefill=True)
optimized_responses = llm.generate(test_prompts, sampling_params)

misses = 0

for ref_response, optimized_response in zip(ref_responses,
optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[
0].text:
misses += 1

assert misses == 0
15 changes: 15 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,16 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

kv_sharing_fast_prefill: bool = False
"""This feature is work in progress and no prefill optimization takes place
with this flag enabled currently.

In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -1725,6 +1735,11 @@ def _verify_args(self) -> Self:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")

return self

def _verify_cache_dtype(self) -> None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ class EngineArgs:
# DEPRECATED
enable_prompt_adapter: bool = False

kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill

def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
Expand Down Expand Up @@ -686,6 +689,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument("--calculate-kv-scales",
**cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])

# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
Expand Down Expand Up @@ -1056,6 +1061,7 @@ def create_engine_config(
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
)

# Get the current placement group if Ray is initialized and
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
del lora_config # Unused.
super().__init__()
self.config = config
self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
Expand Down
35 changes: 33 additions & 2 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
from dataclasses import dataclass, make_dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar

import numpy as np
import torch
Expand Down Expand Up @@ -501,3 +501,34 @@ def reorder_batch_to_split_decodes_and_prefills(
modified_batch = True

return modified_batch


KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]


def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped


def make_kv_sharing_fast_prefill_attention_metadata(
metadata_cls: Any, ) -> Any:
"""
Return a new subclass of `metadata_cls` for fast prefill
"""
return subclass_attention_metadata(
name_prefix="KVSharingFastPrefill",
metadata_cls=metadata_cls,
fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS,
)
Loading