diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index 3629782ca3ba..75d237f99091 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -61,7 +61,7 @@ from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder from sglang.srt.models.gemma4_causal import Gemma4TextModel, pp_filter_load_weight from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, get_device_memory_capacity from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -259,6 +259,10 @@ def __init__( self.logits_processor = LogitsProcessor(config.text_config) self.capture_aux_hidden_states = False + # Lazy-initialized dynamic batch sizing for the vision encoder. + self._encoder_budget_bytes = 0 + self._encoder_bytes_per_patch = 0 + self.post_init() @property @@ -396,124 +400,189 @@ def prepare_attn_masks( ) get_attn_backend().forward_metadata.custom_mask = bidirectional_attn_masks - def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - vt = self.vision_tower + def _encoder_max_batch(self, patches_per_item: int) -> int: + """Max items per encoder call for a given per-item patch count. + + Lazily computes a budget of 5% of device memory on first use and reuses + it. Falls back to a single-item batch before the per-patch cost is known + (populated by `load_weights`) or when memory is unknown. + """ + if self._encoder_bytes_per_patch == 0: + return 1 + if self._encoder_budget_bytes == 0: + # MiB or None; cache None as -1 so we probe once, not every call. + total_mem_mib = get_device_memory_capacity(self.vision_tower.device.type) + if total_mem_mib: + self._encoder_budget_bytes = int(total_mem_mib * (1 << 20) * 0.05) + else: + self._encoder_budget_bytes = -1 + if self._encoder_budget_bytes < 0: + return 1 + cost = patches_per_item * self._encoder_bytes_per_patch + if cost <= 0: + return 1 + return max(1, self._encoder_budget_bytes // cost) + + def _flatten_pixel_lists( + self, + items: List[MultimodalDataItem], + position_ids_attr: str, + modality_label: str, + ) -> Tuple[List[Tuple[bool, object]], List[torch.Tensor], List[torch.Tensor]]: + """Walk `items` in order and return: + - `slots`: one `(is_prepass, payload)` per output group, in original + walk order. `payload` is a ready prepass embedding (`is_prepass=True`) + or an index into `pixel_values_list`. The slot order keeps the output + identical to the per-item loop even when a request interleaves prepass + embeddings and raw pixels. + - `pixel_values_list`: per-item pixel tensors (num_patches, patch_px); + video items contribute one entry per frame. + - `position_ids_list`: matching (num_patches, 2) tensors, -1 = padding. + """ + slots: List[Tuple[bool, object]] = [] + pixel_values_list: List[torch.Tensor] = [] + position_ids_list: List[torch.Tensor] = [] - all_embeds = [] for item in items: all_pixel_values = flatten_nested_list([item.feature]) all_position_ids = flatten_nested_list( - [getattr(item, "image_position_ids", None)] + [getattr(item, position_ids_attr, None)] ) for pv_idx, pv in enumerate(all_pixel_values): + # Caller pre-computed the embedding; nothing to encode. if ( pv.dim() in (2, 3) and pv.shape[-1] == self.config.text_config.hidden_size ): - all_embeds.append(pv.to(self.language_model.device)) + slots.append((True, pv.to(self.language_model.device))) continue if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: raise ValueError( - f"pixel_values[{pv_idx}] has no matching image_position_ids. " - "The HF image processor likely renamed this output — " - "update ATTR_NAME_TO_MODALITY in the Gemma4 processor." + f"{modality_label}[{pv_idx}] has no matching " + f"{position_ids_attr}. The HF processor likely " + "renamed this output — update ATTR_NAME_TO_MODALITY " + "in the Gemma4 processor." ) pp = all_position_ids[pv_idx] - # Vision tower expects 3-D (batch, num_patches, ...). - # A single image may arrive as 2-D; add the batch dim if needed. + # Normalize to 3-D (items, num_patches, ...); 4-D video tensors + # (num_videos, num_frames, ...) flatten into the leading dim. if pv.dim() == 2: pv = pv.unsqueeze(0) if pp.dim() == 2: pp = pp.unsqueeze(0) + if pv.dim() == 4: + pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1]) + if pp.dim() == 4: + pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1]) - pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) - pp = pp.to(device=vt.device) - - pooled, pooler_mask = vt(pv, pp) + # unbind() returns views, so per-item split is copy-free. + for sub_pv, sub_pp in zip(pv.unbind(0), pp.unbind(0)): + slots.append((False, len(pixel_values_list))) + pixel_values_list.append(sub_pv) + position_ids_list.append(sub_pp) - for hs, mask in zip(pooled, pooler_mask): - real_tokens = hs[mask] - all_embeds.append( - self.embed_vision( - inputs_embeds=real_tokens.unsqueeze(0) - ).squeeze(0) - ) + return slots, pixel_values_list, position_ids_list - if all_embeds: - return torch.cat(all_embeds, dim=0) - else: - return torch.empty( - 0, - self.language_model.config.hidden_size, - device=next(self.parameters()).device, - dtype=self.language_model.dtype(), - ) - - def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - """Encode video frames through the vision tower with video-specific pooling. - - Each video is (num_frames, num_patches, patch_pixels) with matching - position_ids (num_frames, num_patches, 2). Frames are flattened into - the batch dimension so each frame is encoded independently, then pooled - dynamically based on the input patch count and pooling_kernel_size. + def _batched_encode( + self, + pixel_values_list: List[torch.Tensor], + position_ids_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Run the vision tower on `pixel_values_list` in resolution buckets, + run `embed_vision` exactly once over all valid tokens, and return the + per-item embeddings in the original input order. """ - vt = self.vision_tower + if not pixel_values_list: + return [] - all_embeds = [] - for item in items: - all_pixel_values = flatten_nested_list([item.feature]) - all_position_ids = flatten_nested_list( - [getattr(item, "video_position_ids", None)] - ) - - for pv_idx, pv in enumerate(all_pixel_values): - if ( - pv.dim() in (2, 3) - and pv.shape[-1] == self.config.text_config.hidden_size - ): - all_embeds.append(pv.to(self.language_model.device)) - continue - - if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: - raise ValueError( - f"pixel_values_videos[{pv_idx}] has no matching video_position_ids." - ) - pp = all_position_ids[pv_idx] - - # HF processor returns 4-D tensors - # (num_videos, num_frames, num_patches, ...) — collapse to - # 3-D (num_frames, num_patches, ...) so each frame is a - # batch element for the vision tower. - if pv.dim() == 4: - pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1]) - if pp.dim() == 4: - pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1]) + vt = self.vision_tower + target_device = vt.device + target_dtype = self.language_model.dtype() - pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) - pp = pp.to(device=vt.device) + # Bucket by patch count so each encoder forward is a same-shape batch + # with no cross-resolution padding waste. + buckets: dict = {} + for idx, pv in enumerate(pixel_values_list): + buckets.setdefault(pv.shape[0], []).append(idx) - pooled, pooler_mask = vt(pv, pp) + per_item_valid_tokens: List[Optional[torch.Tensor]] = [None] * len( + pixel_values_list + ) - for hs, mask in zip(pooled, pooler_mask): - real_tokens = hs[mask] - all_embeds.append( - self.embed_vision( - inputs_embeds=real_tokens.unsqueeze(0) - ).squeeze(0) - ) + for patches, member_indices in buckets.items(): + max_batch = min(len(member_indices), self._encoder_max_batch(patches)) + + for chunk_start in range(0, len(member_indices), max_batch): + chunk_indices = member_indices[chunk_start : chunk_start + max_batch] + + pv_batch = torch.stack( + [pixel_values_list[i] for i in chunk_indices], dim=0 + ).to(device=target_device, dtype=target_dtype) + pp_batch = torch.stack( + [position_ids_list[i] for i in chunk_indices], dim=0 + ).to(device=target_device) + + # pooler_mask marks valid tokens; valid widths differ per item. + pooled, pooler_mask = vt(pv_batch, pp_batch) + + for chunk_pos, orig_idx in enumerate(chunk_indices): + per_item_valid_tokens[orig_idx] = pooled[chunk_pos][ + pooler_mask[chunk_pos] + ] + + # embed_vision is pointwise (RMSNorm + Linear), so one call over all + # valid tokens is identical to per-item projection. + valid_lens = [t.shape[0] for t in per_item_valid_tokens] + flat_tokens = torch.cat(per_item_valid_tokens, dim=0) + flat_projected = self.embed_vision( + inputs_embeds=flat_tokens.unsqueeze(0) + ).squeeze(0) + + per_item_embeds: List[torch.Tensor] = [] + offset = 0 + for length in valid_lens: + per_item_embeds.append(flat_projected[offset : offset + length]) + offset += length + return per_item_embeds + + def _gather_mm_features( + self, + items: List[MultimodalDataItem], + position_ids_attr: str, + modality_label: str, + ) -> torch.Tensor: + """Common driver shared by image and video paths.""" + slots, pv_list, pp_list = self._flatten_pixel_lists( + items, position_ids_attr, modality_label + ) + encoded_embeds = self._batched_encode(pv_list, pp_list) + # Reassemble in walk order to match the pre-batching per-item loop. + all_embeds = [ + payload if is_prepass else encoded_embeds[payload] + for is_prepass, payload in slots + ] if all_embeds: return torch.cat(all_embeds, dim=0) - else: - return torch.empty( - 0, - self.language_model.config.hidden_size, - device=next(self.parameters()).device, - dtype=self.language_model.dtype(), - ) + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + return self._gather_mm_features(items, "image_position_ids", "pixel_values") + + def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # Gemma4 has no separate video tower; frames go through the same + # bucketed image path. + return self._gather_mm_features( + items, "video_position_ids", "pixel_values_videos" + ) def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: if self.audio_tower is None: @@ -1023,6 +1092,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): names = sorted(p for p in unloaded_params if pred(p)) if names: logger.log(level, "%s: %s", msg, names) + + # Cache the per-patch cost for `_encoder_max_batch` here (not __init__) + # so it reflects the actually-loaded vision_config. + vis_cfg = getattr(self.config, "vision_config", None) + if vis_cfg is not None and self.pp_group.is_first_rank: + hidden = int(getattr(vis_cfg, "hidden_size", 0)) + num_layers = int(getattr(vis_cfg, "num_hidden_layers", 0)) + # 2 bytes/elem (bf16) × residual stream per patch × layers. + self._encoder_bytes_per_patch = hidden * 2 * num_layers + return loaded_params lora_pattern = re.compile( diff --git a/test/registered/unit/models/test_gemma4_mm_batched_encoder.py b/test/registered/unit/models/test_gemma4_mm_batched_encoder.py new file mode 100644 index 000000000000..4379533438d8 --- /dev/null +++ b/test/registered/unit/models/test_gemma4_mm_batched_encoder.py @@ -0,0 +1,226 @@ +"""Unit tests for the batched vision-encoder path in ``gemma4_mm.py``. + +The vision tower and embedder are stubbed so the tests run on CPU without the +real Gemma-4 checkpoint; they assert encoder/embedder call counts, output +ordering, budget-bound chunking, and device-agnostic budget computation. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import List + +import torch + +from sglang.srt.models import gemma4_mm as gemma4_mm_module +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci + +register_cuda_ci(est_time=30, stage="base-b", runner_config="1-gpu-small") +register_amd_ci(est_time=30, suite="stage-b-test-1-gpu-small-amd") + + +def _make_fake_model( + hidden_size: int = 16, + *, + encoder_max_batch: int | None = None, + fail_pad: bool = False, +): + """Lightweight stand-in exposing only the attributes the encoder helpers + touch. The fake tower records call shapes and embeds each patch as a + constant vector keyed on its batch row, so per-item ordering is verifiable. + """ + + class _FakeTower: + device = torch.device("cpu") + + def __init__(self): + self.calls: List[tuple[torch.Tensor, torch.Tensor]] = [] + + def __call__(self, pv: torch.Tensor, pp: torch.Tensor): + self.calls.append((pv.clone(), pp.clone())) + b, n, _ = pv.shape + # pp == -1 marks padding (the real Gemma4 convention). + pooler_mask = (pp != -1).all(dim=-1) + hidden = ( + torch.arange(b, dtype=torch.float32) + .view(b, 1, 1) + .repeat(1, n, hidden_size) + ) + return hidden, pooler_mask + + class _FakeEmbedVision(torch.nn.Module): + def __init__(self, hidden): + super().__init__() + self.hidden = hidden + self.calls: List[torch.Tensor] = [] + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + self.calls.append(inputs_embeds.clone()) + return inputs_embeds # identity projection + + class _LM: + def __init__(self, hidden): + self.config = SimpleNamespace(hidden_size=hidden) + self.device = torch.device("cpu") + + def dtype(self): + return torch.float32 + + text_config = SimpleNamespace(hidden_size=hidden_size) + config = SimpleNamespace(text_config=text_config) + + # Default to an effectively unbounded budget so batching runs; the + # `encoder_max_batch` kwarg overrides it to exercise chunking. + if encoder_max_batch is None: + budget = 1 << 40 + per_patch = 1 + else: + budget = encoder_max_batch + per_patch = 1 + + fake = SimpleNamespace( + config=config, + vision_tower=_FakeTower(), + embed_vision=_FakeEmbedVision(hidden_size), + language_model=_LM(hidden_size), + _encoder_budget_bytes=budget, + _encoder_bytes_per_patch=per_patch, + ) + # Bind the real (unbound) methods onto the fake instance. + cls = gemma4_mm_module.Gemma4ForConditionalGeneration + for name in [ + "_flatten_pixel_lists", + "_batched_encode", + "_gather_mm_features", + "_encoder_max_batch", + "get_image_feature", + "get_video_feature", + ]: + fn = getattr(cls, name) + setattr(fake, name, fn.__get__(fake, type(fake))) + + fake._fail_pad = fail_pad + # parameters() is used by the empty-input path; return one tensor. + fake.parameters = lambda: iter([torch.zeros(1)]) + return fake + + +def _make_item(num_images: int, num_patches: int): + """Construct a minimal MultimodalDataItem-like object with `num_images` + images each shaped (num_patches, 4).""" + pv_list = [torch.full((num_patches, 4), float(i)) for i in range(num_images)] + pp_list = [ + torch.arange(num_patches).unsqueeze(-1).repeat(1, 2).float() + for _ in range(num_images) + ] + return SimpleNamespace(feature=pv_list, image_position_ids=pp_list) + + +def test_single_resolution_single_call(): + fake = _make_fake_model() + item = _make_item(num_images=6, num_patches=10) + out = fake.get_image_feature([item]) + + # 1 encoder forward over [6, 10, 4] + assert len(fake.vision_tower.calls) == 1, fake.vision_tower.calls + pv, _ = fake.vision_tower.calls[0] + assert pv.shape == (6, 10, 4) + + # 1 batched embedder call over (1, 60, 16) + assert len(fake.embed_vision.calls) == 1 + assert fake.embed_vision.calls[0].shape == (1, 60, 16) + + # Output is (60, 16): 6 images × 10 valid patches × hidden 16 + assert out.shape == (60, 16) + + +def test_mixed_resolution_bucketing(): + fake = _make_fake_model() + # 2 small images (5 patches each) and 1 big image (12 patches) + small = _make_item(num_images=2, num_patches=5) + big = _make_item(num_images=1, num_patches=12) + fake.get_image_feature([small, big]) + + # Two buckets: one for 5 patches (batch=2), one for 12 patches (batch=1). + assert len(fake.vision_tower.calls) == 2 + shapes = sorted(call[0].shape for call in fake.vision_tower.calls) + assert shapes == [(1, 12, 4), (2, 5, 4)] + + # Still a single embedder call over all valid tokens. + assert len(fake.embed_vision.calls) == 1 + total_tokens = 2 * 5 + 1 * 12 + assert fake.embed_vision.calls[0].shape == (1, total_tokens, 16) + + +def test_chunking_when_max_batch_set(): + # With per_patch=1 and patches=2, cost-per-item = 2. + # budget=4 -> 4//2 = 2 items per chunk; 6 items -> 3 encoder calls. + fake = _make_fake_model(encoder_max_batch=4) + item = _make_item(num_images=6, num_patches=2) + fake.get_image_feature([item]) + assert len(fake.vision_tower.calls) == 3 + # Still 1 embedder call. + assert len(fake.embed_vision.calls) == 1 + + +def test_empty_returns_empty_tensor(): + fake = _make_fake_model() + out = fake.get_image_feature([]) + assert out.shape == (0, 16) + + +def test_prepass_real_interleave_preserves_order(): + """A prepass (already-embedded) entry between two raw-pixel entries must + stay in walk order, not be hoisted to the front. The prepass rows carry a + sentinel value (99) the tower never produces. + """ + fake = _make_fake_model(hidden_size=16) + + real0 = torch.zeros(4, 4) # raw pixels, image 0 + prepass = torch.full((3, 16), 99.0) # already at hidden_size -> prepass + real1 = torch.ones(4, 4) # raw pixels, image 1 + + pp = torch.arange(4).unsqueeze(-1).repeat(1, 2).float() + item = SimpleNamespace( + feature=[real0, prepass, real1], + image_position_ids=[pp, None, pp], + ) + + out = fake.get_image_feature([item]) + + # 4 (real0) + 3 (prepass) + 4 (real1) = 11 tokens, in that order. + assert out.shape == (11, 16) + # The 3 prepass rows must sit in the middle (rows 4..7), not at the front. + assert torch.all(out[4:7] == 99.0), out + # Surrounding rows are real-image outputs (never the 99 sentinel). + assert not torch.any(out[:4] == 99.0) + assert not torch.any(out[7:] == 99.0) + + +def test_lazy_budget_is_device_agnostic(): + """The budget is computed lazily from a device-agnostic memory query, so + batching stays active on non-CUDA devices (here: CPU) instead of falling + back to single-item batches. + """ + fake = _make_fake_model(hidden_size=16) + # Force the lazy-init path: zero budget, real per-patch cost. + fake._encoder_budget_bytes = 0 + fake._encoder_bytes_per_patch = 1 + + max_batch = fake._encoder_max_batch(patches_per_item=4) + assert fake._encoder_budget_bytes > 0, "expected a device-agnostic budget" + assert max_batch > 1, "budget should permit batching, not fall back to 1" + + item = _make_item(num_images=6, num_patches=4) + fake.get_image_feature([item]) + assert len(fake.vision_tower.calls) == 1, fake.vision_tower.calls + + +if __name__ == "__main__": + test_single_resolution_single_call() + test_mixed_resolution_bucketing() + test_chunking_when_max_batch_set() + test_empty_returns_empty_tensor() + test_prepass_real_interleave_preserves_order() + test_lazy_budget_is_device_agnostic() + print("ALL TESTS PASSED")