Skip to content
266 changes: 180 additions & 86 deletions vllm/model_executor/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
PromptUpdate,
PromptUpdateDetails,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

Expand Down Expand Up @@ -960,6 +961,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config, config.text_config
)
# Lazy-initialized on first encoder call (see _encoder_max_batch).
self._encoder_budget_bytes = 0
self._encoder_bytes_per_patch = 0

# ---- Audio tower (variants with audio_config) ----
if config.audio_config is not None:
Expand Down Expand Up @@ -1100,6 +1104,19 @@ def _parse_and_validate_multimodal_inputs(
)
return mm_input_by_modality

def _encoder_max_batch(self, patches_per_item: int) -> int:
"""Max items per encoder call given per-item patch count."""
if self._encoder_budget_bytes == 0:
total_mem = current_platform.get_device_total_memory()
self._encoder_budget_bytes = int(total_mem * 0.05)
logger.info(
"Encoder memory budget: %.1fGB (total=%.1fGB)",
self._encoder_budget_bytes / 1024**3,
total_mem / 1024**3,
)
cost = patches_per_item * self._encoder_bytes_per_patch
return max(1, self._encoder_budget_bytes // cost) if cost > 0 else 1

# ------------------------------------------------------------------ #
# Image processing
# ------------------------------------------------------------------ #
Expand All @@ -1108,73 +1125,103 @@ def _process_image_input(
self,
image_input: Gemma4ImageInputs,
) -> list[torch.Tensor]:
"""Batch-encode images through the vision tower.

Groups images by patch count (resolution bucket) so each
encoder call processes a uniform-shape batch with no
cross-resolution padding. Pooling and projection are then
applied over a single concatenated tensor for all images.
"""
pixel_values = image_input["pixel_values"]
pixel_position_ids = image_input["pixel_position_ids"]

# The HF image processor now outputs pre-patchified data:
# pixel_values: (num_images, max_patches, patch_pixels)
# pixel_position_ids: (num_images, max_patches, 2)
# We call the vision tower's forward() directly, which handles
# patch embedding, encoding, pooling, padding removal, and
# optional standardization internally.
vt = self.vision_tower
pooling_k2 = self.config.vision_config.pooling_kernel_size**2

# TODO: Move this per-image loop into the input processor to
# reduce dynamism at the model runner / engine core. This
# requires spatially padding all images to uniform (H_max,
# W_max) in _call_hf_processor() so they arrive as a single
# stacked tensor, tracking padded regions via image_sizes
# metadata, and validating numerical equivalence with the
# current per-image path.
#
# Concurrent requests with different image resolutions may
# arrive as a list of per-image tensors, while same-resolution
# batches may arrive as a stacked tensor. Both forms are
# iterable over the per-image dimension.

# Process each image individually through the vision tower.
# The vision tower's forward() strips padding and returns a
# flat tensor of valid tokens. We process per-image to get
# variable-length outputs matching the dynamic token count
# from get_image_repl.
per_image_features = []
for pv, pp in zip(pixel_values, pixel_position_ids, strict=True):
pv = pv.unsqueeze(0) # (1, max_patches, patch_pixels)
pp = pp.unsqueeze(0) # (1, max_patches, 2)

# Derive the pooler's output_length from the total patch
# count (including padding). The vision tower encoder
# processes ALL patches — padding patches get zero hidden
# states but still occupy sequence positions. The pooler's
# _avg_pool_by_positions requires:
# input_seq_len / output_length == k²
# where k == pooling_kernel_size. The image processor
# allocates max_patches = max_soft_tokens * k² total slots,
# so output_length = max_patches / k² == max_soft_tokens.
# Without this, the pooler falls back to
# config.image_seq_length (e.g. 280), which fails when a
# different max_soft_tokens was used at preprocessing time.
max_patches = pv.shape[1]
output_length = max_patches // pooling_k2

vt_output = vt(pv, pp, output_length=output_length)
# last_hidden_state: (num_valid_tokens, hidden_size)
# — already flat with padding stripped by the vision tower
per_image_features.append(vt_output.last_hidden_state)

# Project each image's features into LM embedding space.
# Per-image loop is required because images have variable
# token counts after padding removal.
# Cast to match the projection layer's dtype (model may be
# bf16 while the vision tower outputs fp32).
target_dtype = self.embed_vision.embedding_projection.weight.dtype
return [
self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze(
0
# batches may arrive as a stacked tensor.
buckets: dict[int, list[tuple[int, torch.Tensor, torch.Tensor]]] = {}
total_images = (
len(pixel_values)
if isinstance(pixel_values, list)
else pixel_values.shape[0]
)

for idx in range(total_images):
pv = pixel_values[idx]
pp = pixel_position_ids[idx]
buckets.setdefault(pv.shape[0], []).append((idx, pv, pp))

# Encode each resolution bucket in memory-safe chunks.
last_hidden_states_map: dict[int, torch.Tensor] = {}
for patches, items in buckets.items():
max_batch_size = min(len(items), self._encoder_max_batch(patches))

for chunk_idx in range(0, len(items), max_batch_size):
chunk_items = items[chunk_idx : chunk_idx + max_batch_size]

pv_tensor = torch.cat(
[item[1].unsqueeze(0) for item in chunk_items], dim=0
)
pp_tensor = torch.cat(
[item[2].unsqueeze(0) for item in chunk_items], dim=0
)
pad_tensor = (pp_tensor == -1).all(dim=-1)

inputs_embeds = vt.patch_embedder(pv_tensor, pp_tensor, pad_tensor)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_tensor,
pixel_position_ids=pp_tensor,
)
hidden_states = encoder_outputs.last_hidden_state

for i, (orig_idx, _, _) in enumerate(chunk_items):
last_hidden_states_map[orig_idx] = hidden_states[i]

# Pool per image to strip padding and reduce spatial resolution.
all_valid_states: list[torch.Tensor] = [None] * total_images # type: ignore[list-item]
valid_lens = [0] * total_images

for orig_idx in range(total_images):
chunk_hidden = last_hidden_states_map[orig_idx]
output_length = chunk_hidden.shape[0] // pooling_k2

single_hidden = chunk_hidden.unsqueeze(0)
single_pos_ids = pixel_position_ids[orig_idx].unsqueeze(0)
padding_positions = (single_pos_ids == -1).all(dim=-1)

pooled_states, valid_mask = vt.pooler(
hidden_states=single_hidden,
pixel_position_ids=single_pos_ids,
padding_positions=padding_positions,
output_length=output_length,
)
for img in per_image_features
]
valid_states = pooled_states[valid_mask]

if getattr(vt.config, "standardize", False):
valid_states = (valid_states - vt.std_bias) * vt.std_scale

all_valid_states[orig_idx] = valid_states
valid_lens[orig_idx] = valid_states.shape[0]

target_dtype = self.embed_vision.embedding_projection.weight.dtype

# Project all images in a single batched call.
flat_valid_states = torch.cat(all_valid_states, dim=0).to(target_dtype)
flat_proj_embs = self.embed_vision(
inputs_embeds=flat_valid_states.unsqueeze(0)
).squeeze(0)

# Split back into per-image tensors (slicing returns views).
per_image_embeddings: list[torch.Tensor] = []
offset = 0
for length in valid_lens:
per_image_embeddings.append(flat_proj_embs[offset : offset + length])
offset += length

return per_image_embeddings

# ------------------------------------------------------------------ #
# Video processing (frames through vision tower)
Expand All @@ -1184,16 +1231,16 @@ def _process_video_input(
self,
video_input: dict[str, torch.Tensor],
) -> list[torch.Tensor]:
"""Process video frames through the vision tower.
"""Batch-encode video frames through the vision tower.

Reuses the image processing pipeline — Gemma4 has no separate
video tower; video frames are just images at lower resolution
(max_soft_tokens=70).
Gemma4 has no separate video tower; video frames are images at
lower resolution (max_soft_tokens=70). All frames across all
videos in the batch are encoded together in chunks, then pooled
and projected in a single batched call.

Returns one concatenated embedding tensor per video (not per
frame), because vLLM treats one video as one multimodal item.
The flat_from_sizes field config groups all frames of a video
together, so embed_multimodal must return one tensor per video.
frame), matching the flat_from_sizes grouping that vLLM expects
for embed_multimodal.
"""
pixel_values = video_input["pixel_values_videos"]
pixel_position_ids = video_input["pixel_position_ids_videos"]
Expand All @@ -1203,35 +1250,74 @@ def _process_video_input(
pooling_k2 = self.config.vision_config.pooling_kernel_size**2
target_dtype = self.embed_vision.embedding_projection.weight.dtype

# Split flat tensors into per-video chunks
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
else:
fc_list = list(frame_counts)

pv_per_video = torch.split(pixel_values, fc_list, dim=0)
pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)
total_frames = pixel_values.shape[0]
max_batch_size = min(
total_frames, self._encoder_max_batch(pixel_values.shape[1])
)

per_video_embeddings = []
for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
frame_embs = []
for i in range(pv_chunk.shape[0]):
pv = pv_chunk[i].unsqueeze(0)
pp = pp_chunk[i].unsqueeze(0)
padding_positions = (pixel_position_ids == -1).all(dim=-1)

max_patches = pv.shape[1]
output_length = max_patches // pooling_k2
# Encode frames in chunks bounded by _encoder_max_batch.
last_hidden_states_list: list[torch.Tensor] = []
for i in range(0, total_frames, max_batch_size):
pv_chunk = pixel_values[i : i + max_batch_size]
pp_chunk = pixel_position_ids[i : i + max_batch_size]
pad_chunk = padding_positions[i : i + max_batch_size]

vt_output = vt(pv, pp, output_length=output_length)
frame_emb = self.embed_vision(
inputs_embeds=(
vt_output.last_hidden_state.unsqueeze(0).to(target_dtype)
)
).squeeze(0)
frame_embs.append(frame_emb)
inputs_embeds = vt.patch_embedder(pv_chunk, pp_chunk, pad_chunk)
encoder_outputs = vt.encoder(
inputs_embeds=inputs_embeds,
attention_mask=~pad_chunk,
pixel_position_ids=pp_chunk,
)
last_hidden_states_list.append(encoder_outputs.last_hidden_state)

last_hidden_states = torch.cat(last_hidden_states_list, dim=0)

# Concatenate all frames of this video into one tensor.
per_video_embeddings.append(torch.cat(frame_embs, dim=0))
# Pool per frame to strip padding and reduce spatial resolution.
output_length = pixel_values.shape[1] // pooling_k2
all_frame_valid_states: list[torch.Tensor] = []
frame_valid_lens: list[int] = []

for i in range(total_frames):
single_hidden = last_hidden_states[i].unsqueeze(0)
single_pos_ids = pixel_position_ids[i].unsqueeze(0)
single_pad_pos = padding_positions[i].unsqueeze(0)

pooled_states, valid_mask = vt.pooler(
hidden_states=single_hidden,
pixel_position_ids=single_pos_ids,
padding_positions=single_pad_pos,
output_length=output_length,
)
valid_states = pooled_states[valid_mask]

if getattr(vt.config, "standardize", False):
valid_states = (valid_states - vt.std_bias) * vt.std_scale

all_frame_valid_states.append(valid_states)
frame_valid_lens.append(valid_states.shape[0])

# Project all frames in a single batched call.
flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to(target_dtype)
Comment thread
lucianommartins marked this conversation as resolved.
flat_proj_embs = self.embed_vision(
inputs_embeds=flat_valid_states.unsqueeze(0)
).squeeze(0)

# Regroup into per-video tensors (slicing returns views).
per_video_embeddings: list[torch.Tensor] = []
frame_idx = 0
offset = 0
for count in fc_list:
video_tokens = sum(frame_valid_lens[frame_idx : frame_idx + count])
per_video_embeddings.append(flat_proj_embs[offset : offset + video_tokens])
offset += video_tokens
frame_idx += count

return per_video_embeddings

Expand Down Expand Up @@ -1452,7 +1538,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self,
ignore_unexpected_prefixes=ignore_prefixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

# Per-patch activation cost for dynamic encoder batch sizing.
vis_cfg = self.config.vision_config
self._encoder_bytes_per_patch = (
vis_cfg.hidden_size * 2 * vis_cfg.num_hidden_layers
)

return loaded

# ------------------------------------------------------------------ #
# LoRA / multimodal mapping
Expand Down
Loading