Skip to content

[Gemma4] Enable Fast Prefill Optimization#38879

Merged
robertgshaw2-redhat merged 2 commits into
vllm-project:mainfrom
neuralmagic:gemma4-fast-prefill
Apr 6, 2026
Merged

[Gemma4] Enable Fast Prefill Optimization#38879
robertgshaw2-redhat merged 2 commits into
vllm-project:mainfrom
neuralmagic:gemma4-fast-prefill

Conversation

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson commented Apr 3, 2026

Summary

Add --kv-sharing-fast-prefill support for Gemma 4 models, porting the YOCO (You Only Cache Once) fast prefill optimization from Gemma3n. When enabled, the cross-decoder layers (KV-shared) skip prefill tokens and only process decode tokens, significantly reducing prefill latency and improving throughput under concurrent load.

shout-out to @sarckk for the original optimzation (#22628)

Test Plan

GSM8K accuracy (Gemma4-E4B, 5-shot)

# FP=OFF (baseline)
lm_eval --model vllm --tasks gsm8k --num_fewshot 5 \
  --model_args pretrained=google/gemma-4-E4B-it,gpu_memory_utilization=0.9,max_model_len=4096,tensor_parallel_size=1,trust_remote_code=True,attention_backend=TRITON_ATTN,kv_sharing_fast_prefill=False \
  --batch_size auto --apply_chat_template --fewshot_as_multiturn

# FP=ON (this PR)
lm_eval --model vllm --tasks gsm8k --num_fewshot 5 \
  --model_args pretrained=google/gemma-4-E4B-it,gpu_memory_utilization=0.9,max_model_len=4096,tensor_parallel_size=1,trust_remote_code=True,attention_backend=TRITON_ATTN,kv_sharing_fast_prefill=True \
  --batch_size auto --apply_chat_template --fewshot_as_multiturn

Serving benchmark

# Start server (without fast prefill)
vllm serve google/gemma-4-E4B-it \
  --port 8434 \
  --disable-log-stats \
  --no-enable-prefix-caching \
  --max-num-seqs 128 \
  --max-model-len 32768 \
  --max-num-batched-tokens 8192 \
  --attention-backend TRITON_ATTN \
  --trust-remote-code

# Start server (with fast prefill)
vllm serve google/gemma-4-E4B-it \
  --port 8434 \
  --disable-log-stats \
  --no-enable-prefix-caching \
  --max-num-seqs 128 \
  --max-model-len 32768 \
  --max-num-batched-tokens 8192 \
  --attention-backend TRITON_ATTN \
  --trust-remote-code \
  --kv-sharing-fast-prefill

# Run benchmark (after server is ready)
# concurrency=8
vllm bench serve \
  --backend vllm \
  --ignore-eos \
  --port 8434 \
  --model google/gemma-4-E4B-it \
  --dataset-name random \
  --max-concurrency 8 \
  --request-rate inf \
  --num-prompts 256 \
  --random-input-len 8192 \
  --random-output-len 150

# concurrency=32
vllm bench serve \
  --backend vllm \
  --ignore-eos \
  --port 8434 \
  --model google/gemma-4-E4B-it \
  --dataset-name random \
  --max-concurrency 32 \
  --request-rate inf \
  --num-prompts 256 \
  --random-input-len 8192 \
  --random-output-len 150

Test Results

GSM8K accuracy (Gemma4-E4B, 5-shot)

No accuracy regression:

strict-match flexible-extract
FP=OFF (baseline) 0.1054 0.1751
FP=ON (this PR) 0.1031 0.1850

Serving performance (Gemma4-E4B, 1xB200, ISL=8192, OSL=150, n=256)

concurrency=8

Metric NORMAL FAST_PREFILL Delta
Throughput 4.22 req/s 5.06 req/s +19.9%
Mean TTFT 570 ms 363 ms -36.3%
Mean TPOT 8.90 ms 8.16 ms -8.3%

concurrency=32

Metric NORMAL FAST_PREFILL Delta
Throughput 6.53 req/s 9.07 req/s +38.9%
Mean TTFT 942 ms 622 ms -34.0%
Mean TPOT 26.43 ms 19.37 ms -26.7%

@mergify mergify Bot added multi-modality Related to multi-modality (#4194) new-model Requests to new models labels Apr 3, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for the Gemma 4 model family, encompassing both text-only and multimodal (image, audio, and video) capabilities. Key additions include specialized proportional RoPE, logic to handle heterogeneous head dimensions, and custom parsers for reasoning and tool calls. The review feedback identifies several critical robustness and performance improvements: replacing process-terminating sys.exit(1) calls with ValueError exceptions, optimizing memory by conditionalizing tensor clones in the fast prefill path, ensuring global context consistency using try...finally blocks, and implementing bounds checking for batch sizes during profiling to prevent potential runtime crashes.

Comment thread vllm/model_executor/models/gemma4_mm.py
Comment thread vllm/model_executor/models/gemma4.py
Comment thread vllm/model_executor/models/gemma4.py
Comment thread vllm/model_executor/models/gemma4_mm.py
Port the --kv-sharing-fast-prefill optimization from Gemma3n to Gemma4.
When enabled, cross-decoder layers (KV-shared) skip prefill tokens and
only process decode tokens, reducing TTFT by ~36% and improving
throughput by up to ~39% under concurrent load.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@mergify mergify Bot removed the needs-rebase label Apr 3, 2026
Copy link
Copy Markdown

@RyanMullins RyanMullins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Shared layers don't compute so you can early exit depending on the config.

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 3, 2026
@lk-chen
Copy link
Copy Markdown
Collaborator

lk-chen commented Apr 5, 2026

verified on TPU with same set up as vllm-project/tpu-inference#2126 (comment), MMMU-pro score is identical before/after this current PR. Performance metrics untested.

@robertgshaw2-redhat robertgshaw2-redhat merged commit 47e6050 into vllm-project:main Apr 6, 2026
57 checks passed
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@khluu khluu added this to the v0.19.1 cherry picks milestone Apr 7, 2026
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
khluu pushed a commit that referenced this pull request Apr 10, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit 47e6050)
Natfii pushed a commit to Navi-AI-Lab/nvllm that referenced this pull request Apr 10, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@bbrowning
Copy link
Copy Markdown
Collaborator

Just an FYI this may have created correctness issues in certain situations. In debugging #39392, reverting this appears to have fixed that problem for me. It could be this only impacts certain hardware or setups - I'm using a DGX Spark and the user in 39392 was using 8x RTX 4090.

khluu pushed a commit that referenced this pull request Apr 16, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit 47e6050)
(cherry picked from commit fc29ef1)
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed tool-calling

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants