Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 165 additions & 86 deletions python/sglang/srt/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder
from sglang.srt.models.gemma4_causal import Gemma4TextModel, pp_filter_load_weight
from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, get_device_memory_capacity
from sglang.srt.utils.hf_transformers_utils import get_processor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -259,6 +259,10 @@ def __init__(
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_aux_hidden_states = False

# Lazy-initialized dynamic batch sizing for the vision encoder.
self._encoder_budget_bytes = 0
self._encoder_bytes_per_patch = 0

self.post_init()

@property
Expand Down Expand Up @@ -396,124 +400,189 @@ def prepare_attn_masks(
)
get_attn_backend().forward_metadata.custom_mask = bidirectional_attn_masks

def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
vt = self.vision_tower
def _encoder_max_batch(self, patches_per_item: int) -> int:
"""Max items per encoder call for a given per-item patch count.

Lazily computes a budget of 5% of device memory on first use and reuses
it. Falls back to a single-item batch before the per-patch cost is known
(populated by `load_weights`) or when memory is unknown.
"""
if self._encoder_bytes_per_patch == 0:
return 1
if self._encoder_budget_bytes == 0:
# MiB or None; cache None as -1 so we probe once, not every call.
total_mem_mib = get_device_memory_capacity(self.vision_tower.device.type)
if total_mem_mib:
self._encoder_budget_bytes = int(total_mem_mib * (1 << 20) * 0.05)
else:
self._encoder_budget_bytes = -1
if self._encoder_budget_bytes < 0:
return 1
cost = patches_per_item * self._encoder_bytes_per_patch
if cost <= 0:
return 1
return max(1, self._encoder_budget_bytes // cost)

def _flatten_pixel_lists(
self,
items: List[MultimodalDataItem],
position_ids_attr: str,
modality_label: str,
) -> Tuple[List[Tuple[bool, object]], List[torch.Tensor], List[torch.Tensor]]:
"""Walk `items` in order and return:
- `slots`: one `(is_prepass, payload)` per output group, in original
walk order. `payload` is a ready prepass embedding (`is_prepass=True`)
or an index into `pixel_values_list`. The slot order keeps the output
identical to the per-item loop even when a request interleaves prepass
embeddings and raw pixels.
- `pixel_values_list`: per-item pixel tensors (num_patches, patch_px);
video items contribute one entry per frame.
- `position_ids_list`: matching (num_patches, 2) tensors, -1 = padding.
"""
slots: List[Tuple[bool, object]] = []
pixel_values_list: List[torch.Tensor] = []
position_ids_list: List[torch.Tensor] = []

all_embeds = []
for item in items:
all_pixel_values = flatten_nested_list([item.feature])
all_position_ids = flatten_nested_list(
[getattr(item, "image_position_ids", None)]
[getattr(item, position_ids_attr, None)]
)

for pv_idx, pv in enumerate(all_pixel_values):
# Caller pre-computed the embedding; nothing to encode.
if (
pv.dim() in (2, 3)
and pv.shape[-1] == self.config.text_config.hidden_size
):
all_embeds.append(pv.to(self.language_model.device))
slots.append((True, pv.to(self.language_model.device)))
continue

if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None:
raise ValueError(
f"pixel_values[{pv_idx}] has no matching image_position_ids. "
"The HF image processor likely renamed this output — "
"update ATTR_NAME_TO_MODALITY in the Gemma4 processor."
f"{modality_label}[{pv_idx}] has no matching "
f"{position_ids_attr}. The HF processor likely "
"renamed this output — update ATTR_NAME_TO_MODALITY "
"in the Gemma4 processor."
)
pp = all_position_ids[pv_idx]

# Vision tower expects 3-D (batch, num_patches, ...).
# A single image may arrive as 2-D; add the batch dim if needed.
# Normalize to 3-D (items, num_patches, ...); 4-D video tensors
# (num_videos, num_frames, ...) flatten into the leading dim.
if pv.dim() == 2:
pv = pv.unsqueeze(0)
if pp.dim() == 2:
pp = pp.unsqueeze(0)
if pv.dim() == 4:
pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1])
if pp.dim() == 4:
pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1])

pv = pv.to(device=vt.device, dtype=self.language_model.dtype())
pp = pp.to(device=vt.device)

pooled, pooler_mask = vt(pv, pp)
# unbind() returns views, so per-item split is copy-free.
for sub_pv, sub_pp in zip(pv.unbind(0), pp.unbind(0)):
slots.append((False, len(pixel_values_list)))
pixel_values_list.append(sub_pv)
position_ids_list.append(sub_pp)

for hs, mask in zip(pooled, pooler_mask):
real_tokens = hs[mask]
all_embeds.append(
self.embed_vision(
inputs_embeds=real_tokens.unsqueeze(0)
).squeeze(0)
)
return slots, pixel_values_list, position_ids_list

