diff --git a/docs/.nav.yml b/docs/.nav.yml index 441ef9f521e..79d7c38e274 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -98,6 +98,7 @@ nav: - design/feature/disaggregated_inference.md - design/feature/ray_based_execution.md - design/feature/omni_connectors/ + - design/feature/prefix_caching.md - design/feature/cfg_parallel.md - design/feature/expert_parallel.md - design/feature/sequence_parallel.md diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md new file mode 100644 index 00000000000..ebad8b69106 --- /dev/null +++ b/docs/design/feature/prefix_caching.md @@ -0,0 +1,164 @@ +# Automatic Prefix Caching in Omni Models + + +--- + +## Table of Contents + +- [Overview](#overview) +- [High-Level Approach](#high-level-approach) +- [Example](#example) +- [What About Multimodal Inputs?](#what-about-multimodal-inputs) + +--- + +### Overview + +Prefix caching in the context of kv-cache management is a useful optimization for avoiding redundant computations. The main idea is that we store portions of the kv-cache from processed requests, so that we can reuse them if incoming requests have the same prefix as previous requests. + +vLLM manages the kv-cache as blocks, which represent a span of tokens of a fixed length. Blocks are hashable by the content that they contain, which typically means the tokens within the span, but also could be influenced by other factors, e.g., LoRA and multimodal data. + +vLLM implements automatic prefix caching for managing its kv-cache, which is best understood by reading the design document [here](https://docs.vllm.ai/en/latest/design/prefix_caching/). vLLM-Omni builds on top of the prefix caching mechanism in a noninvasive way to allow caching between stages in Omni pipelines. This typically means for a given stage we aim to support caching for the following: + +- The last hidden states produced by the stage +- Model / stage specific multimodal data + +!!! note "Note 1" + This document describes vLLM-Omni's mechanism for caching tensor outputs that are meant to be passed between stages, when requests have common prefixes, similar to the way in which vLLM has prefix caching for the kv-cache. This works in conjunction with vLLM's multimodal encoder caching, but is distinct. See the final section for a concrete example for how they tie together in practice. + +### High-Level Approach +!!! note "Note 2" + Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear. + +The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note: + +- Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU. + +- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. This allows us to dynamically discover the tensors to be marked as cacheable outputs in each Omni model without having to explicitly specify cacheable output field names in every model. + +With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector. + + +### Example +!!! note "Note 3" + Prefix caching in vLLM-Omni currently is only supported on AutoRegressive stages with one kv-cache group. It can be enabled/disabled per-stage via the `enable_prefix_caching` parameter in the model's stage config. + +The way in which vLLM-Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following: + +- `num_blocks=8` +- `block_size=4` +- `hidden_size=2` +- A stage specific multimodal output tensor named `mm_feature` with feature dimension `16` + +The prefix cache flow is then outlined below. + +1. When the model is initialized, we can determine the `hidden_size` from the `ModelConfig`, and allocate a cache of size `(num_blocks, block_size, hidden_size)`. + +2. Say we process the request `The quick brown fox was tired and slept beneath the shady tree`, which is 12 tokens and evenly divides into 3 blocks as shown below. + +``` + [ The quick brown fox ] [ was tired and slept ] [beneath the shady tree ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->| +``` + +When the request processes, we inspect the multimodal outputs and identify the `mm_feature` tensor, which will be of shape `(seq_len, feature_dim)`, i.e., `(12, 16)` in this example. We note that the first axis is dependent on the `seq_len` and add a new cache_tensor of shape `(num_blocks, block_size, feature_dim)` to our multimodal cache for tensors. + + +3. If we lay out the cache as a 2D tensor of shape (`num_blocks`, `block_size`), we'll have something like the following: + +``` +0: [ The quick brown fox ] +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [EMPTY] +... +7: [EMPTY] +``` + +Or, if we flatten it down to 1D, +``` +0: The +1: quick +2: brown +3: fox +... +11: tree +12: [EMPTY] +... +``` + +which we can think of as row indices into the hidden states tensor if we view it as the 2D shape `(num_blocks x block_size, feature_dim)`. That is, the analogous flattened (from 3D -> 2D) mapping of the cache for hidden states becomes the following. +``` +0: +1: +2: +3: +... +11: +12: [EMPTY] +... +``` + +Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. Note that in practice, we may have multiple multimodal output tensors per forward pass, which may have different names and different feature dimensions. + + +4. Now, say that we receive a new request `The quick brown fox jumped over the dog`. + +``` + [ The quick brown fox ] [ jumped over the dog ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +Here, we will have a cache hit for `Block 1` which will be detected by vLLM based on the hash of the first block when it's handling the prefix caching on the kv-cache. As a result, when we get the output from the scheduler, we will see that `num_computed_tokens=4` (corresponding to the cached first block), and we only need to process the remaining 4 new tokens in the new prefill. + +Since we have the block indices / slot mappings from the kv cache manager, we can simply mirror the mappings and leverage the same indices for the cached hidden states and multimodal outputs. This allows us to look up the correct tensors from our externally maintained 3D caches. + +``` +0: [ The quick brown fox ] < already in the cache +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [ jumped over the dog ] < added on the second request +4: [EMPTY] +... +7: [EMPTY] +... +``` + +Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call. + + +### What About Multimodal Inputs? +It's also useful to consider the case about how Omni prefix caching is handled when we have multimodal inputs that don't cleanly end on block boundaries, as well as how this works with multimodal encoder caching in vLLM. For example: + +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 foo ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +In this case, only `Block 1` will have outputs stored in the prefix tensor cache, because vLLM does not store partial blocks. This may appear to be a problem at first glance, because the multimodal input is fragmented across a new block that wasn't cached. + +In reality, this isn't a big problem for correctness, because vLLM also maintains an encoder cache for multimodal inputs. In other words, after the first pass, we'll have the following: + +- The Block 1 hash, which is used for prefix caching +- The hash describing the image data starting at position 0 and with length 6 +- In vLLM's encoder cache, a mapping from the image hash above to the encoder output + + +To understand what happens, say we get the following input as a second request: +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 bar baz ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +First, the scheduler will check for a prefix cache hit, which we will see on `Block 1`. As a result, we will have 4 tokens marked as precomputed, and only see the remaining 4 tokens in the following prefill. + +Because we have multimodal data in a scheduled span that isn't fully precomputed, we still need to call the visual encoder. However, since we have the image hash and encoder cache, we will retrieve the encoder outputs for `Im4` and `Im5` as we create the multimodal embeddings. + +When we pass our multimodal tensors to the language model component in the same stage, we'll then expect the same outputs, because the prefix caching behaviors in vLLM-Omni / vLLM match, so the LLM will use vLLM's KV cache manager's prefix caching to correctly handle the attention information for `Block 1` while calculating the outputs for `Block 2`, giving us the correct results for processing `Block 2` with the context of `Block 1`. + +Finally, we look up the output hidden states/multimodal tensors corresponding to the prefix cache hit `Block 1` and concatenate it with the forward pass result to get the final result, which is expected to be identical to the full hidden states when prefix caching is disabled. diff --git a/tests/conftest.py b/tests/conftest.py index 098fd8d970c..ad1008b7263 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1850,6 +1850,7 @@ class OmniResponse: e2e_latency: float | None = None success: bool = False error_message: str | None = None + cached_tokens: int | None = None @dataclass @@ -2345,6 +2346,11 @@ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: if hasattr(choice.message, "content") and choice.message.content is not None: text_content = choice.message.content + # Extract cached_tokens for prefix caching tests + usage = getattr(chat_completion, "usage", None) + if usage and (details := getattr(usage, "prompt_tokens_details", None)): + result.cached_tokens = details.cached_tokens + # Calculate end-to-end latency result.e2e_latency = time.perf_counter() - start_time diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py new file mode 100644 index 00000000000..c3d8c1ff928 --- /dev/null +++ b/tests/core/test_prefix_cache.py @@ -0,0 +1,347 @@ +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm_omni.core.prefix_cache import OmniTensorPrefixCache + +DEFAULT_SEQ_LEN = 15 +NUM_BLOCKS = 10 +BLOCK_SIZE = 4 +HIDDEN_SIZE = 2 +DTYPE = torch.float32 +OTHER_DTYPE = torch.float16 +DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE]) + + +class MockInputBatch: + def __init__(self, num_computed_tokens_cpu): + self.req_ids = ["req1", "req2"] + self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)} + self.num_computed_tokens_cpu = num_computed_tokens_cpu + # Block table is only mocked for validation of length; + # we don't actually need to add valid values here since + # we patch the table when testing. + self.block_table = Mock() + self.block_table.block_tables = [None] + + +def get_omni_pcache_with_mm_tensors(feat_dims, seq_len) -> OmniTensorPrefixCache: + """Build an OmniTensorPrefixCache and init mm tensors.""" + cache = get_omni_pcache() + mm_outputs = get_multimodal_outputs(feat_dims, seq_len) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, seq_len) + return cache + + +def get_omni_pcache() -> OmniTensorPrefixCache: + """Build an OmniTensorPrefixCache, but don't init mm tensors.""" + cache = OmniTensorPrefixCache( + num_blocks=NUM_BLOCKS, + block_size=BLOCK_SIZE, + hidden_size=HIDDEN_SIZE, + hs_dtype=DTYPE, + ) + return cache + + +def get_multimodal_outputs(feat_dims: dict[str, int], seq_len: int) -> dict[str, torch.Tensor]: + fake_mm_inputs = {} + for mm_key, feat_dim in feat_dims.items(): + fake_mm_inputs[mm_key] = torch.rand((seq_len, feat_dim), dtype=DTYPE) + return fake_mm_inputs + + +### Tests for initialization +def test_initialization_simple(): + """Check default initialization only creates the hidden states.""" + cache = get_omni_pcache() + assert isinstance(cache.hidden_states_cache, torch.Tensor) + assert cache.hidden_states_cache.shape == DEFAULT_SHAPE + assert len(cache.mm_outputs_cache) == 0 + assert len(cache.mm_cache_keys) == 0 + + +def test_initialization_with_multimodal(): + """Check initialization + registration of multimodal outputs.""" + cache = get_omni_pcache() + feat_dims = {"foo": 100, "bar": 50, "baz": 10} + mm_outputs = get_multimodal_outputs( + feat_dims, + seq_len=DEFAULT_SEQ_LEN, + ) + # Cast one of the keys to a different dtype; the dtype of the tensor + # that is used to initialize the cache dictates the cache dtype. + mm_outputs["foo"] = mm_outputs["foo"].to(OTHER_DTYPE) + + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 3 + assert set(cache.mm_cache_keys) == set(feat_dims.keys()) + for mm_key in cache.mm_cache_keys: + cache_tensor = cache.mm_outputs_cache[mm_key] + assert isinstance(cache_tensor, torch.Tensor) + assert cache_tensor.shape[-1] == feat_dims[mm_key] + assert mm_outputs[mm_key].dtype == cache_tensor.dtype + + +def test_init_missing_mm_cache_keys_is_idempotent(): + """Ensure that the cache doesn't reinitialize old keys.""" + cache = get_omni_pcache() + mm_key = "foo" + feat_dims = {mm_key: 100} + mm_outputs = get_multimodal_outputs( + feat_dims, + seq_len=DEFAULT_SEQ_LEN, + ) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 1 + assert mm_key in cache.mm_cache_keys + + # Cache is initialized to 0 - fill it with 1s + cache.mm_outputs_cache[mm_key].fill_(1) + + # Ensure that running another initialization + # doesn't zero out our cache values + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 1 + assert mm_key in cache.mm_cache_keys + assert torch.all(cache.mm_outputs_cache[mm_key] == 1) + + +### Tests for Update +def test_update_no_multimodal(): + """Test that slot mappings act as row indices hidden states.""" + cache = get_omni_pcache() + + num_tokens_unpadded = 8 + slot_offset = 8 + slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) + new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE) + + cache.update_omni_tensor_prefix_cache( + hidden_states=new_hidden_states, + multimodal_outputs=None, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=slot_mapping, + ) + + # Ensure that if we reshape our 3D cache back to 2D, we can use the + # indices in our slot mappings to access the hidden states as expected + hs_rows = cache.hidden_states_cache.view(NUM_BLOCKS * BLOCK_SIZE, HIDDEN_SIZE) + for slot_idx, new_states in zip(slot_mapping, new_hidden_states): + slot_states = hs_rows[slot_idx] + assert torch.all(slot_states == new_states) + + +@pytest.mark.parametrize( + "feat_dims", + [ + {"foo": 100, "bar": 100}, + {"foo": 100, "bar": 50, "baz": 10}, + ], +) +def test_update_with_multimodal_outputs(feat_dims): + """Test that slot mappings are correct for multimodal tensors.""" + cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN) + + num_tokens_unpadded = 8 + slot_offset = 8 + slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) + feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} + mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys} + cache.update_omni_tensor_prefix_cache( + hidden_states=None, + multimodal_outputs=mm_outputs, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=slot_mapping, + ) + + for mm_key in feat_dims.keys(): + assert mm_key in cache.mm_outputs_cache + key_feat_dim = feature_dims[mm_key] + mm_state_rows = cache.mm_outputs_cache[mm_key].view(NUM_BLOCKS * BLOCK_SIZE, key_feat_dim) + + # Similar to hidden states, but for each key in the dict; + # Different tensors may have different feature dims + new_mm_outputs = mm_outputs[mm_key] + for slot_idx, new_output in zip(slot_mapping, new_mm_outputs): + slot_states = mm_state_rows[slot_idx] + assert torch.all(slot_states == new_output) + + +### Tests for Merging +def fake_get_cached_block_ids(self, req_idx, *args, **kwargs): + """Fake block table lookup. + + Assumption: + req_idx 0 is a cache hit with slots 8, 9, ..., 15 + req_idx 1 is a cache miss + """ + assert req_idx < 2 + if req_idx == 0: + # With the slot offset we provided (8), the corresponding + # blocks IDs are 2 & 3 because the block size is 4. + return torch.tensor([2, 3], dtype=torch.long) + return torch.tensor([], dtype=torch.long) + + +@pytest.mark.parametrize("num_tokens_padded", [None, 16]) +def test_get_merged_hidden_states(num_tokens_padded): + """Ensure that hidden states are merged correctly.""" + cache = get_omni_pcache() + + orig_num_tokens_unpadded = 8 + slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 + orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) + orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE) + + cache.update_omni_tensor_prefix_cache( + hidden_states=orig_hidden_states, + multimodal_outputs=None, + num_tokens_unpadded=orig_num_tokens_unpadded, + slot_mapping=orig_slot_mapping, + num_tokens_padded=num_tokens_padded, + ) + + # Say that we have two requests, but only one of them is a cache hit + num_new_toks_req1 = 3 + num_new_toks_req2 = 2 + cache.add_prefix_cached_new_req_id("req1") + + num_scheduled_tokens = { + "req1": num_new_toks_req1, + "req2": num_new_toks_req2, + } + new_hidden_states = torch.rand( + (num_new_toks_req1 + num_new_toks_req2, HIDDEN_SIZE), + dtype=DTYPE, + ) + req1_new_states = new_hidden_states[:num_new_toks_req1] + req2_new_states = new_hidden_states[-num_new_toks_req2:] + + input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0])) + + with patch( + "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids", + new=fake_get_cached_block_ids, + ): + merged_states = cache.get_merged_hidden_states( + query_start_loc=[0, num_new_toks_req1], + input_batch=input_batch, + hidden_states=new_hidden_states, + num_scheduled_tokens=num_scheduled_tokens, + ) + + assert "req1" in merged_states and "req2" in merged_states + req1_merged_states = merged_states["req1"] + req2_merged_states = merged_states["req2"] + + # First, check the cache hit case + assert req1_merged_states.shape == torch.Size([orig_num_tokens_unpadded + num_new_toks_req1, HIDDEN_SIZE]) + # Ensure that the req1 merged states are the cached states + the new req1 states + assert torch.all(req1_merged_states[:orig_num_tokens_unpadded] == orig_hidden_states) + assert torch.all(req1_merged_states[-num_new_toks_req1:] == req1_new_states) + + # Next, ensure that the cache miss case only has the new states + assert req2_merged_states.shape == torch.Size([num_new_toks_req2, HIDDEN_SIZE]) + assert torch.all(req2_merged_states == req2_new_states) + + +@pytest.mark.parametrize("num_tokens_padded", [None, 16]) +@pytest.mark.parametrize( + "feat_dims", + [ + {"foo": 100, "bar": 100}, + {"foo": 100, "bar": 50, "baz": 10}, + ], +) +def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded): + cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN) + + orig_num_tokens_unpadded = 8 + slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 + orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) + feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} + orig_mm_outputs = { + key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys + } + + cache.update_omni_tensor_prefix_cache( + hidden_states=None, + multimodal_outputs=orig_mm_outputs, + num_tokens_unpadded=orig_num_tokens_unpadded, + slot_mapping=orig_slot_mapping, + num_tokens_padded=num_tokens_padded, + ) + + # Similar to hs test- say that we have two requests, but only one of them is a cache hit + num_new_toks_req1 = 3 + num_new_toks_req2 = 2 + cache.add_prefix_cached_new_req_id("req1") + + num_scheduled_tokens = { + "req1": num_new_toks_req1, + "req2": num_new_toks_req2, + } + + new_mm_outputs = {} + for mm_key in cache.mm_cache_keys: + new_mm_outputs[mm_key] = torch.rand( + (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]), + dtype=DTYPE, + ) + # We also want to make sure passthrough data (outside of our keys) isn't dropped + new_mm_outputs["passthrough_data"] = "Something else" + # Lists are a special case because we can't split them yet if we want to match + # the nonprefix cache behavior, because this runs before post process. + new_mm_outputs["passthrough_list"] = ["should", "not", "split"] + + input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0])) + + with patch( + "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids", + new=fake_get_cached_block_ids, + ): + merged_mm_outputs = cache.get_merged_multimodal_states( + query_start_loc=[0, num_new_toks_req1], + input_batch=input_batch, + multimodal_outputs=new_mm_outputs, + num_scheduled_tokens=num_scheduled_tokens, + ) + + # Ensure the passthrough data wasn't dropped + assert "passthrough_data" in merged_mm_outputs + assert "passthrough_list" in merged_mm_outputs + + for mm_key, mm_output in merged_mm_outputs.items(): + # Ensure passthrough data is just forwarded normally and not duplicated + assert isinstance(mm_output, dict) + assert "req1" in mm_output and "req2" in mm_output + if mm_key == "passthrough_data": + assert mm_key not in cache.mm_cache_keys + assert new_mm_outputs[mm_key] == mm_output["req1"] + assert new_mm_outputs[mm_key] == mm_output["req2"] + elif mm_key == "passthrough_list": + assert mm_key not in cache.mm_cache_keys + assert new_mm_outputs[mm_key] == mm_output["req1"] + assert new_mm_outputs[mm_key] == mm_output["req2"] + else: + assert mm_key in cache.mm_cache_keys + curr_feat_dim = feature_dims[mm_key] + # Ensure that req1 (cache hit) merged the mm data + req1_merged_mm_outputs = mm_output["req1"] + req1_new_mm_outputs = new_mm_outputs[mm_key][:num_new_toks_req1] + + assert req1_merged_mm_outputs.shape == torch.Size( + [orig_num_tokens_unpadded + num_new_toks_req1, curr_feat_dim] + ) + # Ensure that the req1 merged mm data are the cached data + the new data + assert torch.all(req1_merged_mm_outputs[:orig_num_tokens_unpadded] == orig_mm_outputs[mm_key]) + assert torch.all(req1_merged_mm_outputs[-num_new_toks_req1:] == req1_new_mm_outputs) + + # Ensure that req2 (cache miss) only has the new mm data + req2_merged_mm_outputs = mm_output["req2"] + req2_new_mm_outputs = new_mm_outputs[mm_key][-num_new_toks_req2:] + + assert req2_merged_mm_outputs.shape == torch.Size([num_new_toks_req2, curr_feat_dim]) + assert torch.all(req2_merged_mm_outputs == req2_new_mm_outputs) diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index f4aabb8b957..c05f8f50674 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -23,11 +23,13 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] +QWEN3_OMNI_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml") +QWEN3_OMNI_XPU_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml") -def get_chunk_config(): +def get_chunk_config(config_path: str): path = modify_stage_config( - str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"), + config_path, updates={ "async_chunk": True, "stage_args": { @@ -44,15 +46,41 @@ def get_chunk_config(): return path +def get_prefix_caching_config(config_path: str): + """Create a stage config with prefix caching enabled on the thinker (stage 0).""" + path = modify_stage_config( + config_path, + updates={ + "stage_args": { + 0: {"engine_args.enable_prefix_caching": True}, + }, + }, + ) + return path + + if current_omni_platform.is_xpu(): - stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")] + stage_configs = [QWEN3_OMNI_XPU_CONFIG_PATH] + prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_XPU_CONFIG_PATH)] else: # MI325 GPU should share the same config as H100 - stage_configs = [get_chunk_config()] + stage_configs = [get_chunk_config(QWEN3_OMNI_CONFIG_PATH)] + prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_CONFIG_PATH)] # Create parameter combinations for model and stage config test_params = [ OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs ] +# For prefix caching, we need to enable prompt token details so that we +# can determine if any tokens were cached. +prefix_test_params = [ + OmniServerParams( + model=model, + stage_config_path=stage_config, + server_args=["--enable-prompt-tokens-details"], # Enable prompt tokens details to get cached_tokens + ) + for model in models + for stage_config in prefix_caching_stage_configs +] def get_system_prompt(): @@ -75,6 +103,7 @@ def get_prompt(prompt_type="text_only"): prompts = { "text_only": "What is the capital of China? Answer in 20 words.", "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + "text_image": "What color are the squares in this image?", } return prompts.get(prompt_type, prompts["text_only"]) @@ -147,3 +176,41 @@ def test_text_to_text_001(omni_server, openai_client) -> None: } openai_client.send_omni_request(request_config, request_num=get_max_batch_size()) + + +@pytest.mark.advanced_model +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True) +def test_thinker_prefix_caching(omni_server, openai_client) -> None: + """ + Test thinker prefix caching by sending identical requests with an image (i.e., + a large shared prefix) and verifying that the second request uses cached tokens + & produces the same output. + """ + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + image_data_url=image_data_url, + content_text=get_prompt("text_image"), + ) + + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": False, + "modalities": ["text"], + } + + response_1 = openai_client.send_omni_request(request_config, request_num=1)[0] + response_2 = openai_client.send_omni_request(request_config, request_num=1)[0] + + assert response_1.success + assert response_2.success + assert response_2.cached_tokens is not None + # We should cache the vast majority of the prompt (image + up to last full block), + # and set seed in the CI config, so the second request should give an identical + # response for the generated input image, even if we use dummy weights + assert response_2.cached_tokens > 0 + assert response_1.text_content == response_2.text_content diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py new file mode 100644 index 00000000000..69e7346c4c1 --- /dev/null +++ b/vllm_omni/core/prefix_cache.py @@ -0,0 +1,264 @@ +""" +Utilities for Prefix Caching in Omni models. +""" + +import torch +from vllm.logger import init_logger +from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element + +logger = init_logger(__name__) + + +class OmniTensorPrefixCache: + """Prefix cache for hidden states (model outputs) and model specific + multimodal outputs. + + This class implements prefix caching in a non-invasive way on top of + vLLM by leveraging the same slot mappings that the vLLM scheduler uses + for the KV Cache. + + Conceptually, this means we are mapping vLLM's cache mapping: + (num_blocks, block_size) + + to 3D tensors of shape: + (num_blocks, block_size, feature_size) + + Note that feature_size may vary across multimodal_outputs. + """ + + def __init__( + self, + num_blocks: int, + block_size: int, + hidden_size: int, + hs_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.default_hidden_size = hidden_size + + # Initialize the hidden states cache immediately + self.hidden_states_cache = self._get_cache_tensor(dtype=hs_dtype) + + # Defer initialization of the mm_outputs_cache until we + # actually see mm output tensors dependent on num tokens. + self.mm_outputs_cache = {} + self.mm_cache_keys = set() + self._new_req_cache_hit_ids: set[str] = set() + + def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: int): + """Given multimodal outputs from executing the model, dynamically + determine which multimodal outputs are tensors depending on sequence + length and should be cached, and initialize the cache tensors + accordingly. + + NOTE: This is done to avoid the need for explicit specification of + cache keys for every model/stage and aligns with the current way + that we slice the multimodal outputs based on the first dimension. + + This will usually be called by the first forward pass, i.e., + determined by the warmup. + """ + for key, val in multimodal_outputs.items(): + if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys: + feat_dim = val.shape[-1] + self.mm_outputs_cache[key] = self._get_cache_tensor( + dtype=val.dtype, + hidden_size=feat_dim, + ) + self.mm_cache_keys.add(key) + new_tensor_shape = self.mm_outputs_cache[key].shape + logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key) + + def _get_cache_tensor(self, dtype: torch.dtype, hidden_size: int | None = None) -> torch.Tensor: + """Allocate a CPU cache tensor for a specific key.""" + actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size + return torch.zeros( + (self.num_blocks, self.block_size, actual_hidden_size), + dtype=dtype, + device="cpu", + ) + + def add_prefix_cached_new_req_id(self, req_id: str): + """Adds a new request ID to the set of prefix cache hits on the batch.""" + self._new_req_cache_hit_ids.add(req_id) + + def reset_prefix_cached_new_req_ids(self): + """Clears the cache hit IDs to prepare for a new engine step.""" + self._new_req_cache_hit_ids.clear() + + @staticmethod + def _coerce_to_cpu_tensor(maybe_gpu_tensor: torch.Tensor) -> torch.Tensor: + """Convert GPU tensors -> contiguous CPU tensors if needed.""" + return maybe_gpu_tensor.detach().cpu().contiguous() + + def update_omni_tensor_prefix_cache( + self, + hidden_states: torch.Tensor | None, + multimodal_outputs: dict[str, torch.Tensor] | None, + num_tokens_unpadded: int, + slot_mapping: torch.Tensor, + num_tokens_padded: int | None = None, + ): + """Updates the hidden cache state for the provided hidden states and multimodal outputs. + + Args: + hidden_states: Hidden states tensor to cache (if any) + multimodal_outputs: Multimodal dict whose tensors may be cached + num_tokens_unpadded: Number of tokens without padding + slot_mapping: Slot mapping for the input sequence + num_tokens_padded: Total number of tokens including padding + """ + unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded + + if hidden_states is not None: + # Slice to unpadded portion before caching + hidden_states = hidden_states[:num_tokens_unpadded] + # Ensure that hidden states are on the CPU + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) + # View the cache as 2D so that we can treat our slots as row indices + flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = hidden_states + logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded) + + # Do the same for the stage's cached multimodal outputs + if multimodal_outputs is not None: + # If we haven't initialized the keys already, do it now + # We check against the padded token count since we haven't sliced yet + self.maybe_init_missing_mm_cache_keys( + multimodal_outputs, + seq_len=num_tokens_padded, + ) + + for mm_out_key, mm_cache in self.mm_outputs_cache.items(): + if mm_out_key in multimodal_outputs: + # Slice to unpadded portion before caching + mm_state = multimodal_outputs[mm_out_key][:num_tokens_unpadded] + mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) + flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = mm_state + logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) + + def _coerce_to_payload_dict( + self, + element: object, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, object]: + """Build the multimodal passthrough data per request for + the object under consideration. This is identical to the case + for no prefix cache when we tensor does have a first dimension + matching the seq len. + """ + elem_dict = {} + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + start = query_start_loc[req_idx] + end = start + num_scheduled_tokens[req_id] + elem_dict[req_id] = to_payload_element( + element, req_idx, start=start, end=end, pass_lists_through=True, seq_len=None + ) + return elem_dict + + def get_merged_multimodal_states( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], + ): + """Get the merged multimodal states if hidden state prefix caching is enabled.""" + combined_multimodal_outputs = {} + # First get the prefix cached tensors that are present in the mm data + for mm_key in self.mm_cache_keys: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) + + # Then, get everything else (passthrough data); first, convert to CPU + # tensors similarly to the non prefix cached path, and then populate + # the subdicts mapping request IDs -> payload objects + passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys + passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys} + mm_cpu = build_mm_cpu(multimodal_outputs=passthrough_mm_data) + + for mm_key, mm_val in mm_cpu.items(): + combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict( + element=mm_val, + query_start_loc=query_start_loc, + input_batch=input_batch, + num_scheduled_tokens=num_scheduled_tokens, + ) + return combined_multimodal_outputs + + def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]: + """Get the merged hidden states.""" + return self._get_merged_tensors( + *args, + **kwargs, + cache=self.hidden_states_cache, + ) + + def _get_merged_tensors( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + cache: torch.Tensor, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, torch.Tensor]: + """When hidden state caching is enabled, takes the input hidden_states, + which only correspond to the scheduled tokens, and returns a mapping + from request IDs to their full hidden states. This is accomplished by + looking up the block IDs & scheduled token counts to split the + hidden_states. + """ + # We do not support hybrid caches at the moment. + if len(input_batch.block_table.block_tables) > 1: + logger.warning_once( + "Omni prefix caching is enabled, but the batch block table appears to" + " have multiple kv groups; only the first group will be used!" + ) + + combined_hidden_states = {} + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + + if req_id in self._new_req_cache_hit_ids: + block_ids = self._get_cached_block_ids(req_idx, input_batch) + cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) + + # Slice the hidden states corresponding to this request; + # we do this by using the query start + start = query_start_loc[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) + else: + # cache miss for this request, pass through normally + start = query_start_loc[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + combined_hidden_states[req_id] = new_hs + + return combined_hidden_states + + def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch.Tensor: + """Given an input batch and request index in the batch (not ID), get the + block IDs corresponding to the cache hit. + """ + num_computed = input_batch.num_computed_tokens_cpu[req_idx] + # NOTE: vLLM only caches full blocks + num_cached_blocks = num_computed // self.block_size + # Get the block IDs attached to this cache hit and reindex into + # the flattened cached hidden states (i.e., 1 row per token). + return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py new file mode 100644 index 00000000000..66d4e6ffe04 --- /dev/null +++ b/vllm_omni/utils/mm_outputs.py @@ -0,0 +1,93 @@ +"""Utilities for handling multimodal outputs / building multimodal output +payloads, most of which are shared by the prefix cache / no prefix cache path. +""" + +import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: + """Pre-copies multimodal tensor to CPU once (not per-request) to avoid + redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + + In the case of prefix caching, the multimodal outputs provided will + only contain the passthrough data. + + Args: + multimodal_outputs: Multimodal dict mapping strings to objects. + """ + # Pre-copy multimodal tensors to CPU once (not per-request) to avoid + # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + mm_cpu: dict[str, object] = {} + # Currently there are some cases where this is true at the + # moment, which should be fixed. + if not isinstance(multimodal_outputs, dict): + logger.warning("Multimodal outputs are not a dict and will not be passed") + + if multimodal_outputs: + for k, v in multimodal_outputs.items(): + if isinstance(v, torch.Tensor): + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor): + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list) and len(v) > 0: + cpu_list = [] + for elem in v: + if isinstance(elem, torch.Tensor): + cpu_list.append(elem.detach().to("cpu").contiguous()) + else: + cpu_list.append(elem) + mm_cpu[k] = cpu_list + elif v is not None: + mm_cpu[k] = v + return mm_cpu + + +def to_payload_element( + element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None +): + """Build an mm payload element corresponding to one request index + from an element containing 0 or more CPU tensors. + + Args: + element: The object to be added to the payload. + idx: The index of the request. + start: The start index corresponding to the request idx. + end: The end index corresponding to the request idx. + pass_lists_through: bool Whether or not lists should be treated as + passthrough data; this should be False in normal cases, but True + if we need to avoid splitting nonempty lists prior to calling + postprocess, which is the case for prefix cache. + seq_len: Optional sequence length (i.e., dim 0 of hidden states). + This should be set to None in the prefix caching case, because + the condition that would be executed here is the same as the + criteria for being added to the multimodal outputs cache. + """ + # Prefix cache won't hit this case because this is the condition + # for being a mm_cache_key in the multimodal outputs tensor. + if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len: + return element[start:end].contiguous() + # Every other case is shared between prefix cache (passthrough data) + # and running a model without prefix caching. + elif isinstance(element, dict): + return {sk: sv[start:end].contiguous() for sk, sv in element.items()} + elif isinstance(element, list): + # For lists, clone tensors to avoid cross-request aliasing + if pass_lists_through: + return [elem.clone() if isinstance(elem, torch.Tensor) else elem for elem in element] + element = element[idx] if idx < len(element) else element[0] + if isinstance(element, torch.Tensor): + element = element.clone() + return element + elif isinstance(element, torch.Tensor): + # List-derived tensor payloads are request-invariant; clone to + # avoid accidental cross-request aliasing on downstream mutation. + return element.clone() + return element diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 62a0c857164..f37b2224efb 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -39,6 +39,7 @@ from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin @@ -201,6 +202,63 @@ def _capture_talker_mtp_graphs(self) -> None: finally: set_cudagraph_capturing_enabled(False) + def _maybe_update_prefix_cache( + self, + hidden_states: torch.Tensor, + multimodal_outputs: dict, + num_tokens_unpadded: int, + num_tokens_padded: int, + ): + """If prefix caching is enabled and it's the last pipeline parallelism rank, + retrieve the hidden states & multimodal outputs from the prefix cache based + on our batch slot mappings. + """ + # Cache hidden states if we've enabled hidden state prefix caching + # unless this isn't the last pipeline parallelism rank. + if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: + # If this happens, it generally means the model is not following the correct + # interface yet and is therefore currently not compatible with prefix cache. + if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict): + logger.warning_once( + "prefix caching expects mm outputs to be a dict, but got %s", + type(multimodal_outputs), + ) + + self.omni_prefix_cache.update_omni_tensor_prefix_cache( + hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, + num_tokens_padded=num_tokens_padded, + ) + + def _maybe_get_combined_prefix_cache_tensors( + self, + hidden_states: torch.Tensor, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], + ) -> tuple[dict[str, torch.Tensor] | None, dict | None]: + """If prefix caching is enabled, extract the merged hidden states and multimodal outputs for + all requests in the batch (including those that aren't a hit on Prefix cache). + """ + # Prior to applying the post-processing func, extract + # the prefix cached hidden states and multimodal states. + combined_hidden_states, combined_multimodal_outputs = None, None + if self.omni_prefix_cache is not None: + combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( + query_start_loc=self.query_start_loc.cpu, + input_batch=self.input_batch, + hidden_states=hidden_states, + num_scheduled_tokens=num_scheduled_tokens, + ) + combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( + query_start_loc=self.query_start_loc.cpu, + input_batch=self.input_batch, + multimodal_outputs=multimodal_outputs, + num_scheduled_tokens=num_scheduled_tokens, + ) + return combined_hidden_states, combined_multimodal_outputs + @torch.inference_mode() def execute_model( self, @@ -476,6 +534,15 @@ def execute_model( hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output) + # Cache hidden states & multimodal outputs if we've enabled hidden state + # prefix caching unless this isn't the last pipeline parallelism rank. + self._maybe_update_prefix_cache( + hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + num_tokens_unpadded=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded, + ) + if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -589,6 +656,23 @@ def _sample( return super()._sample(logits, spec_decode_metadata) + @staticmethod + def _resolve_req_hidden_states( + hidden_states_cpu: torch.Tensor, + combined_hidden_states: dict[str, torch.Tensor] | None, + rid: str, + start: int, + end: int, + ): + if combined_hidden_states is not None: + # We always have all request IDs for prefix cache, even for + # partial cache misses, so this should never happen. + if rid not in combined_hidden_states: + raise RuntimeError("Request IDs in the batch are missing from the merged states!") + return combined_hidden_states[rid] + # Prefix caching is disabled + return hidden_states_cpu[start:end] + @torch.inference_mode() def sample_tokens( self, @@ -597,6 +681,13 @@ def sample_tokens( kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) self.kv_extracted_req_ids = None + # Used for prefix cache + combined_hidden_states = None + combined_multimodal_outputs = None + # Used when we don't use prefix cache; prefix cache builds the payloads + # internally since it already needs to do this for the cached tensors + mm_cpu = {} + if self.execute_model_state is None: kv_connector_output = self.kv_connector_output self.kv_connector_output = None @@ -628,6 +719,7 @@ def sample_tokens( slot_mappings, # OMNI: unpack slot_mappings for drafter ) = self.execute_model_state self.execute_model_state = None + seq_len = hidden_states.shape[0] # Apply structured output bitmasks if present. if grammar_output is not None: @@ -749,67 +841,73 @@ def propose_draft_token_ids(sampled_token_ids): dtype=np.int32, ) + # Prior to applying the post-processing func, extract + # the prefix cached hidden states and multimodal states. + if self.omni_prefix_cache is not None: + ( + combined_hidden_states, + combined_multimodal_outputs, + ) = self._maybe_get_combined_prefix_cache_tensors( + hidden_states, + multimodal_outputs, + scheduler_output.num_scheduled_tokens, + ) + # Otherwise we don't have the mm CPU data yet, so we still need to build it + if self.omni_prefix_cache is None: + mm_cpu = build_mm_cpu(multimodal_outputs) + self._process_additional_information_updates( - hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + hidden_states, + multimodal_outputs, + num_scheduled_tokens_np, + scheduler_output, + combined_hidden_states, + combined_multimodal_outputs, ) - # Pre-copy multimodal tensors to CPU once (not per-request) to avoid - # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. - mm_cpu: dict[str, object] = {} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: - for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list): - if len(v) == 0: - continue - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") - pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: idx = req_id_to_index_output_copy[rid] start = int(self.query_start_loc.cpu[idx]) sched = int(num_scheduled_tokens_np[idx]) end = start + sched - hidden_slice = hidden_states_cpu[start:end] - payload: dict[str, object] = {"hidden": hidden_slice} - if mm_cpu: - mm_payload: dict[str, object] = {} - for k, v in mm_cpu.items(): - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_payload[k] = v[start:end].contiguous() - elif isinstance(v, dict): - mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} - elif isinstance(v, list): - element = v[idx] if idx < len(v) else v[0] - if element is not None: - if isinstance(element, torch.Tensor): - element = element.clone() - mm_payload[k] = element - # Skip None elements: msgspec cannot serialize None - # in dict[str, torch.Tensor] typed fields. - elif isinstance(v, torch.Tensor): - # List-derived tensor payloads are request-invariant; clone to - # avoid accidental cross-request aliasing on downstream mutation. - mm_payload[k] = v.clone() - else: - mm_payload[k] = v + # If prefix cache is enabled, we have already split everything + # by request and converted the states to CPU tensors + req_hidden_states = self._resolve_req_hidden_states( + hidden_states_cpu, + combined_hidden_states, + rid, + start, + end, + ) + payload: dict[str, object] = {"hidden": req_hidden_states} + + mm_payload: dict[str, object] = {} + if combined_multimodal_outputs or mm_cpu: + if combined_multimodal_outputs: + # Prefix cache enabled; all items have already been processed + # and split apart for each request as needed, and all tensors + # have already been detached to the CPU. The only exception is + # lists, which we keep as passthrough data for consistent behavior + # in postprocess. + for mm_key in combined_multimodal_outputs.keys(): + value = combined_multimodal_outputs[mm_key][rid] + if isinstance(value, list): + mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] + else: + mm_payload[mm_key] = value + + else: + # Prefix cache disabled; we still need to process the data + for mm_key, mm_val in mm_cpu.items(): + mm_payload[mm_key] = to_payload_element( + element=mm_val, + idx=idx, + start=start, + end=end, + pass_lists_through=False, + seq_len=seq_len, + ) payload.update(mm_payload) pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 5ff62c11b40..de78011c75a 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -20,6 +20,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm_omni.core.prefix_cache import OmniTensorPrefixCache from vllm_omni.engine.serialization import deserialize_additional_information from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -43,6 +44,9 @@ def __init__(self, *args, **kwargs): self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None + # The Omni tensor prefix cache will be allocated + # when we initialize the metadata builders if enabled + self.omni_prefix_cache = None def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): """Override to fix scheduler_metadata buffer size for FA3 + CUDA graph. @@ -70,6 +74,16 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): device=sm.device, ) + # Initialize the wrapper for both multimodal output tensors + # and for hidden states to be passed between stages + if self.cache_config.enable_prefix_caching: + self.omni_prefix_cache = OmniTensorPrefixCache( + num_blocks=kv_cache_config.num_blocks, + block_size=self.cache_config.block_size, + hidden_size=self.model_config.get_hidden_size(), + hs_dtype=self.dtype, + ) + @instrument(span_name="Loading (GPU)") def load_model(self, *args, **kwargs) -> None: super().load_model(*args, **kwargs) @@ -234,6 +248,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): The SamplingMetadata is updated and copied to the GPU if there is a new/resumed/paused/finished request in the batch. """ + # Used for prefix cache + if self.omni_prefix_cache is not None: + self.omni_prefix_cache.reset_prefix_cached_new_req_ids() + # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) @@ -294,6 +312,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): reqs_to_add.append(req_state) continue + # Since this is the first time the request has been scheduled, + # num_computed_tokens > 0 means that we have a hit in prefix + # caching; mark it so that we can manage the hidden states + # later on as needed. + if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0: + self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id) + sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params @@ -1010,6 +1035,8 @@ def _process_additional_information_updates( multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", + combined_hidden_states: dict[str, torch.Tensor] | None = None, + combined_multimodal_outputs: dict[str, object] | None = None, ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: @@ -1018,21 +1045,31 @@ def _process_additional_information_updates( if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) - start_offset = int(self.query_start_loc.cpu[req_index]) - sched_tokens = int(num_scheduled_tokens_np[req_index]) - s, e = start_offset, start_offset + sched_tokens - # only consider to store data into update dict. - hidden_states_slice = hidden_states[s:e] + if combined_hidden_states: + # Combined hidden states contains all hidden states for every request + hidden_states_slice = combined_hidden_states[req_id] + else: + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + # only consider to store data into update dict. + hidden_states_slice = hidden_states[s:e] + + if combined_multimodal_outputs: + # NOTE this is a bit ugly, but the mm data is structured as a list of + # keys mapping to request IDs, and if enabled, we will always have all + # request IDs in every subdict, including for cache misses. + mm_out = {k: v[req_id] for k, v in combined_multimodal_outputs.items()} + else: + mm_out = multimodal_outputs update_dict = self.model.postprocess( - hidden_states_slice, multimodal_outputs=multimodal_outputs, **req_infos + hidden_states_slice, + multimodal_outputs=mm_out, + **req_infos, ) self._update_intermediate_buffer(req_id, update_dict) except Exception as e: - logger.error( - f"Error merging for requests:{self.input_batch.req_ids} " - f"additional information update: {e}, with the multimodal_outputs " - f"as {multimodal_outputs}" - ) + logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}") import traceback traceback.print_exc()