Skip to content

[Gemma4] perf: batch vision encoder and embed_vision calls#26606

Open
pyc96 wants to merge 2 commits into
sgl-project:mainfrom
pyc96:pyc/upstream-gemma4-mm-batched-encoder
Open

[Gemma4] perf: batch vision encoder and embed_vision calls#26606
pyc96 wants to merge 2 commits into
sgl-project:mainfrom
pyc96:pyc/upstream-gemma4-mm-batched-encoder

Conversation

@pyc96
Copy link
Copy Markdown
Collaborator

@pyc96 pyc96 commented May 28, 2026

Motivation

Gemma4ForConditionalGeneration.get_image_feature and get_video_feature previously iterated one image (or one video frame) at a time through self.vision_tower(...), then once more through self.embed_vision(...) per pooled output. For a request carrying N images this issues 2*N GPU dispatches where 2 suffice.

The vision tower (Gemma4VisionEncoder.forward) already accepts a batched first dim [B, num_patches, patch_pixels], and embed_vision is pointwise (RMSNorm + Linear) along the token axis, so both loops are unnecessary serialization that limits throughput for concurrent multi-image requests.

Algorithm follows vllm-project/vllm#43169 (Apache-2.0).

Modifications

Replace the two per-item loops with three helpers (output is shape- and order-identical to the previous code):

  • _flatten_pixel_lists — walk items in order, normalize shapes, and record an ordered slot list so prepass (already-embedded) entries and raw-pixel entries reassemble in their original positions.
  • _batched_encode — bucket entries by patch count (resolution bucket) so each encoder forward is a uniform-shape batch with no cross-resolution padding; optionally chunk a bucket to bound encoder activation memory (_encoder_max_batch); then run embed_vision exactly once over the concatenated valid tokens.
  • _gather_mm_features — driver shared by the image and video paths; reassembles per-item outputs in original walk order.

_encoder_max_batch lazily derives a per-process budget (5% of device memory) and a per-patch activation cost cached at the end of load_weights from the loaded vision_config; if unavailable it degrades to a single-item batch.

Accuracy Tests

30-prompt colored-image labelling (temperature=0, seed=0, 1–6 images/prompt) on google/gemma-4-E2B-it, 1× B200:

framework accuracy char-for-char vs baseline
baseline (this PR's base) 26/30 (86.7%) (reference)
patched 26/30 (86.7%) 30/30 identical

The patch is provably output-identical to baseline — every one of the 30 responses matches character-for-character.

MMMU-val (100 samples, sglang.test.run_eval mmmu, google/gemma-4-E2B-it):

framework MMMU score latency (s)
baseline 0.2703 7.31
patched 0.2703 7.16

Identical score and per-subject breakdown — accuracy-neutral on a real multimodal reasoning benchmark, and above the 0.26 threshold registered for the Gemma4 family in test/registered/eval/test_vlms_mmmu_eval.py.

Speed Tests and Profiling

1× B200 (sm_100a), bf16, TP=1, --attention-backend triton. Load gen: vllm bench serve --dataset-name random-mm, 6× 480×480 images/prompt, 50 in / 10 out, 100 prompts. Geomean of 2 steady runs.

Cache-off (--disable-radix-cache, isolates the encoder):

metric baseline patched delta
duration (s) 15.66 10.73 1.46× faster
total tok/s 10,326 15,069 +45.9%
median TTFT (ms) 10,550 7,894 −25.2%

Cache-on (default radix cache): parity within noise — on the cyclic random-mm distribution the radix cache hits >60%, so the encoder is rarely re-run. Production workloads with diverse (low-cache-hit) user images sit closer to the cache-off numbers.

Profiling (torch profiler, 12-prompt batch = 72 images): CPU-side op-launch counts confirm the per-image loop collapses into one batched forward per resolution bucket:

op baseline patched reduction
aten::bmm (vision attn) 288 24 12×
aten::linear 6,018 1,662 3.6×
aten::gelu 1,362 306 4.5×

Checklist

Unit tests: test/srt/models/test_gemma4_mm_batched_encoder.py (5 CPU-only tests) — single-resolution single-call, mixed-resolution bucketing, budget-bound chunking, empty input, and interleaved prepass/raw ordering. All pass.


CI States

Latest PR Test (Base): ✅ Run #26801493849
Latest PR Test (Extra): ❌ Run #26801493739

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96 pyc96 force-pushed the pyc/upstream-gemma4-mm-batched-encoder branch from 3e1d52b to 57653b6 Compare May 29, 2026 21:56
@pyc96 pyc96 force-pushed the pyc/upstream-gemma4-mm-batched-encoder branch 2 times, most recently from 1d6f8f7 to da5c5e3 Compare June 2, 2026 05:24
@pyc96 pyc96 force-pushed the pyc/upstream-gemma4-mm-batched-encoder branch from da5c5e3 to 26cb235 Compare June 2, 2026 05:28
@pyc96 pyc96 marked this pull request as ready for review June 2, 2026 05:58
@pyc96 pyc96 requested a review from kpham-sgl as a code owner June 2, 2026 05:58
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented Jun 2, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant