From c68e3b295d58f320b6f0c3647354466fad152da2 Mon Sep 17 00:00:00 2001 From: Luciano Martins Date: Tue, 19 May 2026 22:31:40 +0000 Subject: [PATCH] Batch vision encoder calls for Gemma4 image and video processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Decompose monolithic vt() calls into patch_embedder → encoder → pooler → standardize → embed_vision, enabling batch parallelism across images/frames instead of serial per-item processing - Group images by resolution bucket (patch count) to avoid cross-resolution padding waste in batched encoder calls - Dynamically compute encoder batch ceiling from per-patch activation cost (hidden_size × 2 × num_hidden_layers) and runtime free GPU memory, lazily initialized on first inference after KV cache allocation to reflect actual memory availability - Batch embed_vision (RMSNorm + Linear) across all items in a single call since it is pointwise and resolution-independent Signed-off-by: Luciano Martins --- vllm/model_executor/models/gemma4_mm.py | 266 ++++++++++++++++-------- 1 file changed, 180 insertions(+), 86 deletions(-) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 91be7d47f6f0..db8a8c9305f7 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -61,6 +61,7 @@ PromptUpdate, PromptUpdateDetails, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -960,6 +961,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_vision = Gemma4MultimodalEmbedder( config.vision_config, config.text_config ) + # Lazy-initialized on first encoder call (see _encoder_max_batch). + self._encoder_budget_bytes = 0 + self._encoder_bytes_per_patch = 0 # ---- Audio tower (variants with audio_config) ---- if config.audio_config is not None: @@ -1100,6 +1104,19 @@ def _parse_and_validate_multimodal_inputs( ) return mm_input_by_modality + def _encoder_max_batch(self, patches_per_item: int) -> int: + """Max items per encoder call given per-item patch count.""" + if self._encoder_budget_bytes == 0: + total_mem = current_platform.get_device_total_memory() + self._encoder_budget_bytes = int(total_mem * 0.05) + logger.info( + "Encoder memory budget: %.1fGB (total=%.1fGB)", + self._encoder_budget_bytes / 1024**3, + total_mem / 1024**3, + ) + cost = patches_per_item * self._encoder_bytes_per_patch + return max(1, self._encoder_budget_bytes // cost) if cost > 0 else 1 + # ------------------------------------------------------------------ # # Image processing # ------------------------------------------------------------------ # @@ -1108,73 +1125,103 @@ def _process_image_input( self, image_input: Gemma4ImageInputs, ) -> list[torch.Tensor]: + """Batch-encode images through the vision tower. + + Groups images by patch count (resolution bucket) so each + encoder call processes a uniform-shape batch with no + cross-resolution padding. Pooling and projection are then + applied over a single concatenated tensor for all images. + """ pixel_values = image_input["pixel_values"] pixel_position_ids = image_input["pixel_position_ids"] - # The HF image processor now outputs pre-patchified data: - # pixel_values: (num_images, max_patches, patch_pixels) - # pixel_position_ids: (num_images, max_patches, 2) - # We call the vision tower's forward() directly, which handles - # patch embedding, encoding, pooling, padding removal, and - # optional standardization internally. vt = self.vision_tower pooling_k2 = self.config.vision_config.pooling_kernel_size**2 - # TODO: Move this per-image loop into the input processor to - # reduce dynamism at the model runner / engine core. This - # requires spatially padding all images to uniform (H_max, - # W_max) in _call_hf_processor() so they arrive as a single - # stacked tensor, tracking padded regions via image_sizes - # metadata, and validating numerical equivalence with the - # current per-image path. - # # Concurrent requests with different image resolutions may # arrive as a list of per-image tensors, while same-resolution - # batches may arrive as a stacked tensor. Both forms are - # iterable over the per-image dimension. - - # Process each image individually through the vision tower. - # The vision tower's forward() strips padding and returns a - # flat tensor of valid tokens. We process per-image to get - # variable-length outputs matching the dynamic token count - # from get_image_repl. - per_image_features = [] - for pv, pp in zip(pixel_values, pixel_position_ids, strict=True): - pv = pv.unsqueeze(0) # (1, max_patches, patch_pixels) - pp = pp.unsqueeze(0) # (1, max_patches, 2) - - # Derive the pooler's output_length from the total patch - # count (including padding). The vision tower encoder - # processes ALL patches — padding patches get zero hidden - # states but still occupy sequence positions. The pooler's - # _avg_pool_by_positions requires: - # input_seq_len / output_length == k² - # where k == pooling_kernel_size. The image processor - # allocates max_patches = max_soft_tokens * k² total slots, - # so output_length = max_patches / k² == max_soft_tokens. - # Without this, the pooler falls back to - # config.image_seq_length (e.g. 280), which fails when a - # different max_soft_tokens was used at preprocessing time. - max_patches = pv.shape[1] - output_length = max_patches // pooling_k2 - - vt_output = vt(pv, pp, output_length=output_length) - # last_hidden_state: (num_valid_tokens, hidden_size) - # — already flat with padding stripped by the vision tower - per_image_features.append(vt_output.last_hidden_state) - - # Project each image's features into LM embedding space. - # Per-image loop is required because images have variable - # token counts after padding removal. - # Cast to match the projection layer's dtype (model may be - # bf16 while the vision tower outputs fp32). - target_dtype = self.embed_vision.embedding_projection.weight.dtype - return [ - self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze( - 0 + # batches may arrive as a stacked tensor. + buckets: dict[int, list[tuple[int, torch.Tensor, torch.Tensor]]] = {} + total_images = ( + len(pixel_values) + if isinstance(pixel_values, list) + else pixel_values.shape[0] + ) + + for idx in range(total_images): + pv = pixel_values[idx] + pp = pixel_position_ids[idx] + buckets.setdefault(pv.shape[0], []).append((idx, pv, pp)) + + # Encode each resolution bucket in memory-safe chunks. + last_hidden_states_map: dict[int, torch.Tensor] = {} + for patches, items in buckets.items(): + max_batch_size = min(len(items), self._encoder_max_batch(patches)) + + for chunk_idx in range(0, len(items), max_batch_size): + chunk_items = items[chunk_idx : chunk_idx + max_batch_size] + + pv_tensor = torch.cat( + [item[1].unsqueeze(0) for item in chunk_items], dim=0 + ) + pp_tensor = torch.cat( + [item[2].unsqueeze(0) for item in chunk_items], dim=0 + ) + pad_tensor = (pp_tensor == -1).all(dim=-1) + + inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor) + encoder_outputs = vt.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~pad_tensor, + pixel_position_ids=pp_tensor, + ) + hidden_states = encoder_outputs.last_hidden_state + + for i, (orig_idx, _, _) in enumerate(chunk_items): + last_hidden_states_map[orig_idx] = hidden_states[i] + + # Pool per image to strip padding and reduce spatial resolution. + all_valid_states: list[torch.Tensor] = [None] * total_images # type: ignore[list-item] + valid_lens = [0] * total_images + + for orig_idx in range(total_images): + chunk_hidden = last_hidden_states_map[orig_idx] + output_length = chunk_hidden.shape[0] // pooling_k2 + + single_hidden = chunk_hidden.unsqueeze(0) + single_pos_ids = pixel_position_ids[orig_idx].unsqueeze(0) + padding_positions = (single_pos_ids == -1).all(dim=-1) + + pooled_states, valid_mask = vt.pooler( + hidden_states=single_hidden, + pixel_position_ids=single_pos_ids, + padding_positions=padding_positions, + output_length=output_length, ) - for img in per_image_features - ] + valid_states = pooled_states[valid_mask] + + if getattr(vt.config, "standardize", False): + valid_states = (valid_states - vt.std_bias) * vt.std_scale + + all_valid_states[orig_idx] = valid_states + valid_lens[orig_idx] = valid_states.shape[0] + + target_dtype = self.embed_vision.embedding_projection.weight.dtype + + # Project all images in a single batched call. + flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype) + flat_proj_embs = self.embed_vision( + inputs_embeds=flat_valid_states.unsqueeze(0) + ).squeeze(0) + + # Split back into per-image tensors (slicing returns views). + per_image_embeddings: list[torch.Tensor] = [] + offset = 0 + for length in valid_lens: + per_image_embeddings.append(flat_proj_embs[offset : offset + length]) + offset += length + + return per_image_embeddings # ------------------------------------------------------------------ # # Video processing (frames through vision tower) @@ -1184,16 +1231,16 @@ def _process_video_input( self, video_input: dict[str, torch.Tensor], ) -> list[torch.Tensor]: - """Process video frames through the vision tower. + """Batch-encode video frames through the vision tower. - Reuses the image processing pipeline — Gemma4 has no separate - video tower; video frames are just images at lower resolution - (max_soft_tokens=70). + Gemma4 has no separate video tower; video frames are images at + lower resolution (max_soft_tokens=70). All frames across all + videos in the batch are encoded together in chunks, then pooled + and projected in a single batched call. Returns one concatenated embedding tensor per video (not per - frame), because vLLM treats one video as one multimodal item. - The flat_from_sizes field config groups all frames of a video - together, so embed_multimodal must return one tensor per video. + frame), matching the flat_from_sizes grouping that vLLM expects + for embed_multimodal. """ pixel_values = video_input["pixel_values_videos"] pixel_position_ids = video_input["pixel_position_ids_videos"] @@ -1203,35 +1250,74 @@ def _process_video_input( pooling_k2 = self.config.vision_config.pooling_kernel_size**2 target_dtype = self.embed_vision.embedding_projection.weight.dtype - # Split flat tensors into per-video chunks if isinstance(frame_counts, torch.Tensor): fc_list = frame_counts.tolist() else: fc_list = list(frame_counts) - pv_per_video = torch.split(pixel_values, fc_list, dim=0) - pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0) + total_frames = pixel_values.shape[0] + max_batch_size = min( + total_frames, self._encoder_max_batch(pixel_values.shape[1]) + ) - per_video_embeddings = [] - for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video): - frame_embs = [] - for i in range(pv_chunk.shape[0]): - pv = pv_chunk[i].unsqueeze(0) - pp = pp_chunk[i].unsqueeze(0) + padding_positions = (pixel_position_ids == -1).all(dim=-1) - max_patches = pv.shape[1] - output_length = max_patches // pooling_k2 + # Encode frames in chunks bounded by _encoder_max_batch. + last_hidden_states_list: list[torch.Tensor] = [] + for i in range(0, total_frames, max_batch_size): + pv_chunk = pixel_values[i : i + max_batch_size] + pp_chunk = pixel_position_ids[i : i + max_batch_size] + pad_chunk = padding_positions[i : i + max_batch_size] - vt_output = vt(pv, pp, output_length=output_length) - frame_emb = self.embed_vision( - inputs_embeds=( - vt_output.last_hidden_state.unsqueeze(0).to(target_dtype) - ) - ).squeeze(0) - frame_embs.append(frame_emb) + inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk) + encoder_outputs = vt.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~pad_chunk, + pixel_position_ids=pp_chunk, + ) + last_hidden_states_list.append(encoder_outputs.last_hidden_state) + + last_hidden_states = torch.cat(last_hidden_states_list, dim=0) - # Concatenate all frames of this video into one tensor. - per_video_embeddings.append(torch.cat(frame_embs, dim=0)) + # Pool per frame to strip padding and reduce spatial resolution. + output_length = pixel_values.shape[1] // pooling_k2 + all_frame_valid_states: list[torch.Tensor] = [] + frame_valid_lens: list[int] = [] + + for i in range(total_frames): + single_hidden = last_hidden_states[i].unsqueeze(0) + single_pos_ids = pixel_position_ids[i].unsqueeze(0) + single_pad_pos = padding_positions[i].unsqueeze(0) + + pooled_states, valid_mask = vt.pooler( + hidden_states=single_hidden, + pixel_position_ids=single_pos_ids, + padding_positions=single_pad_pos, + output_length=output_length, + ) + valid_states = pooled_states[valid_mask] + + if getattr(vt.config, "standardize", False): + valid_states = (valid_states - vt.std_bias) * vt.std_scale + + all_frame_valid_states.append(valid_states) + frame_valid_lens.append(valid_states.shape[0]) + + # Project all frames in a single batched call. + flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(target_dtype) + flat_proj_embs = self.embed_vision( + inputs_embeds=flat_valid_states.unsqueeze(0) + ).squeeze(0) + + # Regroup into per-video tensors (slicing returns views). + per_video_embeddings: list[torch.Tensor] = [] + frame_idx = 0 + offset = 0 + for count in fc_list: + video_tokens = sum(frame_valid_lens[frame_idx : frame_idx + count]) + per_video_embeddings.append(flat_proj_embs[offset : offset + video_tokens]) + offset += video_tokens + frame_idx += count return per_video_embeddings @@ -1452,7 +1538,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, ignore_unexpected_prefixes=ignore_prefixes, ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # Per-patch activation cost for dynamic encoder batch sizing. + vis_cfg = self.config.vision_config + self._encoder_bytes_per_patch = ( + vis_cfg.hidden_size * 2 * vis_cfg.num_hidden_layers + ) + + return loaded # ------------------------------------------------------------------ # # LoRA / multimodal mapping