Skip to content

[feature] Hidden State Prefix Caching#2164

Merged
Gaohan123 merged 39 commits intovllm-project:mainfrom
alex-jw-brooks:omni_hs_cache
Apr 15, 2026
Merged

[feature] Hidden State Prefix Caching#2164
Gaohan123 merged 39 commits intovllm-project:mainfrom
alex-jw-brooks:omni_hs_cache

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

Purpose

For #1184

This takes a first pass at enabling prefix caching for hidden states. The core idea with the approach is to preallocate an external hidden state cache, but have it mirror the prefix cache maintained by the upstream vLLM. This is accomplished by:

  • Saving request IDs for new requests that are cache hits (i.e., newly scheduled requests with > 0 computed tokens)
  • Calculating the number of cached blocks for each marked request
  • Reading & writing the cache of hidden states with the same mappings. In other words, the token mappings are essentially row indexes in the hidden state cache.

The same can be done for multimodal tensor outputs that need to be passed through, although this part is pretty hacked together for qwen omni at the moment, still working on cleaning it up. I think the main constraint here though is that the extra tensors here need to be the same length as the input toks/hidden states, otherwise we won't know how to divide them to blocks.

I'll move this out of draft once there are some tests and cleaned up the multimodal outputs part (in the next day or two)

Preliminary Results

I ran some benchmarks on the thinker for qwen3 omni (i.e., text only) with & without prefix caching to ensure we see a very speedup on prefills for large prefix matches; this essentially runs one token generation with text modality for a few cases:

1. Text with a very large common prefix

With no prefix caching:

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml",
  "query_type": "text",
  "num_prompts": 20,
  "prefill_time": 3.964723326731473,
  "total_cached_tokens": 0,
  "total_prompt_tokens": 12170,
  "cache_hit_pct": 0.0
}

With prefix caching

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml",
  "query_type": "text",
  "num_prompts": 20,
  "prefill_time": 0.45084997406229377,
  "total_cached_tokens": 11552,
  "total_prompt_tokens": 12170,
  "cache_hit_pct": 94.92193919474117
}
2. Image Inputs

Without prefix caching

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml",
  "query_type": "image",
  "num_prompts": 20,
  "prefill_time": 9.791648932732642,
  "total_cached_tokens": 0,
  "total_prompt_tokens": 42053,
  "cache_hit_pct": 0.0
}

With prefix caching

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml",
  "query_type": "image",
  "num_prompts": 20,
  "prefill_time": 1.149583789985627,
  "total_cached_tokens": 41616,
  "total_prompt_tokens": 42053,
  "cache_hit_pct": 98.96083513661333
}
3. Audio Inputs

Without prefix caching

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml",
  "query_type": "audio",
  "num_prompts": 20,
  "prefill_time": 1.7638336326926947,
  "total_cached_tokens": 0,
  "total_prompt_tokens": 5498,
  "cache_hit_pct": 0.0
}

With prefix caching

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml",
  "query_type": "audio",
  "num_prompts": 20,
  "prefill_time": 0.2613906688056886,
  "total_cached_tokens": 4848,
  "total_prompt_tokens": 5498,
  "cache_hit_pct": 88.17751909785376
}

Benchmark script for repro:
bench_prefix_caching.py

CC: @tzhouam @amy-why-3459 @LJH-LBJ @Sy0307 @lishunyang12 @NickLucche

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 26, 2026

Does the async chunk enabled when you test?

Copy link
Copy Markdown
Collaborator

@tzhouam tzhouam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job. Please review the comments.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
# for now we assume that all of the multimodal outputs cached
# are exactly the same size as the hidden states.
# TODO (Alex) make this more flexible.
if self.mm_outputs_cache is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern is for the data parallel, will we retain some interface for the cache-aware routing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment, but I think this is consistent with existing behavior with normal prefix caching in vLLM - from this doc

For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.

The handling is similar in this case because the relationship with scheduling blocks/slots is similar to normal prefix caching since we reuse the slot mapping for our own cache

Comment thread vllm_omni/worker/gpu_model_runner.py Outdated
Comment thread vllm_omni/worker/gpu_model_runner.py Outdated
Comment thread vllm_omni/worker/gpu_model_runner.py Outdated
Comment thread vllm_omni/worker/gpu_ar_model_runner.py
Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

Hey @LJH-LBJ, not yet - the config for the run above is below for reproducibility (just change enable_prefix_caching to True / False for the thinker

Config for first run above
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)

async_chunk: false
stage_args:
  - stage_id: 0
    stage_type: llm  # Use llm stage type for AR stages
    runtime:
      devices: "0"
    engine_args:
      model_stage: thinker
      max_num_seqs: 64
      model_arch: Qwen3OmniMoeForConditionalGeneration
      worker_type: ar
      scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
      gpu_memory_utilization: 0.9
      enforce_eager: false
      trust_remote_code: true
      engine_output_type: latent  # Output hidden states for talker
      distributed_executor_backend: "mp"
      enable_prefix_caching: true
      max_num_batched_tokens: 32768
      hf_config_name: thinker_config
      tensor_parallel_size: 1
      attention_backend: TRITON_ATTN
    final_output: true
    final_output_type: text
    is_comprehension: true
    default_sampling_params:
      temperature: 0.4
      top_p: 0.9
      top_k: 1
      max_tokens: 2048
      seed: 42
      detokenize: True
      repetition_penalty: 1.05

  - stage_id: 1
    stage_type: llm  # Use llm stage type for AR stages
    runtime:
      devices: "1"
    engine_args:
      model_stage: talker
      max_num_seqs: 64
      model_arch: Qwen3OmniMoeForConditionalGeneration
      worker_type: ar
      scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
      gpu_memory_utilization: 0.6
      enforce_eager: false
      trust_remote_code: true
      engine_output_type: latent  # Output codec codes for code2wav
      enable_prefix_caching: false
      max_num_batched_tokens: 32768
      distributed_executor_backend: "mp"
      hf_config_name: talker_config
      attention_backend: TRITON_ATTN
    engine_input_source: [0]
    custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
    # final_output: true
    # final_output_type: text
    default_sampling_params:
      temperature: 0.9
      top_k: 50
      max_tokens: 4096
      seed: 42
      detokenize: False
      repetition_penalty: 1.05
      stop_token_ids: [2150]

  - stage_id: 2
    stage_type: llm  # Use llm stage type for AR stages
    runtime:
      devices: "1"
    engine_args:
      model_stage: code2wav
      max_num_seqs: 32
      model_arch: Qwen3OmniMoeForConditionalGeneration
      worker_type: generation
      scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
      enforce_eager: true
      trust_remote_code: true
      async_scheduling: false
      enable_prefix_caching: false
      engine_output_type: audio  # Final output: audio waveform
      gpu_memory_utilization: 0.1
      distributed_executor_backend: "mp"
      max_num_batched_tokens: 1000000
      hf_config_name: thinker_config
    engine_input_source: [1]
    custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
    final_output: true
    final_output_type: audio
    default_sampling_params:
      temperature: 0.0
      top_p: 1.0
      top_k: -1
      max_tokens: 65536
      seed: 42
      detokenize: True
      repetition_penalty: 1.1

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 30, 2026

Why do the test only cover text only case?

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Mar 30, 2026

I run benchmark with your pr with commit id: 32e8c99, it seems works well

vllm serve /workspace/models/Qwen3-Omni-30B-A3B-Instruct --omni --port 46354 --stage-configs-path ./vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml

vllm bench serve   --omni   --dataset-name random-mm   --port 46354   --max-concurrency 10   --model /workspace/models/Qwen3-Omni-30B-A3B-Instruct   --endpoint /v1/chat/completions   --backend openai-chat-omni   --num-prompts 100   --random-input-len 400   --ignore-eos   --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf   --random-output-len 100   --extra_body '{"modalities": ["text", "audio"]}' --random-prefix-len 200
enable_prefix_caching: false:
============ Serving Benchmark Result ============
Successful requests:                     100       
Failed requests:                         0         
Maximum request concurrency:             10        
Benchmark duration (s):                  258.46    
Request throughput (req/s):              0.39      
Peak concurrent requests:                14.00     
----------------End-to-end Latency----------------
Mean E2EL (ms):                          24740.90  
Median E2EL (ms):                        24658.21  
P99 E2EL (ms):                           29473.40  
================== Text Result ===================
Total input tokens:                      60000     
Total generated tokens:                  8721      
Output token throughput (tok/s):         33.74     
Peak output token throughput (tok/s):    270.00    
Peak concurrent requests:                14.00     
Total Token throughput (tok/s):          265.89    
---------------Time to First Token----------------
Mean TTFT (ms):                          3374.48   
Median TTFT (ms):                        3663.65   
P99 TTFT (ms):                           5713.33   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          75.53     
Median TPOT (ms):                        72.91     
P99 TPOT (ms):                           129.20    
---------------Inter-token Latency----------------
Mean ITL (ms):                           64.01     
Median ITL (ms):                         0.02      
P99 ITL (ms):                            1296.70   
================== Audio Result ==================
Total audio duration generated(s):       2952.11   
Total audio frames generated:            70850730  
Audio throughput(audio duration/s):      11.42     
---------------Time to First Packet---------------
Mean AUDIO_TTFP (ms):                    6100.76   
Median AUDIO_TTFP (ms):                  6146.65   
P99 AUDIO_TTFP (ms):                     8470.25   
-----------------Real Time Factor-----------------
Mean AUDIO_RTF:                          0.84      
Median AUDIO_RTF:                        0.84      
P99 AUDIO_RTF:                           1.02      
==================================================
					


