[Gemma4] perf: batch vision encoder and embed_vision calls#26606
Open
pyc96 wants to merge 2 commits into
Open
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
3e1d52b to
57653b6
Compare
1d6f8f7 to
da5c5e3
Compare
da5c5e3 to
26cb235
Compare
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
Author
|
/tag-and-rerun-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Gemma4ForConditionalGeneration.get_image_featureandget_video_featurepreviously iterated one image (or one video frame) at a time throughself.vision_tower(...), then once more throughself.embed_vision(...)per pooled output. For a request carrying N images this issues2*NGPU dispatches where 2 suffice.The vision tower (
Gemma4VisionEncoder.forward) already accepts a batched first dim[B, num_patches, patch_pixels], andembed_visionis pointwise (RMSNorm + Linear) along the token axis, so both loops are unnecessary serialization that limits throughput for concurrent multi-image requests.Algorithm follows vllm-project/vllm#43169 (Apache-2.0).
Modifications
Replace the two per-item loops with three helpers (output is shape- and order-identical to the previous code):
_flatten_pixel_lists— walk items in order, normalize shapes, and record an ordered slot list so prepass (already-embedded) entries and raw-pixel entries reassemble in their original positions._batched_encode— bucket entries by patch count (resolution bucket) so each encoder forward is a uniform-shape batch with no cross-resolution padding; optionally chunk a bucket to bound encoder activation memory (_encoder_max_batch); then runembed_visionexactly once over the concatenated valid tokens._gather_mm_features— driver shared by the image and video paths; reassembles per-item outputs in original walk order._encoder_max_batchlazily derives a per-process budget (5% of device memory) and a per-patch activation cost cached at the end ofload_weightsfrom the loadedvision_config; if unavailable it degrades to a single-item batch.Accuracy Tests
30-prompt colored-image labelling (
temperature=0,seed=0, 1–6 images/prompt) ongoogle/gemma-4-E2B-it, 1× B200:The patch is provably output-identical to baseline — every one of the 30 responses matches character-for-character.
MMMU-val (100 samples,
sglang.test.run_evalmmmu,google/gemma-4-E2B-it):Identical score and per-subject breakdown — accuracy-neutral on a real multimodal reasoning benchmark, and above the 0.26 threshold registered for the Gemma4 family in
test/registered/eval/test_vlms_mmmu_eval.py.Speed Tests and Profiling
1× B200 (sm_100a), bf16, TP=1,
--attention-backend triton. Load gen:vllm bench serve --dataset-name random-mm, 6× 480×480 images/prompt, 50 in / 10 out, 100 prompts. Geomean of 2 steady runs.Cache-off (
--disable-radix-cache, isolates the encoder):Cache-on (default radix cache): parity within noise — on the cyclic
random-mmdistribution the radix cache hits >60%, so the encoder is rarely re-run. Production workloads with diverse (low-cache-hit) user images sit closer to the cache-off numbers.Profiling (torch profiler, 12-prompt batch = 72 images): CPU-side op-launch counts confirm the per-image loop collapses into one batched forward per resolution bucket:
aten::bmm(vision attn)aten::linearaten::geluChecklist
Unit tests:
test/srt/models/test_gemma4_mm_batched_encoder.py(5 CPU-only tests) — single-resolution single-call, mixed-resolution bucketing, budget-bound chunking, empty input, and interleaved prepass/raw ordering. All pass.CI States
Latest PR Test (Base): ✅ Run #26801493849
Latest PR Test (Extra): ❌ Run #26801493739