[feature] Hidden State Prefix Caching#2164
Conversation
|
Does the async chunk enabled when you test? |
tzhouam
left a comment
There was a problem hiding this comment.
Nice job. Please review the comments.
| # for now we assume that all of the multimodal outputs cached | ||
| # are exactly the same size as the hidden states. | ||
| # TODO (Alex) make this more flexible. | ||
| if self.mm_outputs_cache is not None: |
There was a problem hiding this comment.
One concern is for the data parallel, will we retain some interface for the cache-aware routing?
There was a problem hiding this comment.
Not at the moment, but I think this is consistent with existing behavior with normal prefix caching in vLLM - from this doc
For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.
The handling is similar in this case because the relationship with scheduling blocks/slots is similar to normal prefix caching since we reuse the slot mapping for our own cache
|
Hey @LJH-LBJ, not yet - the config for the run above is below for reproducibility (just change Config for first run above# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
async_chunk: false
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type for AR stages
runtime:
devices: "0"
engine_args:
model_stage: thinker
max_num_seqs: 64
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.9
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: true
max_num_batched_tokens: 32768
hf_config_name: thinker_config
tensor_parallel_size: 1
attention_backend: TRITON_ATTN
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 2048
seed: 42
detokenize: True
repetition_penalty: 1.05
- stage_id: 1
stage_type: llm # Use llm stage type for AR stages
runtime:
devices: "1"
engine_args:
model_stage: talker
max_num_seqs: 64
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
enable_prefix_caching: false
max_num_batched_tokens: 32768
distributed_executor_backend: "mp"
hf_config_name: talker_config
attention_backend: TRITON_ATTN
engine_input_source: [0]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params:
temperature: 0.9
top_k: 50
max_tokens: 4096
seed: 42
detokenize: False
repetition_penalty: 1.05
stop_token_ids: [2150]
- stage_id: 2
stage_type: llm # Use llm stage type for AR stages
runtime:
devices: "1"
engine_args:
model_stage: code2wav
max_num_seqs: 32
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 1000000
hf_config_name: thinker_config
engine_input_source: [1]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output: true
final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 65536
seed: 42
detokenize: True
repetition_penalty: 1.1 |
1910354 to
dd341fe
Compare
|
Why do the test only cover text only case? |
|
I run benchmark with your pr with commit id: 32e8c99, it seems works well |
|
Thanks @LJH-LBJ - still working on tests and cleaning up some edge cases, so I hadn't added the multimodal merging tests yet, but they are there now 🙂 |
4950ebc to
78ca8b5
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 14a314caf8
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
Hey @LJH-LBJ @tzhouam @amy-why-3459, this is ready for a look when you have a moment. Current state is:
Still working on some design docs explaining how this works and how it hooks into vLLM's prefix caching, but I think I'll have it written by tomorrow |
|
Excellent work. @LJH-LBJ Could you please help test the L4 test cases? If prefix caching significantly benefits TTFT, we'll consider changing the default |
|
It looks really nice. Can you include the test results for |
|
@princepride PTAL |
455c697 to
7f2d8c4
Compare
@LJH-LBJ I ran the tests and everything passed. Ideally it shouldn't affect the output at all, although after thinking more, I suspect for multimodal outputs there may be occasional differences since vLLM doesn't prefix cache on partial blocks (so I think the leftover part that is cut off may cache miss) |
lishunyang12
left a comment
There was a problem hiding this comment.
Nice results on the benchmarks (8.8x on text, 8.5x on image, 6.7x on audio). Left a few comments — two potential bugs worth fixing before moving out of draft.
45bd22a to
2effc1e
Compare
|
please fix the pre commits |
|
If this PR is already ready, please help add a nightly-test label. @tzhouam |
|
Hey @tzhouam @amy-why-3459, I think there is one more case that should probably be handled in this PR for the partial block case, since vLLM only caches full blocks, which has weird implications for multimodal data that isn't evenly divisible by the block size. I have an idea for how to handle this though and am working on it, will try to push a fix up tomorrow morning |
|
After handing the additional condition, please fix the failed CI and I will add a nightly test label for more tests. |
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
6a8533f to
7593e4b
Compare
|
Hey @Gaohan123, thanks for the review! Yes to both. I ran a quick test with async chunking enabled using mixed audio and image inputs with a huge shared prefix and saw similar numbers come back from the benchmark script from above. Without (generating a max of 30 tokens) With the same config & thinker prefix cache enabled on the same image + audio input for every req, so almost everything is cached: In general, the prefix cache is largely based on the scheduler output, so it should be compatible with most other features out of the box once it's enabled in different workers/schedulers |
Signed-off-by: Alex Brooks <albrooks@redhat.com> Signed-off-by: iancarrasco-b10 <ian.carrasco@baseten.co>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Validate GPU hidden-state update/merge (pytest markers core_model+cuda+L4 to match Buildkite CUDA unit test selection) and dual prefix-hit block table sharing vs isolation on CPU. Complements tests/core/test_prefix_cache.py; relates to hidden-state prefix caching (vllm-project#2164). Signed-off-by: hongzhigao <761417898@qq.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Purpose
For #1184
This takes a first pass at enabling prefix caching for hidden states. The core idea with the approach is to preallocate an external hidden state cache, but have it mirror the prefix cache maintained by the upstream vLLM. This is accomplished by:
The same can be done for multimodal tensor outputs that need to be passed through, although this part is pretty hacked together for qwen omni at the moment, still working on cleaning it up. I think the main constraint here though is that the extra tensors here need to be the same length as the input toks/hidden states, otherwise we won't know how to divide them to blocks.
I'll move this out of draft once there are some tests and cleaned up the multimodal outputs part (in the next day or two)
Preliminary Results
I ran some benchmarks on the thinker for qwen3 omni (i.e., text only) with & without prefix caching to ensure we see a very speedup on prefills for large prefix matches; this essentially runs one token generation with text modality for a few cases:
1. Text with a very large common prefix
With no prefix caching:
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml", "query_type": "text", "num_prompts": 20, "prefill_time": 3.964723326731473, "total_cached_tokens": 0, "total_prompt_tokens": 12170, "cache_hit_pct": 0.0 }With prefix caching
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml", "query_type": "text", "num_prompts": 20, "prefill_time": 0.45084997406229377, "total_cached_tokens": 11552, "total_prompt_tokens": 12170, "cache_hit_pct": 94.92193919474117 }2. Image Inputs
Without prefix caching
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml", "query_type": "image", "num_prompts": 20, "prefill_time": 9.791648932732642, "total_cached_tokens": 0, "total_prompt_tokens": 42053, "cache_hit_pct": 0.0 }With prefix caching
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml", "query_type": "image", "num_prompts": 20, "prefill_time": 1.149583789985627, "total_cached_tokens": 41616, "total_prompt_tokens": 42053, "cache_hit_pct": 98.96083513661333 }3. Audio Inputs
Without prefix caching
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe_no_prefix_cache.yaml", "query_type": "audio", "num_prompts": 20, "prefill_time": 1.7638336326926947, "total_cached_tokens": 0, "total_prompt_tokens": 5498, "cache_hit_pct": 0.0 }With prefix caching
{ "stage_configs_path": "prefix_cache_benchmarks/qwen3_omni_moe.yaml", "query_type": "audio", "num_prompts": 20, "prefill_time": 0.2613906688056886, "total_cached_tokens": 4848, "total_prompt_tokens": 5498, "cache_hit_pct": 88.17751909785376 }Benchmark script for repro:
bench_prefix_caching.py
CC: @tzhouam @amy-why-3459 @LJH-LBJ @Sy0307 @lishunyang12 @NickLucche