enable_prefix_caching: true:
============ Serving Benchmark Result ============							
Successful requests:                     100       							
Failed requests:                         0         							
Maximum request concurrency:             10        							
Benchmark duration (s):                  225.66    							
Request throughput (req/s):              0.44      							
Peak concurrent requests:                15.00     							
----------------End-to-end Latency----------------							
Mean E2EL (ms):                          21319.82  							
Median E2EL (ms):                        21251.56  							
P99 E2EL (ms):                           25833.46  							
================== Text Result ===================							
Total input tokens:                      60000     							
Total generated tokens:                  8842      							
Output token throughput (tok/s):         39.18     							
Peak output token throughput (tok/s):    240.00    							
Peak concurrent requests:                15.00     							
Total Token throughput (tok/s):          305.07    							
---------------Time to First Token----------------							
Mean TTFT (ms):                          1530.10   							
Median TTFT (ms):                        1265.57   							
P99 TTFT (ms):                           3796.66   							
-----Time per Output Token (excl. 1st token)------							
Mean TPOT (ms):                          67.72     							
Median TPOT (ms):                        65.26     							
P99 TPOT (ms):                           96.90     							
---------------Inter-token Latency----------------							
Mean ITL (ms):                           58.41     							
Median ITL (ms):                         0.02      							
P99 ITL (ms):                            1200.07   							
================== Audio Result ==================							
Total audio duration generated(s):       2959.31   							
Total audio frames generated:            71023530  							
Audio throughput(audio duration/s):      13.11     							
---------------Time to First Packet---------------							
Mean AUDIO_TTFP (ms):                    3812.07   							
Median AUDIO_TTFP (ms):                  3638.03   							
P99 AUDIO_TTFP (ms):                     5903.20   							
-----------------Real Time Factor-----------------							
Mean AUDIO_RTF:                          0.73      							
Median AUDIO_RTF:                        0.72      							
P99 AUDIO_RTF:                           0.91      							
==================================================							

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

Thanks @LJH-LBJ - still working on tests and cleaning up some edge cases, so I hadn't added the multimodal merging tests yet, but they are there now 🙂

@alex-jw-brooks alex-jw-brooks changed the title [wip] Hidden State Prefix Caching Hidden State Prefix Caching Mar 31, 2026
@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review March 31, 2026 04:05
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 14a314caf8

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/worker/gpu_model_runner.py Outdated
Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

Hey @LJH-LBJ @tzhouam @amy-why-3459, this is ready for a look when you have a moment.

Current state is:

  • The hidden states / multimodal output caches stay fully on CPU - since everything is coerced to CPU tensors when we create the payloads normally anyway, we can do this without issues 🎉
  • For now, multimodal output keys are determined dynamically - when the model executes and produces its multimodal outputs, keys mapping to tensors whose sequence length is dependent on the number of tokens are added. E.g., for the thinker, this will add 0 and 24 as keys mapping to their own CPU tensor caches on the first pass.
    • This avoids the need to explicitly specify which multimodal output keys are cacheable for every stage on every model, so hopefully it should 'just work', at least on stages running on the ar model runner with token inputs for now
  • Non-cached multimodal data is kept as pass through data, so everything is built at the same time and can be added to the payload easily. I refactored a bit to make some of the mm processing parts common because it was a bit messy.

Still working on some design docs explaining how this works and how it hooks into vLLM's prefix caching, but I think I'll have it written by tomorrow

@amy-why-3459
Copy link
Copy Markdown
Contributor