if all_embeds:
return torch.cat(all_embeds, dim=0)
else:
return torch.empty(
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
"""Encode video frames through the vision tower with video-specific pooling.

Each video is (num_frames, num_patches, patch_pixels) with matching
position_ids (num_frames, num_patches, 2). Frames are flattened into
the batch dimension so each frame is encoded independently, then pooled
dynamically based on the input patch count and pooling_kernel_size.
def _batched_encode(
self,
pixel_values_list: List[torch.Tensor],
position_ids_list: List[torch.Tensor],
) -> List[torch.Tensor]:
"""Run the vision tower on `pixel_values_list` in resolution buckets,
run `embed_vision` exactly once over all valid tokens, and return the
per-item embeddings in the original input order.
"""
vt = self.vision_tower
if not pixel_values_list:
return []

all_embeds = []
for item in items:
all_pixel_values = flatten_nested_list([item.feature])
all_position_ids = flatten_nested_list(
[getattr(item, "video_position_ids", None)]
)

for pv_idx, pv in enumerate(all_pixel_values):
if (
pv.dim() in (2, 3)
and pv.shape[-1] == self.config.text_config.hidden_size
):
all_embeds.append(pv.to(self.language_model.device))
continue

if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None:
raise ValueError(
f"pixel_values_videos[{pv_idx}] has no matching video_position_ids."
)
pp = all_position_ids[pv_idx]

# HF processor returns 4-D tensors
# (num_videos, num_frames, num_patches, ...) — collapse to
# 3-D (num_frames, num_patches, ...) so each frame is a
# batch element for the vision tower.
if pv.dim() == 4:
pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1])
if pp.dim() == 4:
pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1])
vt = self.vision_tower
target_device = vt.device
target_dtype = self.language_model.dtype()

pv = pv.to(device=vt.device, dtype=self.language_model.dtype())
pp = pp.to(device=vt.device)
# Bucket by patch count so each encoder forward is a same-shape batch
# with no cross-resolution padding waste.
buckets: dict = {}
for idx, pv in enumerate(pixel_values_list):
buckets.setdefault(pv.shape[0], []).append(idx)

pooled, pooler_mask = vt(pv, pp)
per_item_valid_tokens: List[Optional[torch.Tensor]] = [None] * len(
pixel_values_list
)

for hs, mask in zip(pooled, pooler_mask):
real_tokens = hs[mask]
all_embeds.append(
self.embed_vision(
inputs_embeds=real_tokens.unsqueeze(0)
).squeeze(0)
)
for patches, member_indices in buckets.items():
max_batch = min(len(member_indices), self._encoder_max_batch(patches))

for chunk_start in range(0, len(member_indices), max_batch):
chunk_indices = member_indices[chunk_start : chunk_start + max_batch]

pv_batch = torch.stack(
[pixel_values_list[i] for i in chunk_indices], dim=0
).to(device=target_device, dtype=target_dtype)
pp_batch = torch.stack(
[position_ids_list[i] for i in chunk_indices], dim=0
).to(device=target_device)

# pooler_mask marks valid tokens; valid widths differ per item.
pooled, pooler_mask = vt(pv_batch, pp_batch)

for chunk_pos, orig_idx in enumerate(chunk_indices):
per_item_valid_tokens[orig_idx] = pooled[chunk_pos][
pooler_mask[chunk_pos]
]

# embed_vision is pointwise (RMSNorm + Linear), so one call over all
# valid tokens is identical to per-item projection.
valid_lens = [t.shape[0] for t in per_item_valid_tokens]
flat_tokens = torch.cat(per_item_valid_tokens, dim=0)
flat_projected = self.embed_vision(
inputs_embeds=flat_tokens.unsqueeze(0)
).squeeze(0)

per_item_embeds: List[torch.Tensor] = []
offset = 0
for length in valid_lens:
per_item_embeds.append(flat_projected[offset : offset + length])
offset += length
return per_item_embeds

def _gather_mm_features(
self,
items: List[MultimodalDataItem],
position_ids_attr: str,
modality_label: str,
) -> torch.Tensor:
"""Common driver shared by image and video paths."""
slots, pv_list, pp_list = self._flatten_pixel_lists(
items, position_ids_attr, modality_label
)
encoded_embeds = self._batched_encode(pv_list, pp_list)
# Reassemble in walk order to match the pre-batching per-item loop.
all_embeds = [
payload if is_prepass else encoded_embeds[payload]
for is_prepass, payload in slots
]

if all_embeds:
return torch.cat(all_embeds, dim=0)
else:
return torch.empty(
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)
return torch.empty(
0,
self.language_model.config.hidden_size,
device=next(self.parameters()).device,
dtype=self.language_model.dtype(),
)

def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
return self._gather_mm_features(items, "image_position_ids", "pixel_values")

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Gemma4 has no separate video tower; frames go through the same
# bucketed image path.
return self._gather_mm_features(
items, "video_position_ids", "pixel_values_videos"
)

def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if self.audio_tower is None:
Expand Down Expand Up @@ -1023,6 +1092,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
names = sorted(p for p in unloaded_params if pred(p))
if names:
logger.log(level, "%s: %s", msg, names)

# Cache the per-patch cost for `_encoder_max_batch` here (not __init__)
# so it reflects the actually-loaded vision_config.
vis_cfg = getattr(self.config, "vision_config", None)
if vis_cfg is not None and self.pp_group.is_first_rank:
hidden = int(getattr(vis_cfg, "hidden_size", 0))
num_layers = int(getattr(vis_cfg, "num_hidden_layers", 0))
# 2 bytes/elem (bf16) × residual stream per patch × layers.
self._encoder_bytes_per_patch = hidden * 2 * num_layers

return loaded_params

lora_pattern = re.compile(
Expand Down
Loading
Loading