From fc6b5e62170efff8696a9acd49ccc98641d2c9a4 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 23 Mar 2026 22:47:23 +0000 Subject: [PATCH 01/38] add external hidden state cache Signed-off-by: Alex Brooks --- vllm_omni/worker/gpu_ar_model_runner.py | 36 ++++++++++++- vllm_omni/worker/gpu_model_runner.py | 68 +++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 62a0c857164..4ae9c63091a 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -201,6 +201,16 @@ def _capture_talker_mtp_graphs(self) -> None: finally: set_cudagraph_capturing_enabled(False) + def update_hidden_state_cache(self, hidden_states, num_tokens_unpadded): + """Updates the hidden cache state for prefix caching from + the current model's execution for the unpadded tokens. + """ + assert self.hidden_state_cache is not None + slot_mapping = self.input_batch.block_table[0].slot_mapping.gpu[:num_tokens_unpadded] + # View the cache as 2D so that we can treat our slots as row indices + flat_cache = self.hidden_state_cache.view(-1, self.hidden_state_cache.shape[-1]) + flat_cache[slot_mapping] = hidden_states[:num_tokens_unpadded] + @torch.inference_mode() def execute_model( self, @@ -476,6 +486,14 @@ def execute_model( hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output) + # Cache hidden states if we've enabled hidden state prefix caching + # unless this isn't the last pipeline parallelism rank. + if self.cache_hidden_states and self.hidden_state_cache is not None and get_pp_group().is_last_rank: + self.update_hidden_state_cache( + hidden_states=hidden_states, + num_tokens_unpadded=num_tokens_unpadded, + ) + if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -749,6 +767,14 @@ def propose_draft_token_ids(sampled_token_ids): dtype=np.int32, ) + # Prior to applying the post-processing func, extract + # the prefix cached hidden states if it's enabled and + # we have them. + combined_hidden_states = self._get_merged_hidden_states( + hidden_states=hidden_states, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ) + self._process_additional_information_updates( hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output ) @@ -787,8 +813,14 @@ def propose_draft_token_ids(sampled_token_ids): start = int(self.query_start_loc.cpu[idx]) sched = int(num_scheduled_tokens_np[idx]) end = start + sched - hidden_slice = hidden_states_cpu[start:end] - payload: dict[str, object] = {"hidden": hidden_slice} + # For prefix cache on hidden states - if it's a request + # in the combined hidden states, it's a cache hit, so we + # send the states that were already merged. + if combined_hidden_states and rid in combined_hidden_states: + req_hidden_states = combined_hidden_states[rid] + else: + req_hidden_states = hidden_states_cpu[start:end] + payload: dict[str, object] = {"hidden": req_hidden_states} if mm_cpu: mm_payload: dict[str, object] = {} for k, v in mm_cpu.items(): diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 5ff62c11b40..35cb7cec463 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -43,6 +43,10 @@ def __init__(self, *args, **kwargs): self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None + # TODO add another gate for this + self.cache_hidden_states = self.cache_config.enable_prefix_caching + self.hidden_state_cache: torch.Tensor | None = None + self._new_req_cache_hit_ids: set[str] | None = None def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): """Override to fix scheduler_metadata buffer size for FA3 + CUDA graph. @@ -69,6 +73,60 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): dtype=sm.dtype, device=sm.device, ) + # If we've enabled hidden state prefix caching, preallocate + # the cache; this is currently only used for the output hidden + # states of the stage. + # NOTE: hidden states are large; careful to ensure things are + # safely handled for memory in the future if this is expanded. + if self.cache_hidden_states: + # We don't handle hybrid cache types for hidden state caching yet; + # This does matter since + if len(kv_cache_config.kv_cache_groups) > 1: + logger.warning("Hidden state caching with multiple KV cache groups not yet supported, disabling.") + self.cache_hidden_states = False + else: + logger.info("Initializing hidden states prefix cache") + num_blocks = kv_cache_config.num_blocks + block_size = self.cache_config.block_size + hidden_size = self.model_config.get_hidden_size() + self.hidden_state_cache = torch.zeros( + (num_blocks, block_size, hidden_size), + dtype=self.dtype, + device=self.device, + ) + + def _get_merged_hidden_states(self, hidden_states, num_scheduled_tokens): + """When hidden state caching is enabled, takes the input hidden_states, + which only correspond to the scheduled tokens, and returns a mapping + from request IDs to their full hidden states. This is accomplished by + looking up the block IDs & scheduled token counts to split the + hidden_states. + + NOTE: We do not handle hybrid caches at the moment, which is why + we index into the first block table like this. + """ + combined_hidden_states = {} + if self.cache_hidden_states and self._new_req_cache_hit_ids: + assert self.hidden_state_cache + for req_id in self._new_req_cache_hit_ids: + req_idx = self.input_batch.req_id_to_index[req_id] + num_computed = self.input_batch.num_computed_tokens_cpu[req_idx] + block_size = self.cache_config.block_size + # NOTE: vLLM only caches full blocks + num_cached_blocks = num_computed // block_size + # Get the block IDs attached to this cache hit and reindex into + # the flattened cached hidden states (i.e., 1 row per token). + block_ids = self.input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] + cached_hs = self.hidden_state_cache[block_ids].reshape(-1, self.hidden_state_cache.shape[-1]) + + # Slice the hidden states corresponding to this request; + # we do this by using the query start + start = self.query_start_loc.gpu[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + # TODO: consider putting the actually hidden state cache on CPU + combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0).detach().to("cpu").contiguous() + + return combined_hidden_states @instrument(span_name="Loading (GPU)") def load_model(self, *args, **kwargs) -> None: @@ -234,6 +292,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): The SamplingMetadata is updated and copied to the GPU if there is a new/resumed/paused/finished request in the batch. """ + # Used for prefix cache + self._new_req_cache_hit_ids = set() + # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) @@ -294,6 +355,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): reqs_to_add.append(req_state) continue + # Since this is the first time the request has been scheduled, + # num_computed_tokens > 0 means that we have a hit in prefix + # caching; mark it so that we can manage the hidden states + # later on as needed. + if self.cache_hidden_states and new_req_data.num_computed_tokens > 0: + self._new_req_cache_hit_ids.add(req_id) + sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params From e1c8acd1f8deff2888c34dca9a2deb1b2e18ebd5 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 24 Mar 2026 01:55:22 +0000 Subject: [PATCH 02/38] hacks for mm output cache Signed-off-by: Alex Brooks --- vllm_omni/worker/gpu_ar_model_runner.py | 61 +++++++++++++++++++++++-- vllm_omni/worker/gpu_model_runner.py | 53 +++++++++++++++------ 2 files changed, 97 insertions(+), 17 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 4ae9c63091a..0aee608648d 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -201,7 +201,7 @@ def _capture_talker_mtp_graphs(self) -> None: finally: set_cudagraph_capturing_enabled(False) - def update_hidden_state_cache(self, hidden_states, num_tokens_unpadded): + def update_hidden_state_cache(self, hidden_states, multimodal_outputs, num_tokens_unpadded): """Updates the hidden cache state for prefix caching from the current model's execution for the unpadded tokens. """ @@ -210,6 +210,19 @@ def update_hidden_state_cache(self, hidden_states, num_tokens_unpadded): # View the cache as 2D so that we can treat our slots as row indices flat_cache = self.hidden_state_cache.view(-1, self.hidden_state_cache.shape[-1]) flat_cache[slot_mapping] = hidden_states[:num_tokens_unpadded] + logger.info(f"[HS Cache WRITE] tokens={num_tokens_unpadded}") + + # Do the same for the cached multimodal outputs for this stage; + # for now we assume that all of the multimodal outputs cached + # are exactly the same size as the hidden states. + # TODO (Alex) make this more flexible. + if self.mm_outputs_cache is not None: + for mm_out_key, mm_cache in self.mm_outputs_cache.items(): + assert mm_out_key in multimodal_outputs + mm_state = multimodal_outputs[mm_out_key] + flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) + flat_cache[slot_mapping] = mm_state[:num_tokens_unpadded] + logger.info(f"[multimodal output Cache WRITE] tokens={num_tokens_unpadded}") @torch.inference_mode() def execute_model( @@ -491,6 +504,7 @@ def execute_model( if self.cache_hidden_states and self.hidden_state_cache is not None and get_pp_group().is_last_rank: self.update_hidden_state_cache( hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, ) @@ -771,12 +785,36 @@ def propose_draft_token_ids(sampled_token_ids): # the prefix cached hidden states if it's enabled and # we have them. combined_hidden_states = self._get_merged_hidden_states( + cache=self.hidden_state_cache, hidden_states=hidden_states, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) + # Do the same for multimodal outputs + # TODO (Alex) clean this up + combined_multimodal_outputs = {} + if ( + self._cacheable_mm_keys + and self.mm_outputs_cache + and multimodal_outputs + and isinstance(multimodal_outputs, dict) + ): + for mm_key in self._cacheable_mm_keys: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_hidden_states( + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ) + else: + logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + self._process_additional_information_updates( - hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + hidden_states, + multimodal_outputs, + num_scheduled_tokens_np, + scheduler_output, + combined_hidden_states, ) # Pre-copy multimodal tensors to CPU once (not per-request) to avoid @@ -817,14 +855,29 @@ def propose_draft_token_ids(sampled_token_ids): # in the combined hidden states, it's a cache hit, so we # send the states that were already merged. if combined_hidden_states and rid in combined_hidden_states: - req_hidden_states = combined_hidden_states[rid] + # TODO cleanup device management + req_hidden_states = combined_hidden_states[rid].detach().to("cpu").contiguous() else: req_hidden_states = hidden_states_cpu[start:end] payload: dict[str, object] = {"hidden": req_hidden_states} + + logger.info( + f"[HS] req={rid} hidden_shape={req_hidden_states.shape} " + f"cache_hit={rid in combined_hidden_states if combined_hidden_states else False}" + ) + if mm_cpu: mm_payload: dict[str, object] = {} for k, v in mm_cpu.items(): - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + if ( + combined_multimodal_outputs + and k in combined_multimodal_outputs + and rid in combined_multimodal_outputs[k] + ): + mm_payload[k] = combined_multimodal_outputs[k][rid].detach().to("cpu").contiguous() + logger.info(f"Cached mm key {k} | shape: {combined_multimodal_outputs[k][rid].shape}") + + elif isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: mm_payload[k] = v[start:end].contiguous() elif isinstance(v, dict): mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 35cb7cec463..4b8fe3afe5f 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -47,6 +47,12 @@ def __init__(self, *args, **kwargs): self.cache_hidden_states = self.cache_config.enable_prefix_caching self.hidden_state_cache: torch.Tensor | None = None self._new_req_cache_hit_ids: set[str] | None = None + # HACK - for testing qwen3omni + self._cacheable_mm_keys: set[str] | None = None + self.mm_outputs_cache: dict[str, torch.Tensor] | None = None + if self.model_config.model_stage == "thinker": + logger.warning("HACK - for now cacheable mm keys are hardcoded to 0 & 24 for the thinker phase") + self._cacheable_mm_keys = {"0", "24"} def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): """Override to fix scheduler_metadata buffer size for FA3 + CUDA graph. @@ -79,6 +85,10 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): # NOTE: hidden states are large; careful to ensure things are # safely handled for memory in the future if this is expanded. if self.cache_hidden_states: + num_blocks = kv_cache_config.num_blocks + block_size = self.cache_config.block_size + hidden_size = self.model_config.get_hidden_size() + # We don't handle hybrid cache types for hidden state caching yet; # This does matter since if len(kv_cache_config.kv_cache_groups) > 1: @@ -86,16 +96,24 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): self.cache_hidden_states = False else: logger.info("Initializing hidden states prefix cache") - num_blocks = kv_cache_config.num_blocks - block_size = self.cache_config.block_size - hidden_size = self.model_config.get_hidden_size() self.hidden_state_cache = torch.zeros( (num_blocks, block_size, hidden_size), dtype=self.dtype, device=self.device, ) + # NOTE: For now we just use a dict of tensors even though all tensors are + # the same shape, since in the future this may not be the case + if self._cacheable_mm_keys is not None: + self.mm_outputs_cache = {} + for k in self._cacheable_mm_keys: + logger.info("Initializing multimodal output cache for key: %s", k) + self.mm_outputs_cache[k] = torch.zeros( + (num_blocks, block_size, hidden_size), + dtype=self.dtype, + device=self.device, + ) - def _get_merged_hidden_states(self, hidden_states, num_scheduled_tokens): + def _get_merged_hidden_states(self, cache: torch.Tensor | None, hidden_states, num_scheduled_tokens): """When hidden state caching is enabled, takes the input hidden_states, which only correspond to the scheduled tokens, and returns a mapping from request IDs to their full hidden states. This is accomplished by @@ -106,8 +124,7 @@ def _get_merged_hidden_states(self, hidden_states, num_scheduled_tokens): we index into the first block table like this. """ combined_hidden_states = {} - if self.cache_hidden_states and self._new_req_cache_hit_ids: - assert self.hidden_state_cache + if cache is not None and self._new_req_cache_hit_ids: for req_id in self._new_req_cache_hit_ids: req_idx = self.input_batch.req_id_to_index[req_id] num_computed = self.input_batch.num_computed_tokens_cpu[req_idx] @@ -117,14 +134,20 @@ def _get_merged_hidden_states(self, hidden_states, num_scheduled_tokens): # Get the block IDs attached to this cache hit and reindex into # the flattened cached hidden states (i.e., 1 row per token). block_ids = self.input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] - cached_hs = self.hidden_state_cache[block_ids].reshape(-1, self.hidden_state_cache.shape[-1]) + cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) # Slice the hidden states corresponding to this request; # we do this by using the query start start = self.query_start_loc.gpu[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] # TODO: consider putting the actually hidden state cache on CPU - combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0).detach().to("cpu").contiguous() + combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) + + logger.info( + f"[Cache combine] req={req_id} cached_blocks={num_cached_blocks} " + f"cached hidden states shape={cached_hs.shape} " + f"new hidden states shape={new_hs.shape}" + ) return combined_hidden_states @@ -1078,6 +1101,7 @@ def _process_additional_information_updates( multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", + combined_hidden_states: torch.Tensor | None = None, ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: @@ -1086,11 +1110,14 @@ def _process_additional_information_updates( if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) - start_offset = int(self.query_start_loc.cpu[req_index]) - sched_tokens = int(num_scheduled_tokens_np[req_index]) - s, e = start_offset, start_offset + sched_tokens - # only consider to store data into update dict. - hidden_states_slice = hidden_states[s:e] + if combined_hidden_states and req_id in combined_hidden_states: + hidden_states_slice = combined_hidden_states[req_id] + else: + start_offset = int(self.query_start_loc.cpu[req_index]) + sched_tokens = int(num_scheduled_tokens_np[req_index]) + s, e = start_offset, start_offset + sched_tokens + # only consider to store data into update dict. + hidden_states_slice = hidden_states[s:e] update_dict = self.model.postprocess( hidden_states_slice, multimodal_outputs=multimodal_outputs, **req_infos ) From 60c78a3f33d19c2192ea7f1a90064c0ccffe7e17 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 27 Mar 2026 22:00:55 +0000 Subject: [PATCH 03/38] refactor combined mm states Signed-off-by: Alex Brooks --- vllm_omni/worker/gpu_ar_model_runner.py | 25 +++++-------------------- vllm_omni/worker/gpu_model_runner.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 0aee608648d..42f71aea5f1 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -782,32 +782,17 @@ def propose_draft_token_ids(sampled_token_ids): ) # Prior to applying the post-processing func, extract - # the prefix cached hidden states if it's enabled and - # we have them. + # the prefix cached hidden states and multimodal states. combined_hidden_states = self._get_merged_hidden_states( cache=self.hidden_state_cache, hidden_states=hidden_states, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) - # Do the same for multimodal outputs - # TODO (Alex) clean this up - combined_multimodal_outputs = {} - if ( - self._cacheable_mm_keys - and self.mm_outputs_cache - and multimodal_outputs - and isinstance(multimodal_outputs, dict) - ): - for mm_key in self._cacheable_mm_keys: - if mm_key in multimodal_outputs: - combined_multimodal_outputs[mm_key] = self._get_merged_hidden_states( - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - ) - else: - logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + combined_multimodal_outputs = self._get_merged_multimodal_states( + multimodal_outputs, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ) self._process_additional_information_updates( hidden_states, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 4b8fe3afe5f..20fdc571f84 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -113,6 +113,26 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): device=self.device, ) + def _get_merged_multimodal_states(self, multimodal_outputs, num_scheduled_tokens): + """Get the merged multimodal states if hidden state prefix caching is enabled.""" + combined_multimodal_outputs = {} + if ( + self._cacheable_mm_keys + and self.mm_outputs_cache + and multimodal_outputs + and isinstance(multimodal_outputs, dict) + ): + for mm_key in self._cacheable_mm_keys: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_hidden_states( + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) + else: + logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + return combined_multimodal_outputs + def _get_merged_hidden_states(self, cache: torch.Tensor | None, hidden_states, num_scheduled_tokens): """When hidden state caching is enabled, takes the input hidden_states, which only correspond to the scheduled tokens, and returns a mapping From 0ea3ec6eb9f3a4511c973c6f2cef964ab1c70c89 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 29 Mar 2026 04:59:00 +0000 Subject: [PATCH 04/38] refactor prefix cache to util class Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache_utils.py | 169 ++++++++++++++++++++++++ vllm_omni/worker/gpu_ar_model_runner.py | 25 ++-- vllm_omni/worker/gpu_model_runner.py | 62 +++------ 3 files changed, 199 insertions(+), 57 deletions(-) create mode 100644 vllm_omni/core/prefix_cache_utils.py diff --git a/vllm_omni/core/prefix_cache_utils.py b/vllm_omni/core/prefix_cache_utils.py new file mode 100644 index 00000000000..f0632ddb475 --- /dev/null +++ b/vllm_omni/core/prefix_cache_utils.py @@ -0,0 +1,169 @@ +""" +Utilities for Prefix Caching in Omni models. +""" + +import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + +""" +NOTE for tomorrow - working on pulling the equivalent functionality out here. also, we should make sure +that we clear new requests, since I don't think we currently are + make sure that we pass the combined +multimodal states since we are not right now +""" + +# TODO - Make this configurable and factor in the number +# of multimodal tensors. +NUM_GPU_BLOCKS = 2048 +# TODO Make this generic, these are specific for qwen3 omni. +OMNI_HS_CACHE_KEY = "_OMNI_HIDDEN_STATES_KEY" +MM_CACHE_KEYS = ["0", "24"] +CACHEABLE_KEYS = MM_CACHE_KEYS + [OMNI_HS_CACHE_KEY] + + +class OmniTensorPrefixCache: + """Prefix cache for hidden states (model outputs) + and model specific multimodal outputs. + + This class implements prefix caching in a non-invasive + way on top of vLLM by leveraging the same slot mappings + that the vLLM scheduler uses for the KV Cache + + Conceptually, we are vLLM's mapping from: + (num_blocks, block_size) + + and translate it to rows in the 3D tensor of shape: + (num_blocks, block_size, feature_size) + + The above is generally a large tensor, especially since + multiple multimodal output tensors may be cached. To reduce + GPU pressure, we implement this as a two-tier cache, where the + keeping the large tensor offloaded on CPU, and transferring hot + blocks to a GPU buffer of size + (num_gpu_blocks, block_size, feature_size) + while maintaining a layer of indirection mapping the indices of + the CPU slot mappings to the GPU slot mappings. + """ + + def __init__(self, num_blocks: int, block_size: int, hidden_size: int, dtype, device): + self.new_req_cache_hit_ids: set[str] | None = None + + self.block_size = block_size + # TODO: Support CPU offload and combine these. + self.omni_tensor_cache = { + cacheable_key: torch.zeros( + (num_blocks, block_size, hidden_size), + dtype=dtype, + device=device, + ) + for cacheable_key in CACHEABLE_KEYS + } + + self._new_req_cache_hit_ids: set[str] | None = None + + @property + def hidden_states_cache(self) -> torch.Tensor: + """Returns the hidden states cache.""" + return self.omni_tensor_cache[OMNI_HS_CACHE_KEY] + + @property + def mm_outputs_cache(self) -> dict[str, torch.Tensor]: + """Returns the model specific multimodal outputs cache.""" + return {k: v for k, v in self.omni_tensor_cache.items() if k != OMNI_HS_CACHE_KEY} + + def update_omni_tensor_prefix_cache(self, hidden_states, multimodal_outputs, num_tokens_unpadded, input_batch): + """Updates the hidden cache state for for hidden states and multimodal outputs.""" + assert self.hidden_states_cache is not None + slot_mapping = input_batch.block_table[0].slot_mapping.gpu[:num_tokens_unpadded] + # View the cache as 2D so that we can treat our slots as row indices + flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) + flat_cache[slot_mapping] = hidden_states[:num_tokens_unpadded] + logger.info(f"[HS Cache WRITE] tokens={num_tokens_unpadded}") + + # Do the same for the cached multimodal outputs for this stage; + # for now we assume that all of the multimodal outputs cached + # are exactly the same size as the hidden states. + # TODO (Alex) make this more flexible. + if self.mm_outputs_cache is not None: + for mm_out_key, mm_cache in self.mm_outputs_cache.items(): + assert mm_out_key in multimodal_outputs + mm_state = multimodal_outputs[mm_out_key] + flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) + flat_cache[slot_mapping] = mm_state[:num_tokens_unpadded] + logger.info(f"[multimodal output Cache WRITE] tokens={num_tokens_unpadded}") + + def _get_combined_states( + self, query_start_loc, input_batch, hidden_states, multimodal_outputs, num_scheduled_tokens + ): + combined_mm_states = self._get_merged_multimodal_states( + query_start_loc, input_batch, multimodal_outputs, num_scheduled_tokens + ) + combined_hidden_states = self._get_merged_hidden_states( + query_start_loc, input_batch, hidden_states, num_scheduled_tokens + ) + return combined_hidden_states, combined_mm_states + + def _get_merged_multimodal_states(self, query_start_loc, input_batch, multimodal_outputs, num_scheduled_tokens): + """Get the merged multimodal states if hidden state prefix caching is enabled.""" + combined_multimodal_outputs = {} + for mm_key in MM_CACHE_KEYS: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) + else: + logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + return combined_multimodal_outputs + + def _get_merged_hidden_states(self, query_start_loc, input_batch, hidden_states, num_scheduled_tokens): + return self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.hidden_states_cache, + hidden_states=hidden_states, + num_scheduled_tokens=num_scheduled_tokens, + ) + + def _get_merged_tensors( + self, query_start_loc, input_batch, cache: torch.Tensor, hidden_states: torch.Tensor, num_scheduled_tokens + ) -> dict[str, torch.Tensor]: + """When hidden state caching is enabled, takes the input hidden_states, + which only correspond to the scheduled tokens, and returns a mapping + from request IDs to their full hidden states. This is accomplished by + looking up the block IDs & scheduled token counts to split the + hidden_states. + + NOTE: We do not handle hybrid caches at the moment, which is why + we index into the first block table like this. + """ + combined_hidden_states = {} + if cache is not None and self._new_req_cache_hit_ids: + for req_id in self._new_req_cache_hit_ids: + req_idx = input_batch.req_id_to_index[req_id] + num_computed = input_batch.num_computed_tokens_cpu[req_idx] + # NOTE: vLLM only caches full blocks + num_cached_blocks = num_computed // self.block_size + # Get the block IDs attached to this cache hit and reindex into + # the flattened cached hidden states (i.e., 1 row per token). + block_ids = input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] + cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) + + # Slice the hidden states corresponding to this request; + # we do this by using the query start + start = query_start_loc.gpu[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + # TODO: consider putting the actually hidden state cache on CPU + combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) + + logger.info( + f"[Cache combine] req={req_id} cached_blocks={num_cached_blocks} " + f"cached hidden states shape={cached_hs.shape} " + f"new hidden states shape={new_hs.shape}" + ) + + return combined_hidden_states diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 42f71aea5f1..e01a9e51735 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -501,11 +501,12 @@ def execute_model( # Cache hidden states if we've enabled hidden state prefix caching # unless this isn't the last pipeline parallelism rank. - if self.cache_hidden_states and self.hidden_state_cache is not None and get_pp_group().is_last_rank: - self.update_hidden_state_cache( + if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: + self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, multimodal_outputs=multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, + input_batch=self.input_batch, ) if not self.broadcast_pp_output: @@ -783,16 +784,16 @@ def propose_draft_token_ids(sampled_token_ids): # Prior to applying the post-processing func, extract # the prefix cached hidden states and multimodal states. - combined_hidden_states = self._get_merged_hidden_states( - cache=self.hidden_state_cache, - hidden_states=hidden_states, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - ) - - combined_multimodal_outputs = self._get_merged_multimodal_states( - multimodal_outputs, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - ) + if self.omni_prefix_cache is None: + combined_hidden_states, combined_multimodal_outputs = None, None + else: + combined_hidden_states, combined_multimodal_outputs = self.omni_prefix_cache._get_combined_states( + query_start_loc=self.query_start_loc, + input_batch=self.input_batch, + hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ) self._process_additional_information_updates( hidden_states, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 20fdc571f84..72fbe477565 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -20,6 +20,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm_omni.core.prefix_cache_utils import OmniTensorPrefixCache from vllm_omni.engine.serialization import deserialize_additional_information from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding from vllm_omni.model_executor.models.output_templates import OmniOutput @@ -43,16 +44,9 @@ def __init__(self, *args, **kwargs): self.model_intermediate_buffer: dict[str, dict[str, Any]] = {} self._omni_num_scheduled_tokens_np: np.ndarray | None = None self._omni_last_model_output: object | None = None - # TODO add another gate for this - self.cache_hidden_states = self.cache_config.enable_prefix_caching - self.hidden_state_cache: torch.Tensor | None = None - self._new_req_cache_hit_ids: set[str] | None = None - # HACK - for testing qwen3omni - self._cacheable_mm_keys: set[str] | None = None - self.mm_outputs_cache: dict[str, torch.Tensor] | None = None - if self.model_config.model_stage == "thinker": - logger.warning("HACK - for now cacheable mm keys are hardcoded to 0 & 24 for the thinker phase") - self._cacheable_mm_keys = {"0", "24"} + # The Omni tensor prefix cache will be allocated + # when we initialize the metadata builders if enabled + self.omni_prefix_cache = None def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): """Override to fix scheduler_metadata buffer size for FA3 + CUDA graph. @@ -79,39 +73,17 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): dtype=sm.dtype, device=sm.device, ) - # If we've enabled hidden state prefix caching, preallocate - # the cache; this is currently only used for the output hidden - # states of the stage. - # NOTE: hidden states are large; careful to ensure things are - # safely handled for memory in the future if this is expanded. - if self.cache_hidden_states: - num_blocks = kv_cache_config.num_blocks - block_size = self.cache_config.block_size - hidden_size = self.model_config.get_hidden_size() - - # We don't handle hybrid cache types for hidden state caching yet; - # This does matter since - if len(kv_cache_config.kv_cache_groups) > 1: - logger.warning("Hidden state caching with multiple KV cache groups not yet supported, disabling.") - self.cache_hidden_states = False - else: - logger.info("Initializing hidden states prefix cache") - self.hidden_state_cache = torch.zeros( - (num_blocks, block_size, hidden_size), - dtype=self.dtype, - device=self.device, - ) - # NOTE: For now we just use a dict of tensors even though all tensors are - # the same shape, since in the future this may not be the case - if self._cacheable_mm_keys is not None: - self.mm_outputs_cache = {} - for k in self._cacheable_mm_keys: - logger.info("Initializing multimodal output cache for key: %s", k) - self.mm_outputs_cache[k] = torch.zeros( - (num_blocks, block_size, hidden_size), - dtype=self.dtype, - device=self.device, - ) + + # Initialize the wrapper for both multimodal output tensors + # and for hidden states to be passed between stages + if self.cache_config.enable_prefix_caching: + self.omni_prefix_cache = OmniTensorPrefixCache( + num_blocks=kv_cache_config.num_blocks, + block_size=self.cache_config.block_size, + hidden_size=self.model_config.get_hidden_size(), + dtype=self.dtype, + device=self.device, + ) def _get_merged_multimodal_states(self, multimodal_outputs, num_scheduled_tokens): """Get the merged multimodal states if hidden state prefix caching is enabled.""" @@ -402,7 +374,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # num_computed_tokens > 0 means that we have a hit in prefix # caching; mark it so that we can manage the hidden states # later on as needed. - if self.cache_hidden_states and new_req_data.num_computed_tokens > 0: + if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0: self._new_req_cache_hit_ids.add(req_id) sampling_params = new_req_data.sampling_params @@ -1121,7 +1093,7 @@ def _process_additional_information_updates( multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", - combined_hidden_states: torch.Tensor | None = None, + combined_hidden_states: dict[str, torch.Tensor] | None = None, ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: From 529f21504a36719b73a9c479266edc13c159e0c1 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 29 Mar 2026 05:20:32 +0000 Subject: [PATCH 05/38] don't pass unused / wrong mm outputs to proc additional updates Signed-off-by: Alex Brooks --- vllm_omni/platforms/npu/worker/npu_ar_model_runner.py | 2 +- vllm_omni/worker/gpu_ar_model_runner.py | 1 - vllm_omni/worker/gpu_model_runner.py | 11 ++--------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index ffb997048bd..fb5c1cf1367 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -639,7 +639,7 @@ def propose_draft_token_ids(sampled_token_ids): ) self._process_additional_information_updates( - hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output + hidden_states, num_scheduled_tokens_np, scheduler_output ) # Pre-copy multimodal tensors to CPU once (not per-request) to avoid diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index e01a9e51735..4dc958afed7 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -797,7 +797,6 @@ def propose_draft_token_ids(sampled_token_ids): self._process_additional_information_updates( hidden_states, - multimodal_outputs, num_scheduled_tokens_np, scheduler_output, combined_hidden_states, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 72fbe477565..84092c1d763 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1090,7 +1090,6 @@ def _build_model_kwargs_extra(self) -> dict: def _process_additional_information_updates( self, hidden_states: torch.Tensor, - multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", combined_hidden_states: dict[str, torch.Tensor] | None = None, @@ -1110,16 +1109,10 @@ def _process_additional_information_updates( s, e = start_offset, start_offset + sched_tokens # only consider to store data into update dict. hidden_states_slice = hidden_states[s:e] - update_dict = self.model.postprocess( - hidden_states_slice, multimodal_outputs=multimodal_outputs, **req_infos - ) + update_dict = self.model.postprocess(hidden_states_slice, **req_infos) self._update_intermediate_buffer(req_id, update_dict) except Exception as e: - logger.error( - f"Error merging for requests:{self.input_batch.req_ids} " - f"additional information update: {e}, with the multimodal_outputs " - f"as {multimodal_outputs}" - ) + logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}") import traceback traceback.print_exc() From 283b3d718ccce003233494949af15ee1ded6d5ec Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 29 Mar 2026 05:33:27 +0000 Subject: [PATCH 06/38] more refactoring Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache_utils.py | 13 +++--- vllm_omni/worker/gpu_model_runner.py | 63 ++-------------------------- 2 files changed, 10 insertions(+), 66 deletions(-) diff --git a/vllm_omni/core/prefix_cache_utils.py b/vllm_omni/core/prefix_cache_utils.py index f0632ddb475..c519680b5ab 100644 --- a/vllm_omni/core/prefix_cache_utils.py +++ b/vllm_omni/core/prefix_cache_utils.py @@ -7,11 +7,6 @@ logger = init_logger(__name__) -""" -NOTE for tomorrow - working on pulling the equivalent functionality out here. also, we should make sure -that we clear new requests, since I don't think we currently are + make sure that we pass the combined -multimodal states since we are not right now -""" # TODO - Make this configurable and factor in the number # of multimodal tensors. @@ -60,7 +55,7 @@ def __init__(self, num_blocks: int, block_size: int, hidden_size: int, dtype, de for cacheable_key in CACHEABLE_KEYS } - self._new_req_cache_hit_ids: set[str] | None = None + self._new_req_cache_hit_ids: set[str] = set() @property def hidden_states_cache(self) -> torch.Tensor: @@ -72,6 +67,12 @@ def mm_outputs_cache(self) -> dict[str, torch.Tensor]: """Returns the model specific multimodal outputs cache.""" return {k: v for k, v in self.omni_tensor_cache.items() if k != OMNI_HS_CACHE_KEY} + def add_prefix_cached_new_req_id(self, req_id): + self._new_req_cache_hit_ids.add(req_id) + + def reset_prefix_cached_new_req_ids(self): + self._new_req_cache_hit_ids.clear() + def update_omni_tensor_prefix_cache(self, hidden_states, multimodal_outputs, num_tokens_unpadded, input_batch): """Updates the hidden cache state for for hidden states and multimodal outputs.""" assert self.hidden_states_cache is not None diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 84092c1d763..6792b5b148d 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -85,64 +85,6 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): device=self.device, ) - def _get_merged_multimodal_states(self, multimodal_outputs, num_scheduled_tokens): - """Get the merged multimodal states if hidden state prefix caching is enabled.""" - combined_multimodal_outputs = {} - if ( - self._cacheable_mm_keys - and self.mm_outputs_cache - and multimodal_outputs - and isinstance(multimodal_outputs, dict) - ): - for mm_key in self._cacheable_mm_keys: - if mm_key in multimodal_outputs: - combined_multimodal_outputs[mm_key] = self._get_merged_hidden_states( - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=num_scheduled_tokens, - ) - else: - logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) - return combined_multimodal_outputs - - def _get_merged_hidden_states(self, cache: torch.Tensor | None, hidden_states, num_scheduled_tokens): - """When hidden state caching is enabled, takes the input hidden_states, - which only correspond to the scheduled tokens, and returns a mapping - from request IDs to their full hidden states. This is accomplished by - looking up the block IDs & scheduled token counts to split the - hidden_states. - - NOTE: We do not handle hybrid caches at the moment, which is why - we index into the first block table like this. - """ - combined_hidden_states = {} - if cache is not None and self._new_req_cache_hit_ids: - for req_id in self._new_req_cache_hit_ids: - req_idx = self.input_batch.req_id_to_index[req_id] - num_computed = self.input_batch.num_computed_tokens_cpu[req_idx] - block_size = self.cache_config.block_size - # NOTE: vLLM only caches full blocks - num_cached_blocks = num_computed // block_size - # Get the block IDs attached to this cache hit and reindex into - # the flattened cached hidden states (i.e., 1 row per token). - block_ids = self.input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] - cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) - - # Slice the hidden states corresponding to this request; - # we do this by using the query start - start = self.query_start_loc.gpu[req_idx] - new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] - # TODO: consider putting the actually hidden state cache on CPU - combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) - - logger.info( - f"[Cache combine] req={req_id} cached_blocks={num_cached_blocks} " - f"cached hidden states shape={cached_hs.shape} " - f"new hidden states shape={new_hs.shape}" - ) - - return combined_hidden_states - @instrument(span_name="Loading (GPU)") def load_model(self, *args, **kwargs) -> None: super().load_model(*args, **kwargs) @@ -308,7 +250,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): new/resumed/paused/finished request in the batch. """ # Used for prefix cache - self._new_req_cache_hit_ids = set() + if self.omni_prefix_cache is not None: + self.omni_prefix_cache.reset_prefix_cached_new_req_ids() # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -375,7 +318,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput"): # caching; mark it so that we can manage the hidden states # later on as needed. if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0: - self._new_req_cache_hit_ids.add(req_id) + self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id) sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params From 661e9ef7ef5ed415f8fb7ff1de4e9e5df11ffc31 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 00:15:35 +0000 Subject: [PATCH 07/38] add tests for init and update Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 138 ++++++++++++++++ ...{prefix_cache_utils.py => prefix_cache.py} | 152 ++++++++++++------ vllm_omni/worker/gpu_ar_model_runner.py | 4 +- vllm_omni/worker/gpu_model_runner.py | 2 +- 4 files changed, 241 insertions(+), 55 deletions(-) create mode 100644 tests/core/test_prefix_cache.py rename vllm_omni/core/{prefix_cache_utils.py => prefix_cache.py} (52%) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py new file mode 100644 index 00000000000..8c50483941e --- /dev/null +++ b/tests/core/test_prefix_cache.py @@ -0,0 +1,138 @@ +import pytest +import torch + +from vllm_omni.core.prefix_cache import OmniTensorPrefixCache + +""" +Things to test: + +1. Merging for hidden states +2. Merging for multimodal + -> Happy path + -> No hit cache (different stage) + -> Invalid key path (stage has wrong keys) +3. End to end? +4. CPU offload + -> It's all on the CPU + -> It's all on the GPU + -> It's a mix of on the CPU and GPU +5. +""" +NUM_BLOCKS = 10 +BLOCK_SIZE = 4 +HIDDEN_SIZE = 2 +DEVICE = torch.device("cuda") +DTYPE = torch.float32 +DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE]) + + +def build_cache_with_mm_keys(mm_cache_keys) -> OmniTensorPrefixCache: + return OmniTensorPrefixCache( + num_blocks=NUM_BLOCKS, + block_size=BLOCK_SIZE, + hidden_size=HIDDEN_SIZE, + device=DEVICE, + dtype=DTYPE, + mm_cache_keys=mm_cache_keys, + ) + + +### Tests for initialization +def test_initialization_from_list_of_cache_keys(): + """Ensure that hidden states / mm outputs cache are created with the + correct sizes by default. + """ + mm_cache_keys = ["foo", "bar"] + cache = build_cache_with_mm_keys(mm_cache_keys) + assert isinstance(cache.hidden_states_cache, torch.Tensor) + assert cache.hidden_states_cache.shape == DEFAULT_SHAPE + assert set(mm_cache_keys) == set(cache.mm_outputs_cache.keys()) + for val in cache.mm_outputs_cache.values(): + assert isinstance(val, torch.Tensor) + assert val.shape == DEFAULT_SHAPE + + +def test_initialization_from_dict_of_cache_keys(): + """Ensure that keys in the mm outputs cache can have their own feature + sizes and fall back to the hidden states cache size if they map to None. + """ + mm_cache_keys = { + "foo": 100, + "bar": 50, + "baz": None, + } + cache = build_cache_with_mm_keys(mm_cache_keys) + assert isinstance(cache.hidden_states_cache, torch.Tensor) + assert cache.hidden_states_cache.shape == DEFAULT_SHAPE + assert set(mm_cache_keys) == set(cache.mm_outputs_cache.keys()) + + for key, val in cache.mm_outputs_cache.items(): + assert isinstance(val, torch.Tensor) + hs_override = mm_cache_keys[key] if mm_cache_keys[key] is not None else HIDDEN_SIZE + expected_shape = torch.Size([NUM_BLOCKS, BLOCK_SIZE, hs_override]) + assert val.shape == expected_shape + + +### Tests for Update +def test_update_no_multimodal(): + """Test that slot mappings act as row indices hidden states.""" + cache = build_cache_with_mm_keys(mm_cache_keys=None) + + num_tokens_unpadded = 8 + # Map the hidden states to valid & unique slots + slot_offset = 6 # We'll put our states in slots 6, 7, 8, ..., 13 + slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) + new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) + + cache.update_omni_tensor_prefix_cache( + hidden_states=new_hidden_states, + multimodal_outputs=None, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=slot_mapping, + ) + + # Ensure that if we reshape our 3D cache back to 2D, we can use the + # indices in our slot mappings to access the hidden states as expected + hs_rows = cache.hidden_states_cache.view(NUM_BLOCKS * BLOCK_SIZE, HIDDEN_SIZE) + for slot_idx, new_states in zip(slot_mapping, new_hidden_states): + slot_states = hs_rows[slot_idx] + assert torch.all(slot_states == new_states) + + +@pytest.mark.parametrize( + "mm_cache_keys", + [ + ("foo", "bar"), # All same feature dim (HIDDEN_SIZE) + {"foo": 100, "bar": 50, "baz": None}, # different feature dims + ], +) +def test_update_with_multimodal_outputs(mm_cache_keys): + """Test that slot mappings are correct for multimodal tensors.""" + cache = build_cache_with_mm_keys(mm_cache_keys) + + num_tokens_unpadded = 8 + # Map the hidden states to valid & unique slots + slot_offset = 6 # We'll put our states in slots 6, 7, 8, ..., 13 + slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) + feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} + mm_outputs = { + key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE, device=DEVICE) for key in mm_cache_keys + } + cache.update_omni_tensor_prefix_cache( + hidden_states=None, + multimodal_outputs=mm_outputs, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=slot_mapping, + ) + + for mm_key in mm_cache_keys: + assert mm_key in cache.mm_outputs_cache + key_feat_dim = feature_dims[mm_key] + mm_state_rows = cache.mm_outputs_cache[mm_key].view(NUM_BLOCKS * BLOCK_SIZE, key_feat_dim) + + # Similar to hidden states, but for each key in the dict; + # Different tensors may have different feature dims + new_mm_states = mm_outputs[mm_key] + for slot_idx, new_states in zip(slot_mapping, new_mm_states): + slot_states = mm_state_rows[slot_idx] + assert torch.all(slot_states == new_states) diff --git a/vllm_omni/core/prefix_cache_utils.py b/vllm_omni/core/prefix_cache.py similarity index 52% rename from vllm_omni/core/prefix_cache_utils.py rename to vllm_omni/core/prefix_cache.py index c519680b5ab..8377ba2c1bb 100644 --- a/vllm_omni/core/prefix_cache_utils.py +++ b/vllm_omni/core/prefix_cache.py @@ -4,6 +4,7 @@ import torch from vllm.logger import init_logger +from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -12,9 +13,7 @@ # of multimodal tensors. NUM_GPU_BLOCKS = 2048 # TODO Make this generic, these are specific for qwen3 omni. -OMNI_HS_CACHE_KEY = "_OMNI_HIDDEN_STATES_KEY" MM_CACHE_KEYS = ["0", "24"] -CACHEABLE_KEYS = MM_CACHE_KEYS + [OMNI_HS_CACHE_KEY] class OmniTensorPrefixCache: @@ -31,71 +30,103 @@ class OmniTensorPrefixCache: and translate it to rows in the 3D tensor of shape: (num_blocks, block_size, feature_size) - The above is generally a large tensor, especially since - multiple multimodal output tensors may be cached. To reduce - GPU pressure, we implement this as a two-tier cache, where the - keeping the large tensor offloaded on CPU, and transferring hot - blocks to a GPU buffer of size - (num_gpu_blocks, block_size, feature_size) - while maintaining a layer of indirection mapping the indices of - the CPU slot mappings to the GPU slot mappings. + Currently all tensors are stored on device. """ - def __init__(self, num_blocks: int, block_size: int, hidden_size: int, dtype, device): - self.new_req_cache_hit_ids: set[str] | None = None - + def __init__( + self, + num_blocks: int, + block_size: int, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + mm_cache_keys: list[str] | dict[str, int | None] | None = MM_CACHE_KEYS, + ): + self.num_blocks = num_blocks self.block_size = block_size - # TODO: Support CPU offload and combine these. - self.omni_tensor_cache = { - cacheable_key: torch.zeros( - (num_blocks, block_size, hidden_size), - dtype=dtype, - device=device, - ) - for cacheable_key in CACHEABLE_KEYS - } + self.default_hidden_size = hidden_size + self.dtype = dtype + self.device = device + # TODO: Support CPU offload + self._initialize_omni_tensor_caches(mm_cache_keys) self._new_req_cache_hit_ids: set[str] = set() - @property - def hidden_states_cache(self) -> torch.Tensor: - """Returns the hidden states cache.""" - return self.omni_tensor_cache[OMNI_HS_CACHE_KEY] + def _initialize_omni_tensor_caches(self, mm_cache_keys: list[str] | dict[str, int | None] | None): + """Initialize the Omni Tensor cache tensors; this handles both the + hidden states cache and the multimodal outputs cache. + + The hidden_states cache is a tensor with shape: + (num_blocks, block_size, self.default_hidden_size) - @property - def mm_outputs_cache(self) -> dict[str, torch.Tensor]: - """Returns the model specific multimodal outputs cache.""" - return {k: v for k, v in self.omni_tensor_cache.items() if k != OMNI_HS_CACHE_KEY} + While the mm_outputs_cache is dict mapping keys to tensors of shape: + (num_blocks, block_size, feature_size) + + By default, if mm_cache_keys is a list, feature_size is set to the + default hidden size for all mm_output_keys. We also accept a dict + mapping to feature sizes on a per key basis, falling back to + self.default_hidden_size. for any keys that are None. + """ + self.hidden_states_cache = self._get_cache_tensor() + + self.mm_outputs_cache = {} + if mm_cache_keys: + if isinstance(mm_cache_keys, dict): + for cache_key, hidden_size in mm_cache_keys.items(): + self.mm_outputs_cache[cache_key] = self._get_cache_tensor( + hidden_size=hidden_size, + ) + else: + for cache_key in mm_cache_keys: + self.mm_outputs_cache[cache_key] = self._get_cache_tensor() + + def _get_cache_tensor(self, hidden_size: int | None = None) -> torch.Tensor: + """Allocate a cache tensor for a specific key.""" + actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size + return torch.zeros( + (self.num_blocks, self.block_size, actual_hidden_size), + dtype=self.dtype, + device=self.device, + ) - def add_prefix_cached_new_req_id(self, req_id): + def add_prefix_cached_new_req_id(self, req_id: str): + """Adds a new request ID to the set of prefix cache hits on the batch.""" self._new_req_cache_hit_ids.add(req_id) def reset_prefix_cached_new_req_ids(self): + """Clears the cache hit IDs to prepare for a new engine step.""" self._new_req_cache_hit_ids.clear() - def update_omni_tensor_prefix_cache(self, hidden_states, multimodal_outputs, num_tokens_unpadded, input_batch): - """Updates the hidden cache state for for hidden states and multimodal outputs.""" - assert self.hidden_states_cache is not None - slot_mapping = input_batch.block_table[0].slot_mapping.gpu[:num_tokens_unpadded] - # View the cache as 2D so that we can treat our slots as row indices - flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) - flat_cache[slot_mapping] = hidden_states[:num_tokens_unpadded] - logger.info(f"[HS Cache WRITE] tokens={num_tokens_unpadded}") - - # Do the same for the cached multimodal outputs for this stage; - # for now we assume that all of the multimodal outputs cached - # are exactly the same size as the hidden states. - # TODO (Alex) make this more flexible. - if self.mm_outputs_cache is not None: + def update_omni_tensor_prefix_cache( + self, + hidden_states: torch.Tensor | None, + multimodal_outputs: dict[str, torch.Tensor] | None, + num_tokens_unpadded: int, + slot_mapping: torch.Tensor, + ): + """Updates the hidden cache state for the provided hidden states and multimodal outputs.""" + unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] + if hidden_states is not None: + # View the cache as 2D so that we can treat our slots as row indices + flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = hidden_states[:num_tokens_unpadded] + + # Do the same for the stage's cached multimodal outputs + if multimodal_outputs is not None: for mm_out_key, mm_cache in self.mm_outputs_cache.items(): assert mm_out_key in multimodal_outputs mm_state = multimodal_outputs[mm_out_key] flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) - flat_cache[slot_mapping] = mm_state[:num_tokens_unpadded] + flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] logger.info(f"[multimodal output Cache WRITE] tokens={num_tokens_unpadded}") def _get_combined_states( - self, query_start_loc, input_batch, hidden_states, multimodal_outputs, num_scheduled_tokens + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + hidden_states: torch.Tensor, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], ): combined_mm_states = self._get_merged_multimodal_states( query_start_loc, input_batch, multimodal_outputs, num_scheduled_tokens @@ -105,9 +136,16 @@ def _get_combined_states( ) return combined_hidden_states, combined_mm_states - def _get_merged_multimodal_states(self, query_start_loc, input_batch, multimodal_outputs, num_scheduled_tokens): + def _get_merged_multimodal_states( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], + ): """Get the merged multimodal states if hidden state prefix caching is enabled.""" combined_multimodal_outputs = {} + # TODO Ensure non cached keys are properly handled. for mm_key in MM_CACHE_KEYS: if mm_key in multimodal_outputs: combined_multimodal_outputs[mm_key] = self._get_merged_tensors( @@ -121,7 +159,13 @@ def _get_merged_multimodal_states(self, query_start_loc, input_batch, multimodal logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) return combined_multimodal_outputs - def _get_merged_hidden_states(self, query_start_loc, input_batch, hidden_states, num_scheduled_tokens): + def _get_merged_hidden_states( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ): return self._get_merged_tensors( query_start_loc=query_start_loc, input_batch=input_batch, @@ -131,7 +175,12 @@ def _get_merged_hidden_states(self, query_start_loc, input_batch, hidden_states, ) def _get_merged_tensors( - self, query_start_loc, input_batch, cache: torch.Tensor, hidden_states: torch.Tensor, num_scheduled_tokens + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + cache: torch.Tensor, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], ) -> dict[str, torch.Tensor]: """When hidden state caching is enabled, takes the input hidden_states, which only correspond to the scheduled tokens, and returns a mapping @@ -156,9 +205,8 @@ def _get_merged_tensors( # Slice the hidden states corresponding to this request; # we do this by using the query start - start = query_start_loc.gpu[req_idx] + start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] - # TODO: consider putting the actually hidden state cache on CPU combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) logger.info( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 4dc958afed7..31f1143e7aa 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -506,7 +506,7 @@ def execute_model( hidden_states=hidden_states, multimodal_outputs=multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, - input_batch=self.input_batch, + slot_mapping=self.input_batch.block_table[0].slot_mapping.gpu, ) if not self.broadcast_pp_output: @@ -788,7 +788,7 @@ def propose_draft_token_ids(sampled_token_ids): combined_hidden_states, combined_multimodal_outputs = None, None else: combined_hidden_states, combined_multimodal_outputs = self.omni_prefix_cache._get_combined_states( - query_start_loc=self.query_start_loc, + query_start_loc=self.query_start_loc.gpu, input_batch=self.input_batch, hidden_states=hidden_states, multimodal_outputs=multimodal_outputs, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 6792b5b148d..590cc3288af 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -20,7 +20,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices -from vllm_omni.core.prefix_cache_utils import OmniTensorPrefixCache +from vllm_omni.core.prefix_cache import OmniTensorPrefixCache from vllm_omni.engine.serialization import deserialize_additional_information from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding from vllm_omni.model_executor.models.output_templates import OmniOutput From 01dfa6d21584366af69022d0f2a15018daf3fc0a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 01:24:04 +0000 Subject: [PATCH 08/38] make mm cache keys more generic Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 39 +++++------ vllm_omni/core/prefix_cache.py | 69 +++++++++++++++---- .../models/qwen3_omni/qwen3_omni.py | 6 ++ vllm_omni/worker/gpu_model_runner.py | 1 + 4 files changed, 77 insertions(+), 38 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 8c50483941e..da10c1d6fe6 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -1,23 +1,10 @@ +from unittest.mock import patch + import pytest import torch from vllm_omni.core.prefix_cache import OmniTensorPrefixCache -""" -Things to test: - -1. Merging for hidden states -2. Merging for multimodal - -> Happy path - -> No hit cache (different stage) - -> Invalid key path (stage has wrong keys) -3. End to end? -4. CPU offload - -> It's all on the CPU - -> It's all on the GPU - -> It's a mix of on the CPU and GPU -5. -""" NUM_BLOCKS = 10 BLOCK_SIZE = 4 HIDDEN_SIZE = 2 @@ -27,14 +14,20 @@ def build_cache_with_mm_keys(mm_cache_keys) -> OmniTensorPrefixCache: - return OmniTensorPrefixCache( - num_blocks=NUM_BLOCKS, - block_size=BLOCK_SIZE, - hidden_size=HIDDEN_SIZE, - device=DEVICE, - dtype=DTYPE, - mm_cache_keys=mm_cache_keys, - ) + with patch( + "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._resolve_mm_cache_keys", + return_value=mm_cache_keys, + ): + # Model config is only used for resolving the mm_cache_keys, + # so the value passed here doesn't matter since it's patched. + return OmniTensorPrefixCache( + num_blocks=NUM_BLOCKS, + block_size=BLOCK_SIZE, + hidden_size=HIDDEN_SIZE, + device=DEVICE, + dtype=DTYPE, + model_config=None, + ) ### Tests for initialization diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 8377ba2c1bb..3f6c78315ce 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -2,10 +2,15 @@ Utilities for Prefix Caching in Omni models. """ +from typing import TypeAlias + import torch from vllm.logger import init_logger +from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_omni.config.model import OmniModelConfig + logger = init_logger(__name__) @@ -14,6 +19,11 @@ NUM_GPU_BLOCKS = 2048 # TODO Make this generic, these are specific for qwen3 omni. MM_CACHE_KEYS = ["0", "24"] +# NEXT ^ let's make this generic, we need to pull it off the model +# by getting the class before we initialize this class + +StageMMCacheKeys: TypeAlias = list[str] | dict[str, int | None] +ModelMMCacheKeys: TypeAlias = dict[str, StageMMCacheKeys] | None class OmniTensorPrefixCache: @@ -40,7 +50,7 @@ def __init__( hidden_size: int, dtype: torch.dtype, device: torch.device, - mm_cache_keys: list[str] | dict[str, int | None] | None = MM_CACHE_KEYS, + model_config: OmniModelConfig, ): self.num_blocks = num_blocks self.block_size = block_size @@ -49,10 +59,31 @@ def __init__( self.device = device # TODO: Support CPU offload - self._initialize_omni_tensor_caches(mm_cache_keys) + self.mm_cache_keys = self._resolve_mm_cache_keys(model_config) + self._initialize_omni_tensor_caches(self.mm_cache_keys) self._new_req_cache_hit_ids: set[str] = set() - def _initialize_omni_tensor_caches(self, mm_cache_keys: list[str] | dict[str, int | None] | None): + def _resolve_mm_cache_keys(self, model_config: OmniModelConfig) -> StageMMCacheKeys | None: + """Determined the configuration for multimodal caching for the current model + architecture and stage.""" + model_stage = model_config.model_stage + arch, arch_str = get_model_architecture(model_config) + if hasattr(arch, "_model_mm_cache_keys"): + model_mm_cache_keys = arch._model_mm_cache_keys + if model_stage in model_mm_cache_keys: + stage_mm_cache_keys = model_mm_cache_keys[model_stage] + logger.info(f"Resolved mm_cache_keys for stage {model_stage} - {stage_mm_cache_keys}") + return stage_mm_cache_keys + + # TODO: Move have_multimodal_outputs to class property and set this log to + # error level & to only go off if we actually have mm outputs. + logger.warning( + f"Model architecture {arch_str} does not have defined _mm_cache_keys and will" + " therefore not able leverage prefix caching for multimodal outputs. " + " As such, prefix caching may not be supported." + ) + + def _initialize_omni_tensor_caches(self, mm_cache_keys: StageMMCacheKeys | None): """Initialize the Omni Tensor cache tensors; this handles both the hidden states cache and the multimodal outputs cache. @@ -110,6 +141,7 @@ def update_omni_tensor_prefix_cache( # View the cache as 2D so that we can treat our slots as row indices flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) flat_cache[unpadded_slot_mapping] = hidden_states[:num_tokens_unpadded] + logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded) # Do the same for the stage's cached multimodal outputs if multimodal_outputs is not None: @@ -118,7 +150,7 @@ def update_omni_tensor_prefix_cache( mm_state = multimodal_outputs[mm_out_key] flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] - logger.info(f"[multimodal output Cache WRITE] tokens={num_tokens_unpadded}") + logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) def _get_combined_states( self, @@ -146,17 +178,24 @@ def _get_merged_multimodal_states( """Get the merged multimodal states if hidden state prefix caching is enabled.""" combined_multimodal_outputs = {} # TODO Ensure non cached keys are properly handled. - for mm_key in MM_CACHE_KEYS: - if mm_key in multimodal_outputs: - combined_multimodal_outputs[mm_key] = self._get_merged_tensors( - query_start_loc=query_start_loc, - input_batch=input_batch, - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=num_scheduled_tokens, - ) - else: - logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + if self.mm_cache_keys is not None: + for mm_key in self.mm_cache_keys: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) + else: + logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) + elif multimodal_outputs: + logger.warning( + " A model stage produced multimodal outputs, but has no defined mm_cache_keys; " + " this probably means that prefix caching is not fully supported for all stages " + "in this model" + ) return combined_multimodal_outputs def _get_merged_hidden_states( diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 7df69479734..5cd6781579b 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -36,6 +36,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +from vllm_omni.core.prefix_cache import ModelMMCacheKeys from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import ( @@ -98,6 +99,11 @@ class Qwen3OmniMoeForConditionalGeneration( """ realtime_max_tokens = 64 + # Currently, we only support prefix caching for the thinker stage + _model_mm_cache_keys: ModelMMCacheKeys = { + # keys 0 & 24 for the thinker have the same dimensionality as hidden states + "thinker": ["0", "24"] + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 590cc3288af..ab057a5f053 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -83,6 +83,7 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): hidden_size=self.model_config.get_hidden_size(), dtype=self.dtype, device=self.device, + model_config=self.model_config, ) @instrument(span_name="Loading (GPU)") From 53d59d9345aa496a449388cefb15346eebb9a5bd Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 01:24:38 +0000 Subject: [PATCH 09/38] remove unused keys Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 3f6c78315ce..92dac0338fa 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -17,10 +17,6 @@ # TODO - Make this configurable and factor in the number # of multimodal tensors. NUM_GPU_BLOCKS = 2048 -# TODO Make this generic, these are specific for qwen3 omni. -MM_CACHE_KEYS = ["0", "24"] -# NEXT ^ let's make this generic, we need to pull it off the model -# by getting the class before we initialize this class StageMMCacheKeys: TypeAlias = list[str] | dict[str, int | None] ModelMMCacheKeys: TypeAlias = dict[str, StageMMCacheKeys] | None From 60696aa48147be9bf69029917480f0d131d41bb5 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 01:26:29 +0000 Subject: [PATCH 10/38] update comment Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 92dac0338fa..7c27cce411a 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -65,14 +65,15 @@ def _resolve_mm_cache_keys(self, model_config: OmniModelConfig) -> StageMMCacheK model_stage = model_config.model_stage arch, arch_str = get_model_architecture(model_config) if hasattr(arch, "_model_mm_cache_keys"): - model_mm_cache_keys = arch._model_mm_cache_keys + model_mm_cache_keys = getattr(arch, "_model_mm_cache_keys") if model_stage in model_mm_cache_keys: stage_mm_cache_keys = model_mm_cache_keys[model_stage] logger.info(f"Resolved mm_cache_keys for stage {model_stage} - {stage_mm_cache_keys}") return stage_mm_cache_keys # TODO: Move have_multimodal_outputs to class property and set this log to - # error level & to only go off if we actually have mm outputs. + # only go off for models that support have_multimodal_outputs, since the + # hidden states caching is generic logger.warning( f"Model architecture {arch_str} does not have defined _mm_cache_keys and will" " therefore not able leverage prefix caching for multimodal outputs. " From d8072ac16ea1fa22d4eea9bd43cbb9244c001ebb Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 03:09:29 +0000 Subject: [PATCH 11/38] fix passthrough Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 69 ++++++++++--------------- vllm_omni/worker/gpu_ar_model_runner.py | 7 ++- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 7c27cce411a..80cb170b4e5 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -149,23 +149,7 @@ def update_omni_tensor_prefix_cache( flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) - def _get_combined_states( - self, - query_start_loc: torch.Tensor, - input_batch: InputBatch, - hidden_states: torch.Tensor, - multimodal_outputs: dict, - num_scheduled_tokens: dict[str, int], - ): - combined_mm_states = self._get_merged_multimodal_states( - query_start_loc, input_batch, multimodal_outputs, num_scheduled_tokens - ) - combined_hidden_states = self._get_merged_hidden_states( - query_start_loc, input_batch, hidden_states, num_scheduled_tokens - ) - return combined_hidden_states, combined_mm_states - - def _get_merged_multimodal_states( + def get_merged_multimodal_states( self, query_start_loc: torch.Tensor, input_batch: InputBatch, @@ -195,19 +179,12 @@ def _get_merged_multimodal_states( ) return combined_multimodal_outputs - def _get_merged_hidden_states( - self, - query_start_loc: torch.Tensor, - input_batch: InputBatch, - hidden_states: torch.Tensor, - num_scheduled_tokens: dict[str, int], - ): + def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]: + """Get the merged hidden states.""" return self._get_merged_tensors( - query_start_loc=query_start_loc, - input_batch=input_batch, + *args, + **kwargs, cache=self.hidden_states_cache, - hidden_states=hidden_states, - num_scheduled_tokens=num_scheduled_tokens, ) def _get_merged_tensors( @@ -228,15 +205,11 @@ def _get_merged_tensors( we index into the first block table like this. """ combined_hidden_states = {} - if cache is not None and self._new_req_cache_hit_ids: - for req_id in self._new_req_cache_hit_ids: - req_idx = input_batch.req_id_to_index[req_id] - num_computed = input_batch.num_computed_tokens_cpu[req_idx] - # NOTE: vLLM only caches full blocks - num_cached_blocks = num_computed // self.block_size - # Get the block IDs attached to this cache hit and reindex into - # the flattened cached hidden states (i.e., 1 row per token). - block_ids = input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + + if req_id in self._new_req_cache_hit_ids: + block_ids = self._get_cached_block_ids(req_idx, input_batch) cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) # Slice the hidden states corresponding to this request; @@ -244,11 +217,21 @@ def _get_merged_tensors( start = query_start_loc[req_idx] new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) - - logger.info( - f"[Cache combine] req={req_id} cached_blocks={num_cached_blocks} " - f"cached hidden states shape={cached_hs.shape} " - f"new hidden states shape={new_hs.shape}" - ) + else: + # cache miss for this request, pass through normally + start = query_start_loc[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + combined_hidden_states[req_id] = new_hs return combined_hidden_states + + def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch.Tensor: + """Given an input batch and request index in the batch (not ID), get the + block IDs corresponding to the cache hit. + """ + num_computed = input_batch.num_computed_tokens_cpu[req_idx] + # NOTE: vLLM only caches full blocks + num_cached_blocks = num_computed // self.block_size + # Get the block IDs attached to this cache hit and reindex into + # the flattened cached hidden states (i.e., 1 row per token). + return input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 31f1143e7aa..c4583627302 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -787,10 +787,15 @@ def propose_draft_token_ids(sampled_token_ids): if self.omni_prefix_cache is None: combined_hidden_states, combined_multimodal_outputs = None, None else: - combined_hidden_states, combined_multimodal_outputs = self.omni_prefix_cache._get_combined_states( + combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( query_start_loc=self.query_start_loc.gpu, input_batch=self.input_batch, hidden_states=hidden_states, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ) + combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( + query_start_loc=self.query_start_loc.gpu, + input_batch=self.input_batch, multimodal_outputs=multimodal_outputs, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) From 3a18ffafc707808fb8b04fb9fd8fc5d7363a34e2 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 03:09:57 +0000 Subject: [PATCH 12/38] add merge test for hidden states Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 89 ++++++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index da10c1d6fe6..e8b24608c00 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -13,6 +13,13 @@ DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE]) +class MockInputBatch: + def __init__(self, num_computed_tokens_cpu): + self.req_ids = ["req1", "req2"] + self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)} + self.num_computed_tokens_cpu = num_computed_tokens_cpu + + def build_cache_with_mm_keys(mm_cache_keys) -> OmniTensorPrefixCache: with patch( "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._resolve_mm_cache_keys", @@ -73,7 +80,7 @@ def test_update_no_multimodal(): num_tokens_unpadded = 8 # Map the hidden states to valid & unique slots - slot_offset = 6 # We'll put our states in slots 6, 7, 8, ..., 13 + slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) @@ -105,7 +112,7 @@ def test_update_with_multimodal_outputs(mm_cache_keys): num_tokens_unpadded = 8 # Map the hidden states to valid & unique slots - slot_offset = 6 # We'll put our states in slots 6, 7, 8, ..., 13 + slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} mm_outputs = { @@ -129,3 +136,81 @@ def test_update_with_multimodal_outputs(mm_cache_keys): for slot_idx, new_states in zip(slot_mapping, new_mm_states): slot_states = mm_state_rows[slot_idx] assert torch.all(slot_states == new_states) + + +### Tests for Merging +def fake_get_cached_block_ids(self, req_idx, *args, **kwargs): + """Fake block table lookup. + + Assumption: + req_idx 0 is a cache hit with slots 8, 9, ..., 15 + req_idx 1 is a cache miss + """ + assert req_idx < 2 + if req_idx == 0: + # With the slot offset we provided (8), the corresponding + # blocks IDs are 2 & 3 because the block size is 4. + return torch.tensor([2, 3], dtype=torch.long) + return torch.tensor([], dtype=torch.long) + + +def test_get_merged_hidden_states(): + """Ensure that hidden states are merged correctly.""" + cache = build_cache_with_mm_keys(mm_cache_keys=None) + + orig_num_tokens_unpadded = 8 + slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 + # Map the hidden states to valid & unique slots + orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) + orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) + + cache.update_omni_tensor_prefix_cache( + hidden_states=orig_hidden_states, + multimodal_outputs=None, + num_tokens_unpadded=orig_num_tokens_unpadded, + slot_mapping=orig_slot_mapping, + ) + + # Say that we have two requests, but only one of them is a cache hit + num_new_toks_req1 = 3 + num_new_toks_req2 = 2 + cache.add_prefix_cached_new_req_id("req1") + + num_scheduled_tokens = { + "req1": num_new_toks_req1, + "req2": num_new_toks_req2, + } + new_hidden_states = torch.rand( + (num_new_toks_req1 + num_new_toks_req2, HIDDEN_SIZE), + dtype=DTYPE, + device=DEVICE, + ) + req1_new_states = new_hidden_states[:num_new_toks_req1] + req2_new_states = new_hidden_states[-num_new_toks_req2:] + + input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0])) + + with patch( + "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids", + new=fake_get_cached_block_ids, + ): + merged_states = cache.get_merged_hidden_states( + query_start_loc=[0, num_new_toks_req1], + input_batch=input_batch, + hidden_states=new_hidden_states, + num_scheduled_tokens=num_scheduled_tokens, + ) + + assert "req1" in merged_states and "req2" in merged_states + req1_merged_states = merged_states["req1"] + req2_merged_states = merged_states["req2"] + + # First, check the partial cache hit case + assert req1_merged_states.shape == torch.Size([orig_num_tokens_unpadded + num_new_toks_req1, HIDDEN_SIZE]) + # Ensure that the req1 merged states are the cached states + the new req1 states + assert torch.all(req1_merged_states[:orig_num_tokens_unpadded] == orig_hidden_states) + assert torch.all(req1_merged_states[-num_new_toks_req1:] == req1_new_states) + + # Next, ensure that the cache miss case only has the new states + assert req2_merged_states.shape == torch.Size([num_new_toks_req2, HIDDEN_SIZE]) + assert torch.all(req2_merged_states == req2_new_states) From 71f16a4e55d7ba2f2f25720b932457071a2e6904 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 03:59:05 +0000 Subject: [PATCH 13/38] add test for multimodal state merging Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 102 +++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 7 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index e8b24608c00..21c860690ba 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -79,7 +79,6 @@ def test_update_no_multimodal(): cache = build_cache_with_mm_keys(mm_cache_keys=None) num_tokens_unpadded = 8 - # Map the hidden states to valid & unique slots slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) @@ -111,7 +110,6 @@ def test_update_with_multimodal_outputs(mm_cache_keys): cache = build_cache_with_mm_keys(mm_cache_keys) num_tokens_unpadded = 8 - # Map the hidden states to valid & unique slots slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} @@ -132,10 +130,10 @@ def test_update_with_multimodal_outputs(mm_cache_keys): # Similar to hidden states, but for each key in the dict; # Different tensors may have different feature dims - new_mm_states = mm_outputs[mm_key] - for slot_idx, new_states in zip(slot_mapping, new_mm_states): + new_mm_outputs = mm_outputs[mm_key] + for slot_idx, new_output in zip(slot_mapping, new_mm_outputs): slot_states = mm_state_rows[slot_idx] - assert torch.all(slot_states == new_states) + assert torch.all(slot_states == new_output) ### Tests for Merging @@ -160,7 +158,6 @@ def test_get_merged_hidden_states(): orig_num_tokens_unpadded = 8 slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 - # Map the hidden states to valid & unique slots orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) @@ -205,7 +202,7 @@ def test_get_merged_hidden_states(): req1_merged_states = merged_states["req1"] req2_merged_states = merged_states["req2"] - # First, check the partial cache hit case + # First, check the cache hit case assert req1_merged_states.shape == torch.Size([orig_num_tokens_unpadded + num_new_toks_req1, HIDDEN_SIZE]) # Ensure that the req1 merged states are the cached states + the new req1 states assert torch.all(req1_merged_states[:orig_num_tokens_unpadded] == orig_hidden_states) @@ -214,3 +211,94 @@ def test_get_merged_hidden_states(): # Next, ensure that the cache miss case only has the new states assert req2_merged_states.shape == torch.Size([num_new_toks_req2, HIDDEN_SIZE]) assert torch.all(req2_merged_states == req2_new_states) + + +@pytest.mark.parametrize( + "mm_cache_keys", + [ + ("foo", "bar"), # All same feature dim (HIDDEN_SIZE) + {"foo": 100, "bar": 50, "baz": None}, # different feature dims + ], +) +def test_get_merged_multimodal_outputs(mm_cache_keys): + cache = build_cache_with_mm_keys(mm_cache_keys) + + orig_num_tokens_unpadded = 8 + slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 + orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) + feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} + orig_mm_outputs = { + key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE, device=DEVICE) + for key in mm_cache_keys + } + + cache.update_omni_tensor_prefix_cache( + hidden_states=None, + multimodal_outputs=orig_mm_outputs, + num_tokens_unpadded=orig_num_tokens_unpadded, + slot_mapping=orig_slot_mapping, + ) + + # Similar to hs test- say that we have two requests, but only one of them is a cache hit + num_new_toks_req1 = 3 + num_new_toks_req2 = 2 + cache.add_prefix_cached_new_req_id("req1") + + num_scheduled_tokens = { + "req1": num_new_toks_req1, + "req2": num_new_toks_req2, + } + + new_mm_outputs = {} + for mm_key in mm_cache_keys: + new_mm_outputs[mm_key] = torch.rand( + (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]), + dtype=DTYPE, + device=DEVICE, + ) + # We also want to make sure passthrough data (outside of our keys) isn't dropped + new_mm_outputs["passthrough_data"] = "Something else" + + input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0])) + + with patch( + "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids", + new=fake_get_cached_block_ids, + ): + merged_mm_outputs = cache.get_merged_multimodal_states( + query_start_loc=[0, num_new_toks_req1], + input_batch=input_batch, + multimodal_outputs=new_mm_outputs, + num_scheduled_tokens=num_scheduled_tokens, + ) + + # Ensure the passthrough data wasn't dropped + assert "passthrough_data" in merged_mm_outputs + + for mm_key, mm_output in merged_mm_outputs.items(): + # Ensure passthrough data is just forwarded normally and not duplicated + if mm_key == "passthrough_data": + assert new_mm_outputs[mm_key] == mm_output + assert new_mm_outputs[mm_key] == mm_output + else: + assert mm_key in mm_cache_keys + assert isinstance(mm_output, dict) + assert "req1" in mm_output and "req2" in mm_output + curr_feat_dim = feature_dims[mm_key] + # Ensure that req1 (cache hit) merged the mm data + req1_merged_mm_outputs = mm_output["req1"] + req1_new_mm_outputs = new_mm_outputs[mm_key][:num_new_toks_req1] + + assert req1_merged_mm_outputs.shape == torch.Size( + [orig_num_tokens_unpadded + num_new_toks_req1, curr_feat_dim] + ) + # Ensure that the req1 merged mm data are the cached data + the new data + assert torch.all(req1_merged_mm_outputs[:orig_num_tokens_unpadded] == orig_mm_outputs[mm_key]) + assert torch.all(req1_merged_mm_outputs[-num_new_toks_req1:] == req1_new_mm_outputs) + + # Ensure that req2 (cache miss) only has the new mm data + req2_merged_mm_outputs = mm_output["req2"] + req2_new_mm_outputs = new_mm_outputs[mm_key][-num_new_toks_req2:] + + assert req2_merged_mm_outputs.shape == torch.Size([num_new_toks_req2, curr_feat_dim]) + assert torch.all(req2_merged_mm_outputs == req2_new_mm_outputs) From 8f1dd8f3297eef10fe41a94d44dc861625ba2f15 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 03:59:39 +0000 Subject: [PATCH 14/38] fix passthrough data on mm outputs Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 80cb170b4e5..c5a9fa2fa8b 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -158,25 +158,25 @@ def get_merged_multimodal_states( ): """Get the merged multimodal states if hidden state prefix caching is enabled.""" combined_multimodal_outputs = {} - # TODO Ensure non cached keys are properly handled. - if self.mm_cache_keys is not None: - for mm_key in self.mm_cache_keys: - if mm_key in multimodal_outputs: - combined_multimodal_outputs[mm_key] = self._get_merged_tensors( - query_start_loc=query_start_loc, - input_batch=input_batch, - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=num_scheduled_tokens, - ) - else: - logger.error("Cacheable multimodal key %s is not present in multimodal outputs", mm_key) - elif multimodal_outputs: + if self.mm_cache_keys is None and multimodal_outputs: logger.warning( " A model stage produced multimodal outputs, but has no defined mm_cache_keys; " " this probably means that prefix caching is not fully supported for all stages " "in this model" ) + + for mm_key, mm_val in multimodal_outputs.items(): + if self.mm_cache_keys is not None and mm_key in self.mm_cache_keys: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=mm_val, + num_scheduled_tokens=num_scheduled_tokens, + ) + else: + # Note an mm_cache_keys; pass through normally + combined_multimodal_outputs[mm_key] = mm_val return combined_multimodal_outputs def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]: From a9ccc322b5b39dae393002e880386742a5fec15e Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 17:07:56 +0000 Subject: [PATCH 15/38] use cpu tensors for prefix cache Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 15 ++++---------- vllm_omni/core/prefix_cache.py | 27 ++++++++++++++----------- vllm_omni/worker/gpu_ar_model_runner.py | 6 +++--- vllm_omni/worker/gpu_model_runner.py | 1 - 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 21c860690ba..c81a1cf0fd0 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -8,7 +8,6 @@ NUM_BLOCKS = 10 BLOCK_SIZE = 4 HIDDEN_SIZE = 2 -DEVICE = torch.device("cuda") DTYPE = torch.float32 DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE]) @@ -31,7 +30,6 @@ def build_cache_with_mm_keys(mm_cache_keys) -> OmniTensorPrefixCache: num_blocks=NUM_BLOCKS, block_size=BLOCK_SIZE, hidden_size=HIDDEN_SIZE, - device=DEVICE, dtype=DTYPE, model_config=None, ) @@ -81,7 +79,7 @@ def test_update_no_multimodal(): num_tokens_unpadded = 8 slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) - new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) + new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE) cache.update_omni_tensor_prefix_cache( hidden_states=new_hidden_states, @@ -113,9 +111,7 @@ def test_update_with_multimodal_outputs(mm_cache_keys): slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} - mm_outputs = { - key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE, device=DEVICE) for key in mm_cache_keys - } + mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in mm_cache_keys} cache.update_omni_tensor_prefix_cache( hidden_states=None, multimodal_outputs=mm_outputs, @@ -159,7 +155,7 @@ def test_get_merged_hidden_states(): orig_num_tokens_unpadded = 8 slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) - orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE, device=DEVICE) + orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE) cache.update_omni_tensor_prefix_cache( hidden_states=orig_hidden_states, @@ -180,7 +176,6 @@ def test_get_merged_hidden_states(): new_hidden_states = torch.rand( (num_new_toks_req1 + num_new_toks_req2, HIDDEN_SIZE), dtype=DTYPE, - device=DEVICE, ) req1_new_states = new_hidden_states[:num_new_toks_req1] req2_new_states = new_hidden_states[-num_new_toks_req2:] @@ -228,8 +223,7 @@ def test_get_merged_multimodal_outputs(mm_cache_keys): orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} orig_mm_outputs = { - key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE, device=DEVICE) - for key in mm_cache_keys + key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in mm_cache_keys } cache.update_omni_tensor_prefix_cache( @@ -254,7 +248,6 @@ def test_get_merged_multimodal_outputs(mm_cache_keys): new_mm_outputs[mm_key] = torch.rand( (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]), dtype=DTYPE, - device=DEVICE, ) # We also want to make sure passthrough data (outside of our keys) isn't dropped new_mm_outputs["passthrough_data"] = "Something else" diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index c5a9fa2fa8b..7059cf80254 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -14,10 +14,6 @@ logger = init_logger(__name__) -# TODO - Make this configurable and factor in the number -# of multimodal tensors. -NUM_GPU_BLOCKS = 2048 - StageMMCacheKeys: TypeAlias = list[str] | dict[str, int | None] ModelMMCacheKeys: TypeAlias = dict[str, StageMMCacheKeys] | None @@ -35,8 +31,6 @@ class OmniTensorPrefixCache: and translate it to rows in the 3D tensor of shape: (num_blocks, block_size, feature_size) - - Currently all tensors are stored on device. """ def __init__( @@ -45,16 +39,13 @@ def __init__( block_size: int, hidden_size: int, dtype: torch.dtype, - device: torch.device, model_config: OmniModelConfig, ): self.num_blocks = num_blocks self.block_size = block_size self.default_hidden_size = hidden_size self.dtype = dtype - self.device = device - # TODO: Support CPU offload self.mm_cache_keys = self._resolve_mm_cache_keys(model_config) self._initialize_omni_tensor_caches(self.mm_cache_keys) self._new_req_cache_hit_ids: set[str] = set() @@ -109,12 +100,12 @@ def _initialize_omni_tensor_caches(self, mm_cache_keys: StageMMCacheKeys | None) self.mm_outputs_cache[cache_key] = self._get_cache_tensor() def _get_cache_tensor(self, hidden_size: int | None = None) -> torch.Tensor: - """Allocate a cache tensor for a specific key.""" + """Allocate a CPU cache tensor for a specific key.""" actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size return torch.zeros( (self.num_blocks, self.block_size, actual_hidden_size), dtype=self.dtype, - device=self.device, + device="cpu", ) def add_prefix_cached_new_req_id(self, req_id: str): @@ -125,6 +116,11 @@ def reset_prefix_cached_new_req_ids(self): """Clears the cache hit IDs to prepare for a new engine step.""" self._new_req_cache_hit_ids.clear() + @staticmethod + def _coerce_to_cpu_tensor(maybe_gpu_tensor: torch.Tensor) -> torch.Tensor: + """Convert GPU tensors -> contiguous CPU tensors if needed.""" + return maybe_gpu_tensor.detach().cpu().contiguous() + def update_omni_tensor_prefix_cache( self, hidden_states: torch.Tensor | None, @@ -133,8 +129,13 @@ def update_omni_tensor_prefix_cache( slot_mapping: torch.Tensor, ): """Updates the hidden cache state for the provided hidden states and multimodal outputs.""" + if hidden_states is not None: + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) + unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] if hidden_states is not None: + # Ensure that hidden states are on the CPU + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) # View the cache as 2D so that we can treat our slots as row indices flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) flat_cache[unpadded_slot_mapping] = hidden_states[:num_tokens_unpadded] @@ -145,6 +146,7 @@ def update_omni_tensor_prefix_cache( for mm_out_key, mm_cache in self.mm_outputs_cache.items(): assert mm_out_key in multimodal_outputs mm_state = multimodal_outputs[mm_out_key] + mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) @@ -205,6 +207,7 @@ def _get_merged_tensors( we index into the first block table like this. """ combined_hidden_states = {} + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) for req_id in input_batch.req_ids: req_idx = input_batch.req_id_to_index[req_id] @@ -234,4 +237,4 @@ def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch. num_cached_blocks = num_computed // self.block_size # Get the block IDs attached to this cache hit and reindex into # the flattened cached hidden states (i.e., 1 row per token). - return input_batch.block_table[0].block_table.gpu[req_idx, :num_cached_blocks] + return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index c4583627302..96d9b3a7d44 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -506,7 +506,7 @@ def execute_model( hidden_states=hidden_states, multimodal_outputs=multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, - slot_mapping=self.input_batch.block_table[0].slot_mapping.gpu, + slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, ) if not self.broadcast_pp_output: @@ -788,13 +788,13 @@ def propose_draft_token_ids(sampled_token_ids): combined_hidden_states, combined_multimodal_outputs = None, None else: combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( - query_start_loc=self.query_start_loc.gpu, + query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, hidden_states=hidden_states, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( - query_start_loc=self.query_start_loc.gpu, + query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, multimodal_outputs=multimodal_outputs, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index ab057a5f053..444fa51c518 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -82,7 +82,6 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): block_size=self.cache_config.block_size, hidden_size=self.model_config.get_hidden_size(), dtype=self.dtype, - device=self.device, model_config=self.model_config, ) From 2320122c9b20da75d7cefa8edfd5dcfbb851bf01 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 17:53:42 +0000 Subject: [PATCH 16/38] docstring updates Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 7059cf80254..142adf235a0 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -19,18 +19,20 @@ class OmniTensorPrefixCache: - """Prefix cache for hidden states (model outputs) - and model specific multimodal outputs. + """Prefix cache for hidden states (model outputs) and model specific + multimodal outputs. - This class implements prefix caching in a non-invasive - way on top of vLLM by leveraging the same slot mappings - that the vLLM scheduler uses for the KV Cache + This class implements prefix caching in a non-invasive way on top of + vLLM by leveraging the same slot mappings that the vLLM scheduler uses + for the KV Cache. - Conceptually, we are vLLM's mapping from: - (num_blocks, block_size) + Conceptually, this means we are mapping vLLM's cache mapping: + (num_blocks, block_size) - and translate it to rows in the 3D tensor of shape: - (num_blocks, block_size, feature_size) + to 3D tensors of shape: + (num_blocks, block_size, feature_size) + + Note that feature_size may vary across multimodal_outputs. """ def __init__( From 7654d2ae852804d15b0bc9bd5d5defbf7d4f404a Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 18:12:49 +0000 Subject: [PATCH 17/38] warn for multi kv cache groups Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 142adf235a0..e302f42ee51 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -204,10 +204,14 @@ def _get_merged_tensors( from request IDs to their full hidden states. This is accomplished by looking up the block IDs & scheduled token counts to split the hidden_states. - - NOTE: We do not handle hybrid caches at the moment, which is why - we index into the first block table like this. """ + # We do not support hybrid caches at the moment. + if len(input_batch.block_table.block_tables) > 1: + logger.warning_once( + "Omni prefix caching is enabled, but the batch block table appears to" + " have multiple kv groups; only the first group will be used!" + ) + combined_hidden_states = {} hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) for req_id in input_batch.req_ids: From 87e3ab3dc3e7dfe0910fc095048c916999a08a8b Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 20:41:32 +0000 Subject: [PATCH 18/38] infer mm cache keys Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 85 ++++++------------- .../models/qwen3_omni/qwen3_omni.py | 6 -- vllm_omni/worker/gpu_ar_model_runner.py | 5 ++ vllm_omni/worker/gpu_model_runner.py | 1 - 4 files changed, 32 insertions(+), 65 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index e302f42ee51..dc8ca42037b 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -2,22 +2,13 @@ Utilities for Prefix Caching in Omni models. """ -from typing import TypeAlias - import torch from vllm.logger import init_logger -from vllm.model_executor.model_loader.utils import get_model_architecture from vllm.v1.worker.gpu_input_batch import InputBatch -from vllm_omni.config.model import OmniModelConfig - logger = init_logger(__name__) -StageMMCacheKeys: TypeAlias = list[str] | dict[str, int | None] -ModelMMCacheKeys: TypeAlias = dict[str, StageMMCacheKeys] | None - - class OmniTensorPrefixCache: """Prefix cache for hidden states (model outputs) and model specific multimodal outputs. @@ -41,65 +32,43 @@ def __init__( block_size: int, hidden_size: int, dtype: torch.dtype, - model_config: OmniModelConfig, ): self.num_blocks = num_blocks self.block_size = block_size self.default_hidden_size = hidden_size self.dtype = dtype - self.mm_cache_keys = self._resolve_mm_cache_keys(model_config) - self._initialize_omni_tensor_caches(self.mm_cache_keys) - self._new_req_cache_hit_ids: set[str] = set() - - def _resolve_mm_cache_keys(self, model_config: OmniModelConfig) -> StageMMCacheKeys | None: - """Determined the configuration for multimodal caching for the current model - architecture and stage.""" - model_stage = model_config.model_stage - arch, arch_str = get_model_architecture(model_config) - if hasattr(arch, "_model_mm_cache_keys"): - model_mm_cache_keys = getattr(arch, "_model_mm_cache_keys") - if model_stage in model_mm_cache_keys: - stage_mm_cache_keys = model_mm_cache_keys[model_stage] - logger.info(f"Resolved mm_cache_keys for stage {model_stage} - {stage_mm_cache_keys}") - return stage_mm_cache_keys - - # TODO: Move have_multimodal_outputs to class property and set this log to - # only go off for models that support have_multimodal_outputs, since the - # hidden states caching is generic - logger.warning( - f"Model architecture {arch_str} does not have defined _mm_cache_keys and will" - " therefore not able leverage prefix caching for multimodal outputs. " - " As such, prefix caching may not be supported." - ) + # Initialize the hidden states cache immediately + self.hidden_states_cache = self._get_cache_tensor() - def _initialize_omni_tensor_caches(self, mm_cache_keys: StageMMCacheKeys | None): - """Initialize the Omni Tensor cache tensors; this handles both the - hidden states cache and the multimodal outputs cache. + # Defer initialization of the mm_outputs_cache until we + # actually see mm output tensors dependent on num tokens. + self.mm_outputs_cache = {} + self.mm_cache_keys = set() + self._new_req_cache_hit_ids: set[str] = set() - The hidden_states cache is a tensor with shape: - (num_blocks, block_size, self.default_hidden_size) + def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: int): + """Given multimodal outputs from executing the model, dynamically + determine which multimodal outputs are tensors depending on sequence + length and should be cached, and initialize the cache tensors + accordingly. - While the mm_outputs_cache is dict mapping keys to tensors of shape: - (num_blocks, block_size, feature_size) + NOTE: This is done to avoid the need for explicit specification of + cache keys for every model/stage and aligns with the current way + that we slice the multimodal outputs based on the first dimension. - By default, if mm_cache_keys is a list, feature_size is set to the - default hidden size for all mm_output_keys. We also accept a dict - mapping to feature sizes on a per key basis, falling back to - self.default_hidden_size. for any keys that are None. + This will usually be called by the first forward pass, i.e., + determined by the warmup. """ - self.hidden_states_cache = self._get_cache_tensor() - - self.mm_outputs_cache = {} - if mm_cache_keys: - if isinstance(mm_cache_keys, dict): - for cache_key, hidden_size in mm_cache_keys.items(): - self.mm_outputs_cache[cache_key] = self._get_cache_tensor( - hidden_size=hidden_size, - ) - else: - for cache_key in mm_cache_keys: - self.mm_outputs_cache[cache_key] = self._get_cache_tensor() + for key, val in multimodal_outputs.items(): + if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys: + feat_dim = val.shape[-1] + self.mm_outputs_cache[key] = self._get_cache_tensor( + hidden_size=feat_dim, + ) + self.mm_cache_keys.add(key) + new_tensor_shape = self.mm_outputs_cache[key].shape + logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key) def _get_cache_tensor(self, hidden_size: int | None = None) -> torch.Tensor: """Allocate a CPU cache tensor for a specific key.""" @@ -179,7 +148,7 @@ def get_merged_multimodal_states( num_scheduled_tokens=num_scheduled_tokens, ) else: - # Note an mm_cache_keys; pass through normally + # Not an mm cache key; pass it normally combined_multimodal_outputs[mm_key] = mm_val return combined_multimodal_outputs diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 5cd6781579b..7df69479734 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -36,7 +36,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler -from vllm_omni.core.prefix_cache import ModelMMCacheKeys from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import ( @@ -99,11 +98,6 @@ class Qwen3OmniMoeForConditionalGeneration( """ realtime_max_tokens = 64 - # Currently, we only support prefix caching for the thinker stage - _model_mm_cache_keys: ModelMMCacheKeys = { - # keys 0 & 24 for the thinker have the same dimensionality as hidden states - "thinker": ["0", "24"] - } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 96d9b3a7d44..5573214f713 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -502,6 +502,11 @@ def execute_model( # Cache hidden states if we've enabled hidden state prefix caching # unless this isn't the last pipeline parallelism rank. if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: + self.omni_prefix_cache.maybe_init_missing_mm_cache_keys( + multimodal_outputs, + seq_len=hidden_states.shape[0], + ) + self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, multimodal_outputs=multimodal_outputs, diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 444fa51c518..61e52d8f96e 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -82,7 +82,6 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): block_size=self.cache_config.block_size, hidden_size=self.model_config.get_hidden_size(), dtype=self.dtype, - model_config=self.model_config, ) @instrument(span_name="Loading (GPU)") From da3982b6c57dab412d44be9b1ff15c1f26d0b28f Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 20:41:52 +0000 Subject: [PATCH 19/38] update tests to use dynamic keys Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 151 +++++++++++++++++++------------- 1 file changed, 91 insertions(+), 60 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index c81a1cf0fd0..003a00d5cd1 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -1,10 +1,11 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import torch from vllm_omni.core.prefix_cache import OmniTensorPrefixCache +DEFAULT_SEQ_LEN = 15 NUM_BLOCKS = 10 BLOCK_SIZE = 4 HIDDEN_SIZE = 2 @@ -17,64 +18,94 @@ def __init__(self, num_computed_tokens_cpu): self.req_ids = ["req1", "req2"] self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)} self.num_computed_tokens_cpu = num_computed_tokens_cpu + # Block table is only mocked for validation of length; + # we don't actually need to add valid values here since + # we patch the table when testing. + self.block_table = Mock() + self.block_table.block_tables = [None] + + +def get_omni_pcache_with_mm_tensors(feat_dims, seq_len) -> OmniTensorPrefixCache: + """Build an OmniTensorPrefixCache and init mm tensors.""" + cache = get_omni_pcache() + mm_outputs = get_multimodal_outputs(feat_dims, seq_len) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, seq_len) + return cache + + +def get_omni_pcache() -> OmniTensorPrefixCache: + """Build an OmniTensorPrefixCache, but don't init mm tensors.""" + cache = OmniTensorPrefixCache( + num_blocks=NUM_BLOCKS, + block_size=BLOCK_SIZE, + hidden_size=HIDDEN_SIZE, + dtype=DTYPE, + ) + return cache -def build_cache_with_mm_keys(mm_cache_keys) -> OmniTensorPrefixCache: - with patch( - "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._resolve_mm_cache_keys", - return_value=mm_cache_keys, - ): - # Model config is only used for resolving the mm_cache_keys, - # so the value passed here doesn't matter since it's patched. - return OmniTensorPrefixCache( - num_blocks=NUM_BLOCKS, - block_size=BLOCK_SIZE, - hidden_size=HIDDEN_SIZE, - dtype=DTYPE, - model_config=None, - ) +def get_multimodal_outputs(feat_dims: dict[str, int], seq_len: int) -> dict[str, torch.Tensor]: + fake_mm_inputs = {} + for mm_key, feat_dim in feat_dims.items(): + fake_mm_inputs[mm_key] = torch.rand((seq_len, feat_dim), dtype=DTYPE) + return fake_mm_inputs ### Tests for initialization -def test_initialization_from_list_of_cache_keys(): - """Ensure that hidden states / mm outputs cache are created with the - correct sizes by default. - """ - mm_cache_keys = ["foo", "bar"] - cache = build_cache_with_mm_keys(mm_cache_keys) +def test_initialization_simple(): + """Check default initialization only creates the hidden states.""" + cache = get_omni_pcache() assert isinstance(cache.hidden_states_cache, torch.Tensor) assert cache.hidden_states_cache.shape == DEFAULT_SHAPE - assert set(mm_cache_keys) == set(cache.mm_outputs_cache.keys()) - for val in cache.mm_outputs_cache.values(): - assert isinstance(val, torch.Tensor) - assert val.shape == DEFAULT_SHAPE + assert len(cache.mm_outputs_cache) == 0 + assert len(cache.mm_cache_keys) == 0 -def test_initialization_from_dict_of_cache_keys(): - """Ensure that keys in the mm outputs cache can have their own feature - sizes and fall back to the hidden states cache size if they map to None. - """ - mm_cache_keys = { - "foo": 100, - "bar": 50, - "baz": None, - } - cache = build_cache_with_mm_keys(mm_cache_keys) - assert isinstance(cache.hidden_states_cache, torch.Tensor) - assert cache.hidden_states_cache.shape == DEFAULT_SHAPE - assert set(mm_cache_keys) == set(cache.mm_outputs_cache.keys()) +def test_initialization_with_multimodal(): + """Check initialization + registration of multimodal outputs.""" + cache = get_omni_pcache() + feat_dims = {"foo": 100, "bar": 50, "baz": 10} + mm_outputs = get_multimodal_outputs( + feat_dims, + seq_len=DEFAULT_SEQ_LEN, + ) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 3 + assert set(cache.mm_cache_keys) == set(feat_dims.keys()) + for mm_key in cache.mm_cache_keys: + cache_tensor = cache.mm_outputs_cache[mm_key] + assert isinstance(cache_tensor, torch.Tensor) + assert cache_tensor.shape[-1] == feat_dims[mm_key] + + +def test_init_missing_mm_cache_keys_is_idempotent(): + """Ensure that the cache doesn't reinitialize old keys.""" + cache = get_omni_pcache() + mm_key = "foo" + feat_dims = {mm_key: 100} + mm_outputs = get_multimodal_outputs( + feat_dims, + seq_len=DEFAULT_SEQ_LEN, + ) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 1 + assert mm_key in cache.mm_cache_keys + + # Cache is initialized to 0 - fill it with 1s + cache.mm_outputs_cache[mm_key].fill_(1) - for key, val in cache.mm_outputs_cache.items(): - assert isinstance(val, torch.Tensor) - hs_override = mm_cache_keys[key] if mm_cache_keys[key] is not None else HIDDEN_SIZE - expected_shape = torch.Size([NUM_BLOCKS, BLOCK_SIZE, hs_override]) - assert val.shape == expected_shape + # Ensure that running another initialization + # doesn't zero out our cache values + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) + assert len(cache.mm_cache_keys) == 1 + assert mm_key in cache.mm_cache_keys + assert torch.all(cache.mm_outputs_cache[mm_key] == 1) ### Tests for Update def test_update_no_multimodal(): """Test that slot mappings act as row indices hidden states.""" - cache = build_cache_with_mm_keys(mm_cache_keys=None) + cache = get_omni_pcache() num_tokens_unpadded = 8 slot_offset = 8 @@ -97,21 +128,21 @@ def test_update_no_multimodal(): @pytest.mark.parametrize( - "mm_cache_keys", + "feat_dims", [ - ("foo", "bar"), # All same feature dim (HIDDEN_SIZE) - {"foo": 100, "bar": 50, "baz": None}, # different feature dims + {"foo": 100, "bar": 100}, + {"foo": 100, "bar": 50, "baz": 10}, ], ) -def test_update_with_multimodal_outputs(mm_cache_keys): +def test_update_with_multimodal_outputs(feat_dims): """Test that slot mappings are correct for multimodal tensors.""" - cache = build_cache_with_mm_keys(mm_cache_keys) + cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN) num_tokens_unpadded = 8 slot_offset = 8 slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} - mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in mm_cache_keys} + mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys} cache.update_omni_tensor_prefix_cache( hidden_states=None, multimodal_outputs=mm_outputs, @@ -119,7 +150,7 @@ def test_update_with_multimodal_outputs(mm_cache_keys): slot_mapping=slot_mapping, ) - for mm_key in mm_cache_keys: + for mm_key in feat_dims.keys(): assert mm_key in cache.mm_outputs_cache key_feat_dim = feature_dims[mm_key] mm_state_rows = cache.mm_outputs_cache[mm_key].view(NUM_BLOCKS * BLOCK_SIZE, key_feat_dim) @@ -150,7 +181,7 @@ def fake_get_cached_block_ids(self, req_idx, *args, **kwargs): def test_get_merged_hidden_states(): """Ensure that hidden states are merged correctly.""" - cache = build_cache_with_mm_keys(mm_cache_keys=None) + cache = get_omni_pcache() orig_num_tokens_unpadded = 8 slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 @@ -209,21 +240,21 @@ def test_get_merged_hidden_states(): @pytest.mark.parametrize( - "mm_cache_keys", + "feat_dims", [ - ("foo", "bar"), # All same feature dim (HIDDEN_SIZE) - {"foo": 100, "bar": 50, "baz": None}, # different feature dims + {"foo": 100, "bar": 100}, + {"foo": 100, "bar": 50, "baz": 10}, ], ) -def test_get_merged_multimodal_outputs(mm_cache_keys): - cache = build_cache_with_mm_keys(mm_cache_keys) +def test_get_merged_multimodal_outputs(feat_dims): + cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN) orig_num_tokens_unpadded = 8 slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15 orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded) feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()} orig_mm_outputs = { - key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in mm_cache_keys + key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys } cache.update_omni_tensor_prefix_cache( @@ -244,7 +275,7 @@ def test_get_merged_multimodal_outputs(mm_cache_keys): } new_mm_outputs = {} - for mm_key in mm_cache_keys: + for mm_key in cache.mm_cache_keys: new_mm_outputs[mm_key] = torch.rand( (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]), dtype=DTYPE, @@ -274,7 +305,7 @@ def test_get_merged_multimodal_outputs(mm_cache_keys): assert new_mm_outputs[mm_key] == mm_output assert new_mm_outputs[mm_key] == mm_output else: - assert mm_key in mm_cache_keys + assert mm_key in cache.mm_cache_keys assert isinstance(mm_output, dict) assert "req1" in mm_output and "req2" in mm_output curr_feat_dim = feature_dims[mm_key] From 71b97b3d2474e73a0fea6c5e34355d695424abe0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 21:04:20 +0000 Subject: [PATCH 20/38] wip refactoring for sample tokens Signed-off-by: Alex Brooks --- vllm_omni/worker/gpu_ar_model_runner.py | 97 ++++++++++++++++--------- 1 file changed, 61 insertions(+), 36 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 5573214f713..dd271dfe645 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -627,6 +627,54 @@ def _sample( return super()._sample(logits, spec_decode_metadata) + @staticmethod + def _build_mm_cpu(multimodal_outputs, seq_len) -> dict[str, object]: + # Pre-copy multimodal tensors to CPU once (not per-request) to avoid + # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + mm_cpu: dict[str, object] = {} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list): + if len(v) == 0: + continue + cpu_list = [] + for elem in v: + if isinstance(elem, torch.Tensor): + cpu_list.append(elem.detach().to("cpu").contiguous()) + else: + cpu_list.append(elem) + mm_cpu[k] = cpu_list + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + return mm_cpu + + @staticmethod + def _resolve_req_hidden_states( + hidden_states_cpu: torch.Tensor, + combined_hidden_states: dict[str, torch.Tensor] | None, + rid: str, + start: int, + end: int, + ): + if combined_hidden_states is not None: + # We always have all request IDs for prefix cache, even for + # partial cache misses, so this should never happen. + if rid not in combined_hidden_states: + raise RuntimeError("Request IDs in the batch are missing from the merged states!") + return combined_hidden_states[rid] + # Prefix caching is disabled + return hidden_states_cpu[start:end] + @torch.inference_mode() def sample_tokens( self, @@ -812,33 +860,7 @@ def propose_draft_token_ids(sampled_token_ids): combined_hidden_states, ) - # Pre-copy multimodal tensors to CPU once (not per-request) to avoid - # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. - mm_cpu: dict[str, object] = {} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: - for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]: - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list): - if len(v) == 0: - continue - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") + mm_cpu = self._build_mm_cpu(multimodal_outputs, seq_len=hidden_states_cpu.shape[0]) pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: @@ -846,14 +868,15 @@ def propose_draft_token_ids(sampled_token_ids): start = int(self.query_start_loc.cpu[idx]) sched = int(num_scheduled_tokens_np[idx]) end = start + sched - # For prefix cache on hidden states - if it's a request - # in the combined hidden states, it's a cache hit, so we - # send the states that were already merged. - if combined_hidden_states and rid in combined_hidden_states: - # TODO cleanup device management - req_hidden_states = combined_hidden_states[rid].detach().to("cpu").contiguous() - else: - req_hidden_states = hidden_states_cpu[start:end] + # If prefix cache is enabled, we have already split everything + # by request and converted the states to CPU tensors + req_hidden_states = self._resolve_req_hidden_states( + hidden_states_cpu, + combined_hidden_states, + rid, + start, + end, + ) payload: dict[str, object] = {"hidden": req_hidden_states} logger.info( @@ -872,7 +895,9 @@ def propose_draft_token_ids(sampled_token_ids): mm_payload[k] = combined_multimodal_outputs[k][rid].detach().to("cpu").contiguous() logger.info(f"Cached mm key {k} | shape: {combined_multimodal_outputs[k][rid].shape}") - elif isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: + # We probably need to handle the passthrough data correctly here, + # since for now we just pass it as is. + if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: mm_payload[k] = v[start:end].contiguous() elif isinstance(v, dict): mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} From 6e5ce9aa38d964b16431c7c37d7010842f5b1965 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 22:08:44 +0000 Subject: [PATCH 21/38] fix passthrough edge cases Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 48 +++++++++-- vllm_omni/utils/mm_outputs.py | 62 ++++++++++++++ vllm_omni/worker/gpu_ar_model_runner.py | 102 +++++++----------------- 3 files changed, 134 insertions(+), 78 deletions(-) create mode 100644 vllm_omni/utils/mm_outputs.py diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index dc8ca42037b..a4ab7245c82 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -6,6 +6,8 @@ from vllm.logger import init_logger from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element + logger = init_logger(__name__) @@ -122,6 +124,26 @@ def update_omni_tensor_prefix_cache( flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) + def _coerce_to_payload_dict( + self, + element: object, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, object]: + """Build the multimodal passthrough data per request for + the object under consideration. This is identical to the case + for no prefix cache when we tensor does have a first dimension + matching the seq len. + """ + elem_dict = {} + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + start = query_start_loc[req_idx] + end = start + num_scheduled_tokens[req_id] + elem_dict[req_id] = to_payload_element(element, req_idx, start=start, end=end, seq_len=None) + return elem_dict + def get_merged_multimodal_states( self, query_start_loc: torch.Tensor, @@ -137,19 +159,33 @@ def get_merged_multimodal_states( " this probably means that prefix caching is not fully supported for all stages " "in this model" ) - - for mm_key, mm_val in multimodal_outputs.items(): + # First get the prefix cached tensors + for mm_key in self.mm_cache_keys: if self.mm_cache_keys is not None and mm_key in self.mm_cache_keys: combined_multimodal_outputs[mm_key] = self._get_merged_tensors( query_start_loc=query_start_loc, input_batch=input_batch, cache=self.mm_outputs_cache[mm_key], - hidden_states=mm_val, + hidden_states=multimodal_outputs[mm_key], num_scheduled_tokens=num_scheduled_tokens, ) - else: - # Not an mm cache key; pass it normally - combined_multimodal_outputs[mm_key] = mm_val + + # Then, get everything else (passthrough data); first, convert to CPU + # tensors similarly to the non prefix cached path, and then populate + # the subdicts mapping request IDs -> payload objects + passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys + passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys} + mm_cpu = build_mm_cpu( + multimodal_outputs=passthrough_mm_data, + seq_len=None, + ) + for mm_key, mm_val in mm_cpu.items(): + combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict( + element=mm_val, + query_start_loc=query_start_loc, + input_batch=input_batch, + num_scheduled_tokens=num_scheduled_tokens, + ) return combined_multimodal_outputs def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]: diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py new file mode 100644 index 00000000000..03fd326069d --- /dev/null +++ b/vllm_omni/utils/mm_outputs.py @@ -0,0 +1,62 @@ +"""Utilities for handling multimodal outputs / building multimodal output +payloads, most of which are shared by the prefix cache / no prefix cache path. +""" + +import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def build_mm_cpu(multimodal_outputs, seq_len: int | None) -> dict[str, object]: + # Pre-copy multimodal tensors to CPU once (not per-request) to avoid + # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + mm_cpu: dict[str, object] = {} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + for k, v in multimodal_outputs.items(): + try: + if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list): + if len(v) == 0: + continue + cpu_list = [] + for elem in v: + if isinstance(elem, torch.Tensor): + cpu_list.append(elem.detach().to("cpu").contiguous()) + else: + cpu_list.append(elem) + mm_cpu[k] = cpu_list + except Exception as e: + logger.error(f"Error in merge multimodal outputs: {e}") + return mm_cpu + + +def to_payload_element(element, idx, start, end, seq_len: int | None = None): + """Given""" + # Prefix cache won't hit this case because this is the considition + # for being a mm_cache_key in the multimodal outputs tensor. + if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len: + return element[start:end].contiguous() + # Every other case is shared between prefix cache (passthrough data) + # and running a model without prefix caching. + elif isinstance(element, dict): + return {sk: sv[start:end].contiguous() for sk, sv in element.items()} + elif isinstance(element, list): + element = element[idx] if idx < len(element) else element[0] + # Clone tensors to avoid cross-request aliasing + if isinstance(element, torch.Tensor): + element = element.clone() + return element + elif isinstance(element, torch.Tensor): + # List-derived tensor payloads are request-invariant; clone to + # avoid accidental cross-request aliasing on downstream mutation. + return element.clone() + return element diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index dd271dfe645..ac5c6da1871 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -39,6 +39,7 @@ from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin @@ -627,37 +628,6 @@ def _sample( return super()._sample(logits, spec_decode_metadata) - @staticmethod - def _build_mm_cpu(multimodal_outputs, seq_len) -> dict[str, object]: - # Pre-copy multimodal tensors to CPU once (not per-request) to avoid - # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. - mm_cpu: dict[str, object] = {} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: - for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list): - if len(v) == 0: - continue - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") - return mm_cpu - @staticmethod def _resolve_req_hidden_states( hidden_states_cpu: torch.Tensor, @@ -683,6 +653,13 @@ def sample_tokens( kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) self.kv_extracted_req_ids = None + # Used for prefix cache + combined_hidden_states = None + combined_multimodal_outputs = None + # Used when we don't use prefix cache; prefix cache builds the payloads + # internally since it already needs to do this for the cached tensors + mm_cpu = {} + if self.execute_model_state is None: kv_connector_output = self.kv_connector_output self.kv_connector_output = None @@ -714,6 +691,7 @@ def sample_tokens( slot_mappings, # OMNI: unpack slot_mappings for drafter ) = self.execute_model_state self.execute_model_state = None + seq_len = hidden_states.shape[0] # Apply structured output bitmasks if present. if grammar_output is not None: @@ -837,9 +815,7 @@ def propose_draft_token_ids(sampled_token_ids): # Prior to applying the post-processing func, extract # the prefix cached hidden states and multimodal states. - if self.omni_prefix_cache is None: - combined_hidden_states, combined_multimodal_outputs = None, None - else: + if self.omni_prefix_cache is not None: combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( query_start_loc=self.query_start_loc.cpu, input_batch=self.input_batch, @@ -852,6 +828,8 @@ def propose_draft_token_ids(sampled_token_ids): multimodal_outputs=multimodal_outputs, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) + else: + mm_cpu = build_mm_cpu(multimodal_outputs, seq_len=seq_len) self._process_additional_information_updates( hidden_states, @@ -860,8 +838,6 @@ def propose_draft_token_ids(sampled_token_ids): combined_hidden_states, ) - mm_cpu = self._build_mm_cpu(multimodal_outputs, seq_len=hidden_states_cpu.shape[0]) - pooler_output: list[dict[str, object]] = [] for rid in req_ids_output_copy: idx = req_id_to_index_output_copy[rid] @@ -879,42 +855,24 @@ def propose_draft_token_ids(sampled_token_ids): ) payload: dict[str, object] = {"hidden": req_hidden_states} - logger.info( - f"[HS] req={rid} hidden_shape={req_hidden_states.shape} " - f"cache_hit={rid in combined_hidden_states if combined_hidden_states else False}" - ) - - if mm_cpu: - mm_payload: dict[str, object] = {} - for k, v in mm_cpu.items(): - if ( - combined_multimodal_outputs - and k in combined_multimodal_outputs - and rid in combined_multimodal_outputs[k] - ): - mm_payload[k] = combined_multimodal_outputs[k][rid].detach().to("cpu").contiguous() - logger.info(f"Cached mm key {k} | shape: {combined_multimodal_outputs[k][rid].shape}") - - # We probably need to handle the passthrough data correctly here, - # since for now we just pass it as is. - if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]: - mm_payload[k] = v[start:end].contiguous() - elif isinstance(v, dict): - mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()} - elif isinstance(v, list): - element = v[idx] if idx < len(v) else v[0] - if element is not None: - if isinstance(element, torch.Tensor): - element = element.clone() - mm_payload[k] = element - # Skip None elements: msgspec cannot serialize None - # in dict[str, torch.Tensor] typed fields. - elif isinstance(v, torch.Tensor): - # List-derived tensor payloads are request-invariant; clone to - # avoid accidental cross-request aliasing on downstream mutation. - mm_payload[k] = v.clone() - else: - mm_payload[k] = v + mm_payload: dict[str, object] = {} + if combined_multimodal_outputs or mm_cpu: + if combined_multimodal_outputs: + # Prefix cache enabled; all items have already been processed + # and split apart for each request as needed, and all tensors + # have already been detached to the CPU. + for mm_key in combined_multimodal_outputs.keys(): + mm_payload[mm_key] = combined_multimodal_outputs[mm_key][rid] + else: + # Prefix cache disabled; we still need to process the data + for mm_key, mm_val in mm_cpu.items(): + mm_payload[mm_key] = to_payload_element( + element=mm_val, + idx=idx, + start=start, + end=end, + seq_len=seq_len, + ) payload.update(mm_payload) pooler_output.append(payload) with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): From 51c3ffd002eb7803d36bb2b01f5a57d5c8a31862 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 22:26:34 +0000 Subject: [PATCH 22/38] improve docstrings Signed-off-by: Alex Brooks --- vllm_omni/utils/mm_outputs.py | 71 ++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 03fd326069d..dbbdebb5a09 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -8,39 +8,58 @@ logger = init_logger(__name__) -def build_mm_cpu(multimodal_outputs, seq_len: int | None) -> dict[str, object]: +def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, object]: + """Pre-copies multimodal tensor to CPU once (not per-request) to avoid + redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. + + In the case of prefix caching, the multimodal outputs provided will + only contain the passthrough data. + + Args: + multimodal_outputs: Multimodal dict mapping strings to objects. + seq_len: Optional sequence length (i.e., dim 0 of hidden states). + This should be set to None in the prefix caching case since we + only consider the passtrhough data. + """ # Pre-copy multimodal tensors to CPU once (not per-request) to avoid # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. mm_cpu: dict[str, object] = {} - if isinstance(multimodal_outputs, dict) and multimodal_outputs: + if multimodal_outputs: for k, v in multimodal_outputs.items(): - try: - if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: - mm_cpu[k] = v.detach().to("cpu").contiguous() - elif isinstance(v, dict): - sub_dict: dict[str, torch.Tensor] = {} - for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: - sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() - if sub_dict: - mm_cpu[k] = sub_dict - elif isinstance(v, list): - if len(v) == 0: - continue - cpu_list = [] - for elem in v: - if isinstance(elem, torch.Tensor): - cpu_list.append(elem.detach().to("cpu").contiguous()) - else: - cpu_list.append(elem) - mm_cpu[k] = cpu_list - except Exception as e: - logger.error(f"Error in merge multimodal outputs: {e}") + if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: + mm_cpu[k] = v.detach().to("cpu").contiguous() + elif isinstance(v, dict): + sub_dict: dict[str, torch.Tensor] = {} + for sk, sv in v.items(): + if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: + sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() + if sub_dict: + mm_cpu[k] = sub_dict + elif isinstance(v, list) and len(v) > 0: + cpu_list = [] + for elem in v: + if isinstance(elem, torch.Tensor): + cpu_list.append(elem.detach().to("cpu").contiguous()) + else: + cpu_list.append(elem) + mm_cpu[k] = cpu_list return mm_cpu -def to_payload_element(element, idx, start, end, seq_len: int | None = None): - """Given""" +def to_payload_element(element, idx: int, start: int, end: int, seq_len: int | None = None): + """Build an mm payload element corresponding to one request index + from an element containing 0 or more CPU tensors. + + Args: + element: The object to be added to the payload. + idx: The index of the request. + start: The start index corresponding to the request idx. + end: The end index corresponding to the request idx. + seq_len: Optional sequence length (i.e., dim 0 of hidden states). + This should be set to None in the prefix caching case, because + the condition that would be executed here is the same as the + criteria for being added to the multimodal outputs cache. + """ # Prefix cache won't hit this case because this is the considition # for being a mm_cache_key in the multimodal outputs tensor. if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len: From 6a5bb32875e95d475bef12033a7aae3e51f0a254 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 22:37:27 +0000 Subject: [PATCH 23/38] don't drop mm_cpu data Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 9 +++++---- vllm_omni/utils/mm_outputs.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 003a00d5cd1..6fea0247a4e 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -301,13 +301,14 @@ def test_get_merged_multimodal_outputs(feat_dims): for mm_key, mm_output in merged_mm_outputs.items(): # Ensure passthrough data is just forwarded normally and not duplicated + assert isinstance(mm_output, dict) + assert "req1" in mm_output and "req2" in mm_output if mm_key == "passthrough_data": - assert new_mm_outputs[mm_key] == mm_output - assert new_mm_outputs[mm_key] == mm_output + assert mm_key not in cache.mm_cache_keys + assert new_mm_outputs[mm_key] == mm_output["req1"] + assert new_mm_outputs[mm_key] == mm_output["req2"] else: assert mm_key in cache.mm_cache_keys - assert isinstance(mm_output, dict) - assert "req1" in mm_output and "req2" in mm_output curr_feat_dim = feature_dims[mm_key] # Ensure that req1 (cache hit) merged the mm data req1_merged_mm_outputs = mm_output["req1"] diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index dbbdebb5a09..31236e9cc68 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -26,12 +26,12 @@ def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, obj mm_cpu: dict[str, object] = {} if multimodal_outputs: for k, v in multimodal_outputs.items(): - if isinstance(v, torch.Tensor) and v.shape[0] == seq_len: + if isinstance(v, torch.Tensor): mm_cpu[k] = v.detach().to("cpu").contiguous() elif isinstance(v, dict): sub_dict: dict[str, torch.Tensor] = {} for sk, sv in v.items(): - if isinstance(sv, torch.Tensor) and sv.shape[0] == seq_len: + if isinstance(sv, torch.Tensor): sub_dict[str(sk)] = sv.detach().to("cpu").contiguous() if sub_dict: mm_cpu[k] = sub_dict @@ -43,6 +43,8 @@ def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, obj else: cpu_list.append(elem) mm_cpu[k] = cpu_list + else: + mm_cpu[k] = v return mm_cpu From da67c30621ec7d558c92602995bd04edc97577b0 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 30 Mar 2026 22:41:38 +0000 Subject: [PATCH 24/38] minor cleanup Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 6 ++---- vllm_omni/utils/mm_outputs.py | 9 +++++++-- vllm_omni/worker/gpu_ar_model_runner.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index a4ab7245c82..7e0491d8fe7 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -175,10 +175,8 @@ def get_merged_multimodal_states( # the subdicts mapping request IDs -> payload objects passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys} - mm_cpu = build_mm_cpu( - multimodal_outputs=passthrough_mm_data, - seq_len=None, - ) + mm_cpu = build_mm_cpu(multimodal_outputs=passthrough_mm_data) + for mm_key, mm_val in mm_cpu.items(): combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict( element=mm_val, diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 31236e9cc68..aab63885f27 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -8,7 +8,7 @@ logger = init_logger(__name__) -def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, object]: +def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: """Pre-copies multimodal tensor to CPU once (not per-request) to avoid redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. @@ -24,6 +24,11 @@ def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, obj # Pre-copy multimodal tensors to CPU once (not per-request) to avoid # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. mm_cpu: dict[str, object] = {} + # Currently there are some cases where this is true at the + # moment, which should be fixed. + if not isinstance(multimodal_outputs, dict): + logger.warning("Multimodal outputs are not a dict and will not be passed") + if multimodal_outputs: for k, v in multimodal_outputs.items(): if isinstance(v, torch.Tensor): @@ -48,7 +53,7 @@ def build_mm_cpu(multimodal_outputs: dict, seq_len: int | None) -> dict[str, obj return mm_cpu -def to_payload_element(element, idx: int, start: int, end: int, seq_len: int | None = None): +def to_payload_element(element: object, idx: int, start: int, end: int, seq_len: int | None = None): """Build an mm payload element corresponding to one request index from an element containing 0 or more CPU tensors. diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index ac5c6da1871..2f7913fd538 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -829,7 +829,7 @@ def propose_draft_token_ids(sampled_token_ids): num_scheduled_tokens=scheduler_output.num_scheduled_tokens, ) else: - mm_cpu = build_mm_cpu(multimodal_outputs, seq_len=seq_len) + mm_cpu = build_mm_cpu(multimodal_outputs) self._process_additional_information_updates( hidden_states, From 0879df0212a1bcc0afda0fb809ab83d88685f7cc Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 31 Mar 2026 02:15:48 +0000 Subject: [PATCH 25/38] add end to end online tests for prefix cache Signed-off-by: Alex Brooks --- tests/conftest.py | 6 ++ tests/e2e/online_serving/test_qwen3_omni.py | 63 +++++++++++++++++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index adb87cbd728..54e7f41bcb1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1850,6 +1850,7 @@ class OmniResponse: e2e_latency: float | None = None success: bool = False error_message: str | None = None + cached_tokens: int | None = None @dataclass @@ -2345,6 +2346,11 @@ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: if hasattr(choice.message, "content") and choice.message.content is not None: text_content = choice.message.content + # Extract cached_tokens for prefix caching tests + usage = getattr(chat_completion, "usage", None) + if usage and (details := getattr(usage, "prompt_tokens_details", None)): + result.cached_tokens = details.cached_tokens + # Calculate end-to-end latency result.e2e_latency = time.perf_counter() - start_time diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index f4aabb8b957..14cb4c5df17 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -23,11 +23,13 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] +QWEN3_OMNI_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml") +QWEN3_OMNI_XPU_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml") -def get_chunk_config(): +def get_chunk_config(config_path: str): path = modify_stage_config( - str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"), + config_path, updates={ "async_chunk": True, "stage_args": { @@ -44,15 +46,41 @@ def get_chunk_config(): return path +def get_prefix_caching_config(config_path: str): + """Create a stage config with prefix caching enabled on the thinker (stage 0).""" + path = modify_stage_config( + config_path, + updates={ + "stage_args": { + 0: {"engine_args.enable_prefix_caching": True}, + }, + }, + ) + return path + + if current_omni_platform.is_xpu(): - stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")] + stage_configs = [QWEN3_OMNI_XPU_CONFIG_PATH] + prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_XPU_CONFIG_PATH)] else: # MI325 GPU should share the same config as H100 - stage_configs = [get_chunk_config()] + stage_configs = [get_chunk_config(QWEN3_OMNI_CONFIG_PATH)] + prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_CONFIG_PATH)] # Create parameter combinations for model and stage config test_params = [ OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs ] +# For prefix caching, we need to enable prompt token details so that we +# can determine if any tokens were cached. +prefix_test_params = [ + OmniServerParams( + model=model, + stage_config_path=stage_config, + server_args=["--enable-prompt-tokens-details"], # Enable prompt tokens details to get cached_tokens + ) + for model in models + for stage_config in prefix_caching_stage_configs +] def get_system_prompt(): @@ -147,3 +175,30 @@ def test_text_to_text_001(omni_server, openai_client) -> None: } openai_client.send_omni_request(request_config, request_num=get_max_batch_size()) + + +@pytest.mark.advanced_model +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True) +def test_thinker_prefix_caching(omni_server, openai_client) -> None: + """ + Test thinker supports prefix caching by sending two identical + requests and checking the number of cached tokens. + """ + messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt()) + request_config = { + "model": omni_server.model, + "messages": messages, + "stream": False, + "modalities": ["text"], + } + + response_1 = openai_client.send_omni_request(request_config, request_num=1)[0] + response_2 = openai_client.send_omni_request(request_config, request_num=1)[0] + + assert response_1.success + assert response_2.success + assert response_2.cached_tokens is not None + assert response_2.cached_tokens > 0 From 3e1d86db671286e6b8fc3cec60a30124d2351061 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 31 Mar 2026 03:16:15 +0000 Subject: [PATCH 26/38] don't fix cache dtype Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 8 +++++++- vllm_omni/core/prefix_cache.py | 10 +++++----- vllm_omni/worker/gpu_model_runner.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 6fea0247a4e..3e20ecfe685 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -10,6 +10,7 @@ BLOCK_SIZE = 4 HIDDEN_SIZE = 2 DTYPE = torch.float32 +OTHER_DTYPE = torch.float16 DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE]) @@ -39,7 +40,7 @@ def get_omni_pcache() -> OmniTensorPrefixCache: num_blocks=NUM_BLOCKS, block_size=BLOCK_SIZE, hidden_size=HIDDEN_SIZE, - dtype=DTYPE, + hs_dtype=DTYPE, ) return cache @@ -69,6 +70,10 @@ def test_initialization_with_multimodal(): feat_dims, seq_len=DEFAULT_SEQ_LEN, ) + # Cast one of the keys to a different dtype; the dtype of the tensor + # that is used to initialize the cache dictates the cache dtype. + mm_outputs["foo"] = mm_outputs["foo"].to(OTHER_DTYPE) + cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN) assert len(cache.mm_cache_keys) == 3 assert set(cache.mm_cache_keys) == set(feat_dims.keys()) @@ -76,6 +81,7 @@ def test_initialization_with_multimodal(): cache_tensor = cache.mm_outputs_cache[mm_key] assert isinstance(cache_tensor, torch.Tensor) assert cache_tensor.shape[-1] == feat_dims[mm_key] + assert mm_outputs[mm_key].dtype == cache_tensor.dtype def test_init_missing_mm_cache_keys_is_idempotent(): diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 7e0491d8fe7..495cc98ad60 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -33,15 +33,14 @@ def __init__( num_blocks: int, block_size: int, hidden_size: int, - dtype: torch.dtype, + hs_dtype: torch.dtype, ): self.num_blocks = num_blocks self.block_size = block_size self.default_hidden_size = hidden_size - self.dtype = dtype # Initialize the hidden states cache immediately - self.hidden_states_cache = self._get_cache_tensor() + self.hidden_states_cache = self._get_cache_tensor(dtype=hs_dtype) # Defer initialization of the mm_outputs_cache until we # actually see mm output tensors dependent on num tokens. @@ -66,18 +65,19 @@ def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: in if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys: feat_dim = val.shape[-1] self.mm_outputs_cache[key] = self._get_cache_tensor( + dtype=val.dtype, hidden_size=feat_dim, ) self.mm_cache_keys.add(key) new_tensor_shape = self.mm_outputs_cache[key].shape logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key) - def _get_cache_tensor(self, hidden_size: int | None = None) -> torch.Tensor: + def _get_cache_tensor(self, dtype: torch.dtype, hidden_size: int | None = None) -> torch.Tensor: """Allocate a CPU cache tensor for a specific key.""" actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size return torch.zeros( (self.num_blocks, self.block_size, actual_hidden_size), - dtype=self.dtype, + dtype=dtype, device="cpu", ) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 61e52d8f96e..2b6c5d59ab2 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -81,7 +81,7 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes): num_blocks=kv_cache_config.num_blocks, block_size=self.cache_config.block_size, hidden_size=self.model_config.get_hidden_size(), - dtype=self.dtype, + hs_dtype=self.dtype, ) @instrument(span_name="Loading (GPU)") From 67fb53c9a883d2ba1a6566a171fdb59d462abd28 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 31 Mar 2026 03:26:40 +0000 Subject: [PATCH 27/38] fix docs Signed-off-by: Alex Brooks --- vllm_omni/utils/mm_outputs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index aab63885f27..884adb72e84 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -17,9 +17,6 @@ def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: Args: multimodal_outputs: Multimodal dict mapping strings to objects. - seq_len: Optional sequence length (i.e., dim 0 of hidden states). - This should be set to None in the prefix caching case since we - only consider the passtrhough data. """ # Pre-copy multimodal tensors to CPU once (not per-request) to avoid # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU. From 0bf148ee73c1112f9ab85008547a68ea44e7b6ad Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 31 Mar 2026 04:42:37 +0000 Subject: [PATCH 28/38] guard type for init mm caches Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 15 +++++++-------- vllm_omni/worker/gpu_ar_model_runner.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 495cc98ad60..536da78b5dd 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -161,14 +161,13 @@ def get_merged_multimodal_states( ) # First get the prefix cached tensors for mm_key in self.mm_cache_keys: - if self.mm_cache_keys is not None and mm_key in self.mm_cache_keys: - combined_multimodal_outputs[mm_key] = self._get_merged_tensors( - query_start_loc=query_start_loc, - input_batch=input_batch, - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=num_scheduled_tokens, - ) + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) # Then, get everything else (passthrough data); first, convert to CPU # tensors similarly to the non prefix cached path, and then populate diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 2f7913fd538..b1eeec0fcfb 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -503,10 +503,15 @@ def execute_model( # Cache hidden states if we've enabled hidden state prefix caching # unless this isn't the last pipeline parallelism rank. if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: - self.omni_prefix_cache.maybe_init_missing_mm_cache_keys( - multimodal_outputs, - seq_len=hidden_states.shape[0], - ) + if isinstance(multimodal_outputs, dict): + self.omni_prefix_cache.maybe_init_missing_mm_cache_keys( + multimodal_outputs, + seq_len=hidden_states.shape[0], + ) + else: + # This usually means that the stage doesn't have + # multimodal outputs, so only the hidden states cache + logger.warning_once("Omni prefix caching expects type dict, but got %s", type(multimodal_outputs)) self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, From 05d1da4002725bbb93d54b9ed08c8df91df9cb6b Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 1 Apr 2026 07:06:22 +0000 Subject: [PATCH 29/38] add docs for prefix cache Signed-off-by: Alex Brooks --- docs/.nav.yml | 1 + docs/design/feature/prefix_caching.md | 127 ++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 docs/design/feature/prefix_caching.md diff --git a/docs/.nav.yml b/docs/.nav.yml index 86ce4a3b0c4..8b364a715fb 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -97,6 +97,7 @@ nav: - design/feature/disaggregated_inference.md - design/feature/ray_based_execution.md - design/feature/omni_connectors/ + - design/feature/prefix_caching.md - design/feature/cfg_parallel.md - design/feature/expert_parallel.md - design/feature/sequence_parallel.md diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md new file mode 100644 index 00000000000..15c9b35bb85 --- /dev/null +++ b/docs/design/feature/prefix_caching.md @@ -0,0 +1,127 @@ +# Automatic Prefix Caching in Omni Models + + +--- + +## Table of Contents + +- [Overview](#overview) +- [High-Level Approach](#high-level-approach) +- [Example](#example) +- [Current Limitations](#current-limitations) + +--- + +### Overview + +Prefix caching in the context of kv-cache management is a useful optimization for avoiding redundant computations. The main idea is that we store portions of the kv-cache from processed requests, so that we can reuse them if incoming requests have the same prefix as previous requests. + +vLLM manages the kv-cache as blocks, which represent a span of tokens of a fixed length. Blocks are hashable by the content that they contain, which typically means the tokens within the span, but also could be influenced by other factors, e.g., LoRA and multimodal data. + +vLLM implements automatic prefix caching for managing its kv-cache, which is best understood by reading the design document [here](https://docs.vllm.ai/en/latest/design/prefix_caching/). vLLM-Omni builds on top of the prefix caching mechanism in a noninvasive way to allow caching between stages in Omni pipelines. This typically means for a given stage we aim to support caching for the following: + +- The last hidden states produced by the stage +- Model / stage specific multimodal data + +### High-Level Approach +!!! note "Note 1" + Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear. + +The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note: + +- Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU. + +- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. + +With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector. + + +### Example +!!! note "Note 2" + Prefix caching in vLLM Omni currently is only supported on AutoRegressive stages with one kv-cache group. + +The way in which vLLM Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following: + +- `num_blocks=8` +- `block_size=4` +- `hidden_size=2` +- A stage specific multimodal output `mm_feature` with feature dimension `16` + +The prefix cache flow is then outlined below. + +1. When the model is initialized, we can determine the `hidden_size` from the `ModelConfig`, and allocate a cache of size `(num_blocks, block_size, hidden_size)`. + +2. Say we process the request `The quick brown fox was tired and slept beneath the shady tree`, which is 12 tokens and evenly divides into 3 blocks as shown below. + +``` + [ The quick brown fox ] [ was tired and slept ] [beneath the shady tree ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->| +``` + +When the request processes, we inspect the multimodal outputs and identify the `mm_feature` tensor, which will be of shape `(seq_len, feature_dim)`, i.e., `(12, 16)` in this example. We note that the first axis is dependent on the `seq_len` and add a new cache_tensor of shape `(num_blocks, block_size, feature_dim)` to our multimodal cache for tensors. + + +3. If we lay out the cache as a 2D tensor of shape (`num_blocks`, `block_size`), we'll have something like the following: + +``` +0: [ The quick brown fox ] +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [EMPTY] +... +7: [EMPTY] +``` + +Or, if we flatten it down to 1D, +``` +0: The +1: quick +2: brown +3: fox +... +11: tree +12: [EMPTY] +... +``` + +which we can think of as row indices into the hidden states tensor if we view it as the 2D shape `(num_blocks x block_size, feature_dim)`. That is, the analogous flattened (from 3D -> 2D) mapping of the cache for hidden states becomes the following. +``` +0: +1: +2: +3: +... +11: +12: [EMPTY] +... +``` + +Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. + + +4. Now, say that we receive a new request `The quick brown fox jumped over the dog`. + +``` + [ The quick brown fox ] [ jumped over the dog ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +Here, we will have a cache hit for `Block 1` which will be detected by vLLM based on the hash of the first block when it's handling the prefix caching on the kv-cache. As a result, when we get the output from the scheduler, we will see that `num_computed_tokens=4` (corresponding to the cached first block), and we only need to process the remaining 4 new tokens in the new prefill. + +Since we have the block indices / slot mappings from the kv cache manager, we can simply mirror the mappings and leverage the same indices for the cached hidden states and multimodal outputs. This allows us to look up the correct tensors from our externally maintained 3D caches. + +``` +0: [ The quick brown fox ] < already in the cache +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [ jumped over the dog ] < added on the second request +4: [EMPTY] +... +7: [EMPTY] +... +``` + +Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call. From a8a7c8b2b04b9e85cedc87dbc4ab1bbf05adc622 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 1 Apr 2026 08:40:37 +0000 Subject: [PATCH 30/38] don't split lists before postprocess Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 8 ++++++++ vllm_omni/core/prefix_cache.py | 4 +++- .../npu/worker/npu_ar_model_runner.py | 2 +- vllm_omni/utils/mm_outputs.py | 12 ++++++++++-- vllm_omni/worker/gpu_ar_model_runner.py | 14 ++++++++++++-- vllm_omni/worker/gpu_model_runner.py | 19 +++++++++++++++++-- 6 files changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 3e20ecfe685..7b24e88d92b 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -288,6 +288,9 @@ def test_get_merged_multimodal_outputs(feat_dims): ) # We also want to make sure passthrough data (outside of our keys) isn't dropped new_mm_outputs["passthrough_data"] = "Something else" + # Lists are a special case because we can't split them yet if we want to match + # the nonprefix cache behavior, because this runs before post process. + new_mm_outputs["passthrough_list"] = ["should", "not", "split"] input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0])) @@ -304,6 +307,7 @@ def test_get_merged_multimodal_outputs(feat_dims): # Ensure the passthrough data wasn't dropped assert "passthrough_data" in merged_mm_outputs + assert "passthrough_list" in merged_mm_outputs for mm_key, mm_output in merged_mm_outputs.items(): # Ensure passthrough data is just forwarded normally and not duplicated @@ -313,6 +317,10 @@ def test_get_merged_multimodal_outputs(feat_dims): assert mm_key not in cache.mm_cache_keys assert new_mm_outputs[mm_key] == mm_output["req1"] assert new_mm_outputs[mm_key] == mm_output["req2"] + elif mm_key == "passthrough_list": + assert mm_key not in cache.mm_cache_keys + assert new_mm_outputs[mm_key] == mm_output["req1"] + assert new_mm_outputs[mm_key] == mm_output["req2"] else: assert mm_key in cache.mm_cache_keys curr_feat_dim = feature_dims[mm_key] diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 536da78b5dd..91085c4cbf9 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -141,7 +141,9 @@ def _coerce_to_payload_dict( req_idx = input_batch.req_id_to_index[req_id] start = query_start_loc[req_idx] end = start + num_scheduled_tokens[req_id] - elem_dict[req_id] = to_payload_element(element, req_idx, start=start, end=end, seq_len=None) + elem_dict[req_id] = to_payload_element( + element, req_idx, start=start, end=end, pass_lists_through=True, seq_len=None + ) return elem_dict def get_merged_multimodal_states( diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index fb5c1cf1367..ffb997048bd 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -639,7 +639,7 @@ def propose_draft_token_ids(sampled_token_ids): ) self._process_additional_information_updates( - hidden_states, num_scheduled_tokens_np, scheduler_output + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output ) # Pre-copy multimodal tensors to CPU once (not per-request) to avoid diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 884adb72e84..bf314b84d2b 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -50,7 +50,9 @@ def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: return mm_cpu -def to_payload_element(element: object, idx: int, start: int, end: int, seq_len: int | None = None): +def to_payload_element( + element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None +): """Build an mm payload element corresponding to one request index from an element containing 0 or more CPU tensors. @@ -59,6 +61,10 @@ def to_payload_element(element: object, idx: int, start: int, end: int, seq_len: idx: The index of the request. start: The start index corresponding to the request idx. end: The end index corresponding to the request idx. + pass_lists_through: bool Whether or not lists should be treated as + passthrough data; this should be False in normal cases, but True + if we need to avoid splitting nonempty lists prior to calling + postprocess, which is the case for prefix cache. seq_len: Optional sequence length (i.e., dim 0 of hidden states). This should be set to None in the prefix caching case, because the condition that would be executed here is the same as the @@ -73,8 +79,10 @@ def to_payload_element(element: object, idx: int, start: int, end: int, seq_len: elif isinstance(element, dict): return {sk: sv[start:end].contiguous() for sk, sv in element.items()} elif isinstance(element, list): + # For lists, clone tensors to avoid cross-request aliasing + if pass_lists_through: + return [elem.clone() if isinstance(elem, torch.Tensor) else elem for elem in element] element = element[idx] if idx < len(element) else element[0] - # Clone tensors to avoid cross-request aliasing if isinstance(element, torch.Tensor): element = element.clone() return element diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index b1eeec0fcfb..d6c7c0eac10 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -838,9 +838,11 @@ def propose_draft_token_ids(sampled_token_ids): self._process_additional_information_updates( hidden_states, + multimodal_outputs, num_scheduled_tokens_np, scheduler_output, combined_hidden_states, + combined_multimodal_outputs, ) pooler_output: list[dict[str, object]] = [] @@ -865,9 +867,16 @@ def propose_draft_token_ids(sampled_token_ids): if combined_multimodal_outputs: # Prefix cache enabled; all items have already been processed # and split apart for each request as needed, and all tensors - # have already been detached to the CPU. + # have already been detached to the CPU. The only exception is + # lists, which we keep as passthrough data for consistent behavior + # in postprocess. for mm_key in combined_multimodal_outputs.keys(): - mm_payload[mm_key] = combined_multimodal_outputs[mm_key][rid] + value = combined_multimodal_outputs[mm_key][rid] + if isinstance(value, list): + mm_payload[mm_key] = value[idx] if idx < len(value) else value[0] + else: + mm_payload[mm_key] = value + else: # Prefix cache disabled; we still need to process the data for mm_key, mm_val in mm_cpu.items(): @@ -876,6 +885,7 @@ def propose_draft_token_ids(sampled_token_ids): idx=idx, start=start, end=end, + pass_lists_through=False, seq_len=seq_len, ) payload.update(mm_payload) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 2b6c5d59ab2..de78011c75a 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -1032,9 +1032,11 @@ def _build_model_kwargs_extra(self) -> dict: def _process_additional_information_updates( self, hidden_states: torch.Tensor, + multimodal_outputs: object, num_scheduled_tokens_np: np.ndarray, scheduler_output: "SchedulerOutput", combined_hidden_states: dict[str, torch.Tensor] | None = None, + combined_multimodal_outputs: dict[str, object] | None = None, ) -> None: """Process model-provided per-request updates and merge into model_intermediate_buffer.""" try: @@ -1043,7 +1045,8 @@ def _process_additional_information_updates( if hasattr(self.model, "has_postprocess") and self.model.has_postprocess: for req_index, req_id in enumerate(self.input_batch.req_ids): req_infos = self.model_intermediate_buffer.get(req_id, {}) - if combined_hidden_states and req_id in combined_hidden_states: + if combined_hidden_states: + # Combined hidden states contains all hidden states for every request hidden_states_slice = combined_hidden_states[req_id] else: start_offset = int(self.query_start_loc.cpu[req_index]) @@ -1051,7 +1054,19 @@ def _process_additional_information_updates( s, e = start_offset, start_offset + sched_tokens # only consider to store data into update dict. hidden_states_slice = hidden_states[s:e] - update_dict = self.model.postprocess(hidden_states_slice, **req_infos) + + if combined_multimodal_outputs: + # NOTE this is a bit ugly, but the mm data is structured as a list of + # keys mapping to request IDs, and if enabled, we will always have all + # request IDs in every subdict, including for cache misses. + mm_out = {k: v[req_id] for k, v in combined_multimodal_outputs.items()} + else: + mm_out = multimodal_outputs + update_dict = self.model.postprocess( + hidden_states_slice, + multimodal_outputs=mm_out, + **req_infos, + ) self._update_intermediate_buffer(req_id, update_dict) except Exception as e: logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}") From 0e57316bcf1825d04097b39a32255037571885d5 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 1 Apr 2026 16:50:44 +0000 Subject: [PATCH 31/38] dont always require all mm cache keys Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 91085c4cbf9..39e2dfb0512 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -117,11 +117,11 @@ def update_omni_tensor_prefix_cache( # Do the same for the stage's cached multimodal outputs if multimodal_outputs is not None: for mm_out_key, mm_cache in self.mm_outputs_cache.items(): - assert mm_out_key in multimodal_outputs - mm_state = multimodal_outputs[mm_out_key] - mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) - flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) - flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] + if mm_out_key in multimodal_outputs: + mm_state = multimodal_outputs[mm_out_key] + mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) + flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) def _coerce_to_payload_dict( From 46635d14bf54c1c05fbfda356b484dc41ff33982 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 1 Apr 2026 16:58:40 +0000 Subject: [PATCH 32/38] minor Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 39e2dfb0512..17d81a046ba 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -102,9 +102,6 @@ def update_omni_tensor_prefix_cache( slot_mapping: torch.Tensor, ): """Updates the hidden cache state for the provided hidden states and multimodal outputs.""" - if hidden_states is not None: - hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) - unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] if hidden_states is not None: # Ensure that hidden states are on the CPU @@ -155,12 +152,6 @@ def get_merged_multimodal_states( ): """Get the merged multimodal states if hidden state prefix caching is enabled.""" combined_multimodal_outputs = {} - if self.mm_cache_keys is None and multimodal_outputs: - logger.warning( - " A model stage produced multimodal outputs, but has no defined mm_cache_keys; " - " this probably means that prefix caching is not fully supported for all stages " - "in this model" - ) # First get the prefix cached tensors for mm_key in self.mm_cache_keys: combined_multimodal_outputs[mm_key] = self._get_merged_tensors( From c4b9a0caad12d13d9d371bc0d9aed889eedfa82d Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 8 Apr 2026 00:48:16 +0000 Subject: [PATCH 33/38] review changes Signed-off-by: Alex Brooks --- vllm_omni/core/prefix_cache.py | 24 ++++++++++++++++-------- vllm_omni/utils/mm_outputs.py | 2 +- vllm_omni/worker/gpu_ar_model_runner.py | 14 ++++++-------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 17d81a046ba..0ae24aac648 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -103,6 +103,7 @@ def update_omni_tensor_prefix_cache( ): """Updates the hidden cache state for the provided hidden states and multimodal outputs.""" unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] + if hidden_states is not None: # Ensure that hidden states are on the CPU hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) @@ -113,6 +114,12 @@ def update_omni_tensor_prefix_cache( # Do the same for the stage's cached multimodal outputs if multimodal_outputs is not None: + # If we haven't initialized the keys already, do it now + self.maybe_init_missing_mm_cache_keys( + multimodal_outputs, + seq_len=num_tokens_unpadded, + ) + for mm_out_key, mm_cache in self.mm_outputs_cache.items(): if mm_out_key in multimodal_outputs: mm_state = multimodal_outputs[mm_out_key] @@ -152,15 +159,16 @@ def get_merged_multimodal_states( ): """Get the merged multimodal states if hidden state prefix caching is enabled.""" combined_multimodal_outputs = {} - # First get the prefix cached tensors + # First get the prefix cached tensors that are present in the mm data for mm_key in self.mm_cache_keys: - combined_multimodal_outputs[mm_key] = self._get_merged_tensors( - query_start_loc=query_start_loc, - input_batch=input_batch, - cache=self.mm_outputs_cache[mm_key], - hidden_states=multimodal_outputs[mm_key], - num_scheduled_tokens=num_scheduled_tokens, - ) + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) # Then, get everything else (passthrough data); first, convert to CPU # tensors similarly to the non prefix cached path, and then populate diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index bf314b84d2b..045ab3fe2cc 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -70,7 +70,7 @@ def to_payload_element( the condition that would be executed here is the same as the criteria for being added to the multimodal outputs cache. """ - # Prefix cache won't hit this case because this is the considition + # Prefix cache won't hit this case because this is the condition # for being a mm_cache_key in the multimodal outputs tensor. if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len: return element[start:end].contiguous() diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index d6c7c0eac10..cc6c952daa9 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -503,15 +503,13 @@ def execute_model( # Cache hidden states if we've enabled hidden state prefix caching # unless this isn't the last pipeline parallelism rank. if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: - if isinstance(multimodal_outputs, dict): - self.omni_prefix_cache.maybe_init_missing_mm_cache_keys( - multimodal_outputs, - seq_len=hidden_states.shape[0], + # If this happens, it generally means the model is not following the correct + # interface yet and is therefore currently not compatible with prefix cache. + if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict): + logger.warning_once( + "prefix caching expects mm outputs to be a dict, but got %s", + type(multimodal_outputs), ) - else: - # This usually means that the stage doesn't have - # multimodal outputs, so only the hidden states cache - logger.warning_once("Omni prefix caching expects type dict, but got %s", type(multimodal_outputs)) self.omni_prefix_cache.update_omni_tensor_prefix_cache( hidden_states=hidden_states, From 8f0af0af5473739d0eb239d986384be7b9b84c13 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Wed, 8 Apr 2026 04:01:19 +0000 Subject: [PATCH 34/38] check against padded tok counts Signed-off-by: Alex Brooks --- tests/core/test_prefix_cache.py | 8 ++++++-- vllm_omni/core/prefix_cache.py | 25 ++++++++++++++++++++----- vllm_omni/worker/gpu_ar_model_runner.py | 1 + 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py index 7b24e88d92b..c3d8c1ff928 100644 --- a/tests/core/test_prefix_cache.py +++ b/tests/core/test_prefix_cache.py @@ -185,7 +185,8 @@ def fake_get_cached_block_ids(self, req_idx, *args, **kwargs): return torch.tensor([], dtype=torch.long) -def test_get_merged_hidden_states(): +@pytest.mark.parametrize("num_tokens_padded", [None, 16]) +def test_get_merged_hidden_states(num_tokens_padded): """Ensure that hidden states are merged correctly.""" cache = get_omni_pcache() @@ -199,6 +200,7 @@ def test_get_merged_hidden_states(): multimodal_outputs=None, num_tokens_unpadded=orig_num_tokens_unpadded, slot_mapping=orig_slot_mapping, + num_tokens_padded=num_tokens_padded, ) # Say that we have two requests, but only one of them is a cache hit @@ -245,6 +247,7 @@ def test_get_merged_hidden_states(): assert torch.all(req2_merged_states == req2_new_states) +@pytest.mark.parametrize("num_tokens_padded", [None, 16]) @pytest.mark.parametrize( "feat_dims", [ @@ -252,7 +255,7 @@ def test_get_merged_hidden_states(): {"foo": 100, "bar": 50, "baz": 10}, ], ) -def test_get_merged_multimodal_outputs(feat_dims): +def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded): cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN) orig_num_tokens_unpadded = 8 @@ -268,6 +271,7 @@ def test_get_merged_multimodal_outputs(feat_dims): multimodal_outputs=orig_mm_outputs, num_tokens_unpadded=orig_num_tokens_unpadded, slot_mapping=orig_slot_mapping, + num_tokens_padded=num_tokens_padded, ) # Similar to hs test- say that we have two requests, but only one of them is a cache hit diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py index 0ae24aac648..69e7346c4c1 100644 --- a/vllm_omni/core/prefix_cache.py +++ b/vllm_omni/core/prefix_cache.py @@ -100,32 +100,47 @@ def update_omni_tensor_prefix_cache( multimodal_outputs: dict[str, torch.Tensor] | None, num_tokens_unpadded: int, slot_mapping: torch.Tensor, + num_tokens_padded: int | None = None, ): - """Updates the hidden cache state for the provided hidden states and multimodal outputs.""" + """Updates the hidden cache state for the provided hidden states and multimodal outputs. + + Args: + hidden_states: Hidden states tensor to cache (if any) + multimodal_outputs: Multimodal dict whose tensors may be cached + num_tokens_unpadded: Number of tokens without padding + slot_mapping: Slot mapping for the input sequence + num_tokens_padded: Total number of tokens including padding + """ unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded if hidden_states is not None: + # Slice to unpadded portion before caching + hidden_states = hidden_states[:num_tokens_unpadded] # Ensure that hidden states are on the CPU hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) # View the cache as 2D so that we can treat our slots as row indices flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) - flat_cache[unpadded_slot_mapping] = hidden_states[:num_tokens_unpadded] + flat_cache[unpadded_slot_mapping] = hidden_states logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded) # Do the same for the stage's cached multimodal outputs if multimodal_outputs is not None: # If we haven't initialized the keys already, do it now + # We check against the padded token count since we haven't sliced yet self.maybe_init_missing_mm_cache_keys( multimodal_outputs, - seq_len=num_tokens_unpadded, + seq_len=num_tokens_padded, ) for mm_out_key, mm_cache in self.mm_outputs_cache.items(): if mm_out_key in multimodal_outputs: - mm_state = multimodal_outputs[mm_out_key] + # Slice to unpadded portion before caching + mm_state = multimodal_outputs[mm_out_key][:num_tokens_unpadded] mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) - flat_cache[unpadded_slot_mapping] = mm_state[:num_tokens_unpadded] + flat_cache[unpadded_slot_mapping] = mm_state logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) def _coerce_to_payload_dict( diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index cc6c952daa9..b1caaf8e447 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -516,6 +516,7 @@ def execute_model( multimodal_outputs=multimodal_outputs, num_tokens_unpadded=num_tokens_unpadded, slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, + num_tokens_padded=num_tokens_padded, ) if not self.broadcast_pp_output: From 4e37c10ef80a56bf024eb19b8fca85bd630c50b3 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 10 Apr 2026 03:52:26 +0000 Subject: [PATCH 35/38] add clarity in encoder caching <-> prefix caching Signed-off-by: Alex Brooks --- docs/design/feature/prefix_caching.md | 53 +++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md index 15c9b35bb85..ebad8b69106 100644 --- a/docs/design/feature/prefix_caching.md +++ b/docs/design/feature/prefix_caching.md @@ -8,7 +8,7 @@ - [Overview](#overview) - [High-Level Approach](#high-level-approach) - [Example](#example) -- [Current Limitations](#current-limitations) +- [What About Multimodal Inputs?](#what-about-multimodal-inputs) --- @@ -23,29 +23,32 @@ vLLM implements automatic prefix caching for managing its kv-cache, which is bes - The last hidden states produced by the stage - Model / stage specific multimodal data -### High-Level Approach !!! note "Note 1" + This document describes vLLM-Omni's mechanism for caching tensor outputs that are meant to be passed between stages, when requests have common prefixes, similar to the way in which vLLM has prefix caching for the kv-cache. This works in conjunction with vLLM's multimodal encoder caching, but is distinct. See the final section for a concrete example for how they tie together in practice. + +### High-Level Approach +!!! note "Note 2" Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear. The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note: - Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU. -- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. +- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. This allows us to dynamically discover the tensors to be marked as cacheable outputs in each Omni model without having to explicitly specify cacheable output field names in every model. With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector. ### Example -!!! note "Note 2" - Prefix caching in vLLM Omni currently is only supported on AutoRegressive stages with one kv-cache group. +!!! note "Note 3" + Prefix caching in vLLM-Omni currently is only supported on AutoRegressive stages with one kv-cache group. It can be enabled/disabled per-stage via the `enable_prefix_caching` parameter in the model's stage config. -The way in which vLLM Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following: +The way in which vLLM-Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following: - `num_blocks=8` - `block_size=4` - `hidden_size=2` -- A stage specific multimodal output `mm_feature` with feature dimension `16` +- A stage specific multimodal output tensor named `mm_feature` with feature dimension `16` The prefix cache flow is then outlined below. @@ -98,7 +101,7 @@ which we can think of as row indices into the hidden states tensor if we view it ... ``` -Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. +Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. Note that in practice, we may have multiple multimodal output tensors per forward pass, which may have different names and different feature dimensions. 4. Now, say that we receive a new request `The quick brown fox jumped over the dog`. @@ -125,3 +128,37 @@ Since we have the block indices / slot mappings from the kv cache manager, we ca ``` Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call. + + +### What About Multimodal Inputs? +It's also useful to consider the case about how Omni prefix caching is handled when we have multimodal inputs that don't cleanly end on block boundaries, as well as how this works with multimodal encoder caching in vLLM. For example: + +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 foo ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +In this case, only `Block 1` will have outputs stored in the prefix tensor cache, because vLLM does not store partial blocks. This may appear to be a problem at first glance, because the multimodal input is fragmented across a new block that wasn't cached. + +In reality, this isn't a big problem for correctness, because vLLM also maintains an encoder cache for multimodal inputs. In other words, after the first pass, we'll have the following: + +- The Block 1 hash, which is used for prefix caching +- The hash describing the image data starting at position 0 and with length 6 +- In vLLM's encoder cache, a mapping from the image hash above to the encoder output + + +To understand what happens, say we get the following input as a second request: +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 bar baz ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +First, the scheduler will check for a prefix cache hit, which we will see on `Block 1`. As a result, we will have 4 tokens marked as precomputed, and only see the remaining 4 tokens in the following prefill. + +Because we have multimodal data in a scheduled span that isn't fully precomputed, we still need to call the visual encoder. However, since we have the image hash and encoder cache, we will retrieve the encoder outputs for `Im4` and `Im5` as we create the multimodal embeddings. + +When we pass our multimodal tensors to the language model component in the same stage, we'll then expect the same outputs, because the prefix caching behaviors in vLLM-Omni / vLLM match, so the LLM will use vLLM's KV cache manager's prefix caching to correctly handle the attention information for `Block 1` while calculating the outputs for `Block 2`, giving us the correct results for processing `Block 2` with the context of `Block 1`. + +Finally, we look up the output hidden states/multimodal tensors corresponding to the prefix cache hit `Block 1` and concatenate it with the forward pass result to get the final result, which is expected to be identical to the full hidden states when prefix caching is disabled. From b2aebe57438ff182367bf07335a19fa60df0fba5 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 10 Apr 2026 04:53:33 +0000 Subject: [PATCH 36/38] don't send None in payloads Signed-off-by: Alex Brooks --- vllm_omni/utils/mm_outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py index 045ab3fe2cc..66d4e6ffe04 100644 --- a/vllm_omni/utils/mm_outputs.py +++ b/vllm_omni/utils/mm_outputs.py @@ -45,7 +45,7 @@ def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]: else: cpu_list.append(elem) mm_cpu[k] = cpu_list - else: + elif v is not None: mm_cpu[k] = v return mm_cpu From 1c873d37264cdc4ac4b5376342fb73d762d15ba2 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Fri, 10 Apr 2026 07:49:22 +0000 Subject: [PATCH 37/38] check output in prefix cache test Signed-off-by: Alex Brooks --- tests/e2e/online_serving/test_qwen3_omni.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py index 14cb4c5df17..c05f8f50674 100644 --- a/tests/e2e/online_serving/test_qwen3_omni.py +++ b/tests/e2e/online_serving/test_qwen3_omni.py @@ -103,6 +103,7 @@ def get_prompt(prompt_type="text_only"): prompts = { "text_only": "What is the capital of China? Answer in 20 words.", "mix": "What is recited in the audio? What is in this image? Describe the video briefly.", + "text_image": "What color are the squares in this image?", } return prompts.get(prompt_type, prompts["text_only"]) @@ -184,10 +185,17 @@ def test_text_to_text_001(omni_server, openai_client) -> None: @pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True) def test_thinker_prefix_caching(omni_server, openai_client) -> None: """ - Test thinker supports prefix caching by sending two identical - requests and checking the number of cached tokens. + Test thinker prefix caching by sending identical requests with an image (i.e., + a large shared prefix) and verifying that the second request uses cached tokens + & produces the same output. """ - messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt()) + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + messages = dummy_messages_from_mix_data( + system_prompt=get_system_prompt(), + image_data_url=image_data_url, + content_text=get_prompt("text_image"), + ) + request_config = { "model": omni_server.model, "messages": messages, @@ -201,4 +209,8 @@ def test_thinker_prefix_caching(omni_server, openai_client) -> None: assert response_1.success assert response_2.success assert response_2.cached_tokens is not None + # We should cache the vast majority of the prompt (image + up to last full block), + # and set seed in the CI config, so the second request should give an identical + # response for the generated input image, even if we use dummy weights assert response_2.cached_tokens > 0 + assert response_1.text_content == response_2.text_content From 7593e4b801490404829d2b74dd740e86fa06252e Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 14 Apr 2026 22:44:07 +0000 Subject: [PATCH 38/38] refactor pc to separate methods Signed-off-by: Alex Brooks --- vllm_omni/worker/gpu_ar_model_runner.py | 123 ++++++++++++++---------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index b1caaf8e447..f37b2224efb 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -202,28 +202,62 @@ def _capture_talker_mtp_graphs(self) -> None: finally: set_cudagraph_capturing_enabled(False) - def update_hidden_state_cache(self, hidden_states, multimodal_outputs, num_tokens_unpadded): - """Updates the hidden cache state for prefix caching from - the current model's execution for the unpadded tokens. + def _maybe_update_prefix_cache( + self, + hidden_states: torch.Tensor, + multimodal_outputs: dict, + num_tokens_unpadded: int, + num_tokens_padded: int, + ): + """If prefix caching is enabled and it's the last pipeline parallelism rank, + retrieve the hidden states & multimodal outputs from the prefix cache based + on our batch slot mappings. """ - assert self.hidden_state_cache is not None - slot_mapping = self.input_batch.block_table[0].slot_mapping.gpu[:num_tokens_unpadded] - # View the cache as 2D so that we can treat our slots as row indices - flat_cache = self.hidden_state_cache.view(-1, self.hidden_state_cache.shape[-1]) - flat_cache[slot_mapping] = hidden_states[:num_tokens_unpadded] - logger.info(f"[HS Cache WRITE] tokens={num_tokens_unpadded}") - - # Do the same for the cached multimodal outputs for this stage; - # for now we assume that all of the multimodal outputs cached - # are exactly the same size as the hidden states. - # TODO (Alex) make this more flexible. - if self.mm_outputs_cache is not None: - for mm_out_key, mm_cache in self.mm_outputs_cache.items(): - assert mm_out_key in multimodal_outputs - mm_state = multimodal_outputs[mm_out_key] - flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) - flat_cache[slot_mapping] = mm_state[:num_tokens_unpadded] - logger.info(f"[multimodal output Cache WRITE] tokens={num_tokens_unpadded}") + # Cache hidden states if we've enabled hidden state prefix caching + # unless this isn't the last pipeline parallelism rank. + if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: + # If this happens, it generally means the model is not following the correct + # interface yet and is therefore currently not compatible with prefix cache. + if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict): + logger.warning_once( + "prefix caching expects mm outputs to be a dict, but got %s", + type(multimodal_outputs), + ) + + self.omni_prefix_cache.update_omni_tensor_prefix_cache( + hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + num_tokens_unpadded=num_tokens_unpadded, + slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, + num_tokens_padded=num_tokens_padded, + ) + + def _maybe_get_combined_prefix_cache_tensors( + self, + hidden_states: torch.Tensor, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], + ) -> tuple[dict[str, torch.Tensor] | None, dict | None]: + """If prefix caching is enabled, extract the merged hidden states and multimodal outputs for + all requests in the batch (including those that aren't a hit on Prefix cache). + """ + # Prior to applying the post-processing func, extract + # the prefix cached hidden states and multimodal states. + combined_hidden_states, combined_multimodal_outputs = None, None + if self.omni_prefix_cache is not None: + combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( + query_start_loc=self.query_start_loc.cpu, + input_batch=self.input_batch, + hidden_states=hidden_states, + num_scheduled_tokens=num_scheduled_tokens, + ) + combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( + query_start_loc=self.query_start_loc.cpu, + input_batch=self.input_batch, + multimodal_outputs=multimodal_outputs, + num_scheduled_tokens=num_scheduled_tokens, + ) + return combined_hidden_states, combined_multimodal_outputs @torch.inference_mode() def execute_model( @@ -500,24 +534,14 @@ def execute_model( hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output) - # Cache hidden states if we've enabled hidden state prefix caching - # unless this isn't the last pipeline parallelism rank. - if self.omni_prefix_cache is not None and get_pp_group().is_last_rank: - # If this happens, it generally means the model is not following the correct - # interface yet and is therefore currently not compatible with prefix cache. - if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict): - logger.warning_once( - "prefix caching expects mm outputs to be a dict, but got %s", - type(multimodal_outputs), - ) - - self.omni_prefix_cache.update_omni_tensor_prefix_cache( - hidden_states=hidden_states, - multimodal_outputs=multimodal_outputs, - num_tokens_unpadded=num_tokens_unpadded, - slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu, - num_tokens_padded=num_tokens_padded, - ) + # Cache hidden states & multimodal outputs if we've enabled hidden state + # prefix caching unless this isn't the last pipeline parallelism rank. + self._maybe_update_prefix_cache( + hidden_states=hidden_states, + multimodal_outputs=multimodal_outputs, + num_tokens_unpadded=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded, + ) if not self.broadcast_pp_output: # Common case. @@ -820,19 +844,16 @@ def propose_draft_token_ids(sampled_token_ids): # Prior to applying the post-processing func, extract # the prefix cached hidden states and multimodal states. if self.omni_prefix_cache is not None: - combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states( - query_start_loc=self.query_start_loc.cpu, - input_batch=self.input_batch, - hidden_states=hidden_states, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - ) - combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states( - query_start_loc=self.query_start_loc.cpu, - input_batch=self.input_batch, - multimodal_outputs=multimodal_outputs, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + ( + combined_hidden_states, + combined_multimodal_outputs, + ) = self._maybe_get_combined_prefix_cache_tensors( + hidden_states, + multimodal_outputs, + scheduler_output.num_scheduled_tokens, ) - else: + # Otherwise we don't have the mm CPU data yet, so we still need to build it + if self.omni_prefix_cache is None: mm_cpu = build_mm_cpu(multimodal_outputs) self._process_additional_information_updates(