Excellent work. @LJH-LBJ Could you please help test the L4 test cases? If prefix caching significantly benefits TTFT, we'll consider changing the default enable_prefix_caching configuration for stage-0 to True.

@LJH-LBJ
Copy link
Copy Markdown
Contributor

LJH-LBJ commented Apr 1, 2026

It looks really nice. Can you include the test results for tests/e2e/online_serving/test_qwen3_omni_expansion.py? I’d like to confirm whether it affects accuracy.

@amy-why-3459
Copy link
Copy Markdown
Contributor

@princepride PTAL

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Apr 1, 2026

It looks really nice. Can you include the test results for tests/e2e/online_serving/test_qwen3_omni_expansion.py? I’d like to confirm whether it affects accuracy.

@LJH-LBJ I ran the tests and everything passed. Ideally it shouldn't affect the output at all, although after thinking more, I suspect for multimodal outputs there may be occasional differences since vLLM doesn't prefix cache on partial blocks (so I think the leftover part that is cut off may cache miss)

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice results on the benchmarks (8.8x on text, 8.5x on image, 6.7x on audio). Left a few comments — two potential bugs worth fixing before moving out of draft.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
Comment thread vllm_omni/core/prefix_cache.py Outdated
Comment thread vllm_omni/utils/mm_outputs.py Outdated
@tzhouam tzhouam added the ready label to trigger buildkite CI label Apr 8, 2026
@tzhouam
Copy link
Copy Markdown
Collaborator

tzhouam commented Apr 8, 2026

please fix the pre commits

@amy-why-3459
Copy link
Copy Markdown
Contributor

If this PR is already ready, please help add a nightly-test label. @tzhouam

@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Apr 8, 2026

Hey @tzhouam @amy-why-3459, I think there is one more case that should probably be handled in this PR for the partial block case, since vLLM only caches full blocks, which has weird implications for multimodal data that isn't evenly divisible by the block size.

I have an idea for how to handle this though and am working on it, will try to push a fix up tomorrow morning

@tzhouam
Copy link
Copy Markdown
Collaborator

tzhouam commented Apr 9, 2026

After handing the additional condition, please fix the failed CI and I will add a nightly test label for more tests.

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Apr 15, 2026

Hey @Gaohan123, thanks for the review! Yes to both. I ran a quick test with async chunking enabled using mixed audio and image inputs with a huge shared prefix and saw similar numbers come back from the benchmark script from above.

Without (generating a max of 30 tokens)

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml",
  "query_type": "mixed",
  "num_prompts": 20,
  "max_tokens": 30,
  "generation_wall_time": 10.93366788700223,
  "total_cached_tokens": 0,
  "total_prompt_tokens": 53810,
  "cache_hit_pct": 0.0,
}

With the same config & thinker prefix cache enabled on the same image + audio input for every req, so almost everything is cached:

{
  "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_prefix_cache.yaml",
  "query_type": "mixed",
  "num_prompts": 20,
  "max_tokens": 30,
  "generation_wall_time": 2.7216501152142882,
  "total_cached_tokens": 53136,
  "total_prompt_tokens": 53810,
  "cache_hit_pct": 98.74744471287865,
}

In general, the prefix cache is largely based on the scheduler output, so it should be compatible with most other features out of the box once it's enabled in different workers/schedulers

@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 15, 2026
Copy link
Copy Markdown
Collaborator

@Gaohan123 Gaohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@Gaohan123 Gaohan123 enabled auto-merge (squash) April 15, 2026 16:52
@Gaohan123 Gaohan123 merged commit f1e3f03 into vllm-project:main Apr 15, 2026
7 of 8 checks passed
iancarrasco-b10 pushed a commit to basetenlabs/vllm-omni that referenced this pull request Apr 15, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: iancarrasco-b10 <ian.carrasco@baseten.co>
y123456y78 pushed a commit to y123456y78/vllm-omni that referenced this pull request Apr 15, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
y123456y78 pushed a commit to y123456y78/vllm-omni that referenced this pull request Apr 16, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
hongzhi-gao added a commit to hongzhi-gao/vllm-omni that referenced this pull request Apr 17, 2026
Validate GPU hidden-state update/merge (pytest markers core_model+cuda+L4 to match Buildkite CUDA unit test selection) and dual prefix-hit block table sharing vs isolation on CPU.

Complements tests/core/test_prefix_cache.py; relates to hidden-state prefix caching (vllm-project#2164).

Signed-off-by: hongzhigao <761417898@qq.com>
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants