diff --git a/vllm_gaudi/extension/bucketing/vision.py b/vllm_gaudi/extension/bucketing/vision.py index 084526259f..52d90ddde6 100644 --- a/vllm_gaudi/extension/bucketing/vision.py +++ b/vllm_gaudi/extension/bucketing/vision.py @@ -21,8 +21,8 @@ }, 'qwen3_vl': { 'is_batch_based': False, - 'buckets': - [256, 512, 1024, 1350, 1602, 2048, 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 131076] + # patches per image + 'buckets': [196, 256, 441, 480, 576, 900, 1156] } } @@ -37,6 +37,8 @@ def __init__(self, model_name, is_batch_based=None): self.is_batch_based = is_batch_based if is_batch_based is not None else config['is_batch_based'] + self.qwen2_5_vl = 'qwen2_5_vl' in model_name.lower() + envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "") if envvar == 'None': @@ -85,15 +87,16 @@ def find_factor(self, desired_patches, orig): return None def find_padding(self, h_orig, w_orig, desired_patches): + merge_size = 2 best_pad_h, best_pad_w = 0, 0 if desired_patches % h_orig == 0: best_pad_h = 0 w_factor = desired_patches // h_orig - best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % 2 == 0) else 0 + best_pad_w = w_factor - w_orig if (w_factor > w_orig and w_factor % merge_size == 0) else 0 elif desired_patches % w_orig == 0: best_pad_w = 0 h_factor = desired_patches // w_orig - best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % 2 == 0) else 0 + best_pad_h = h_factor - h_orig if (h_factor > h_orig and h_factor % merge_size == 0) else 0 elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0: if h_orig > w_orig: w_factor = self.find_factor(desired_patches, w_orig) @@ -163,3 +166,28 @@ def greedy_plan(self, batchsize, available_batchsizes): def __repr__(self): return str(self.multimodal_buckets) + + def bucket_to_image_resolution(self, patch_size: int = 14): + """ + Calculate image resolution by first determining height from target_patches, + then deriving width from aspect ratio. + """ + aspect_ratios = [ + (1, 1), # 1:1 square + (4, 3), # 4:3 landscape + (3, 4), # 3:4 portrait + (16, 9), # 16:9 widescreen + (9, 16), # 9:16 portrait + ] + merge_size = 2 # Qwen2.5/3VL spatial_merge_size + resolution_list = [] + for target_patches in self.multimodal_buckets: + for (ratio_w, ratio_h) in aspect_ratios: + grid_h = int(target_patches**0.5) + height = grid_h * patch_size + width = int(height * ratio_w / ratio_h) + grid_w = width // patch_size + if grid_w * grid_h // merge_size != 0: + grid_w = ((grid_w + merge_size - 1) // merge_size) * merge_size + resolution_list.append((grid_w * patch_size, height)) + return resolution_list diff --git a/vllm_gaudi/models/__init__.py b/vllm_gaudi/models/__init__.py index 64fad770e0..1cfa3b4de4 100644 --- a/vllm_gaudi/models/__init__.py +++ b/vllm_gaudi/models/__init__.py @@ -15,3 +15,7 @@ def register_model(): from vllm_gaudi.models.qwen3_vl import HpuQwen3_VLForConditionalGeneration # noqa: F401 ModelRegistry.register_model("Qwen3VLForConditionalGeneration", "vllm_gaudi.models.qwen3_vl:HpuQwen3_VLForConditionalGeneration") + + from vllm_gaudi.models.qwen3_vl_moe import HpuQwen3_VLMoeForConditionalGeneration # noqa: F401 + ModelRegistry.register_model("Qwen3VLMoeForConditionalGeneration", + "vllm_gaudi.models.qwen3_vl_moe:HpuQwen3_VLMoeForConditionalGeneration") diff --git a/vllm_gaudi/models/qwen2_5_vl.py b/vllm_gaudi/models/qwen2_5_vl.py index 2023f59122..98f699fd76 100644 --- a/vllm_gaudi/models/qwen2_5_vl.py +++ b/vllm_gaudi/models/qwen2_5_vl.py @@ -1,4 +1,3 @@ -import math import os from functools import partial from typing import Optional, Callable, Union @@ -34,6 +33,7 @@ from vllm.model_executor.models.utils import (maybe_prefix, cast_overflow_tensors) from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm_gaudi.extension.runtime import get_config import habana_frameworks.torch.core as htcore from habana_frameworks.torch.hpex.kernels import FusedSDPA @@ -72,28 +72,27 @@ class HPU_Attention: in ['true', '1'] else 'None' @classmethod - def forward(cls, q, k, v, mask, q_block_size=64): + def forward(cls, q, k, v, mask, cu_seqlens, qwen2_5_vl, q_block_size=64): """ Support long sequence at prompt phase """ q_len = q.size(-2) - if q_len <= 65536: # need to investigate this crosspoint - return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode) - - assert q_len % q_block_size == 0 - q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) - attn_output = torch.zeros_like(q) - - for i in range(q_tiles): - s, e = i * q_block_size, (i + 1) * q_block_size - row_q = q[:, :, s:e, :] - row_mask = mask[:, :, s:e, :] - attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None, cls.softmax_mode) - # TODO: markstep after a couple of iterations - # need to experiment the optimal number. - if i % 75 == 0: - htcore.mark_step() - return attn_output + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + if mask is not None or len(lens) == 1: + if not qwen2_5_vl or (qwen2_5_vl and q_len < 65536): + return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode) + else: + return AttentionLongSequence.forward(q, k, v, mask, q_block_size, cls.softmax_mode) + else: + q_chunks = torch.split(q, lens, dim=2) + k_chunks = torch.split(k, lens, dim=2) + v_chunks = torch.split(v, lens, dim=2) + outputs = [] + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): + output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, cls.softmax_mode) + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=2) + return context_layer def create_block_diagonal_attention_mask(indices): @@ -148,6 +147,8 @@ def __init__( ) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + model_type = get_config().model_type + self.qwen2_5_vl = 'qwen2_5_vl' in model_type.lower() def forward( self, @@ -187,11 +188,9 @@ def forward( # performs full attention using the previous computed mask q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - output = HPU_Attention.forward(q1, k1, v1, attn_mask) + output = HPU_Attention.forward(q1, k1, v1, attn_mask, cu_seqlens, self.qwen2_5_vl) context_layer = rearrange(output, "b h s d -> b s h d ") - context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() - output, _ = self.proj(context_layer) return output diff --git a/vllm_gaudi/models/qwen3_moe.py b/vllm_gaudi/models/qwen3_moe.py new file mode 100644 index 0000000000..5df4f84f7e --- /dev/null +++ b/vllm_gaudi/models/qwen3_moe.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from vllm.model_executor.models.qwen3_moe import ( + Qwen3MoeSparseMoeBlock as UpstreamQwen3MoeSparseMoeBlock, ) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.distributed import tensor_model_parallel_all_gather + + +class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock): + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = orig_shape[-1] + + hs = hidden_states.reshape(-1, hidden_dim) # (T, H) + num_tokens = hs.shape[0] + + if getattr(self, "is_sequence_parallel", False): + hs = sequence_parallel_chunk(hs) + + router_logits, _ = self.gate(hs) + out = self.experts(hidden_states=hs, router_logits=router_logits) + + if getattr(self, "is_sequence_parallel", False): + out = tensor_model_parallel_all_gather(out, 0) + out = out[:num_tokens] + + return out.reshape(*orig_shape[:-1], hidden_dim) + + +def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int: + lm_model = getattr(language_model, "model", None) + layers = getattr(lm_model, "layers", None) + if layers is None: + return + + for layer in layers: + mlp = getattr(layer, "mlp", None) + if mlp is None: + continue + + if isinstance(mlp, HpuQwen3MoeSparseMoeBlock): + continue + + if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock): + mlp.__class__ = HpuQwen3MoeSparseMoeBlock + mlp._hpu_accept_3d_installed = True diff --git a/vllm_gaudi/models/qwen3_vl.py b/vllm_gaudi/models/qwen3_vl.py index 82a72d3ebe..15b3d4d469 100644 --- a/vllm_gaudi/models/qwen3_vl.py +++ b/vllm_gaudi/models/qwen3_vl.py @@ -1,17 +1,23 @@ import torch +import numpy as np from .utils import _merge_multimodal_embeddings from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import _require_is_multimodal + +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLImageInputs, ) from vllm.model_executor.models.qwen3_vl import ( Qwen3VLForConditionalGeneration, Qwen3_VisionTransformer, Qwen3_VisionBlock, ) +from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model + from vllm.model_executor.models.utils import maybe_prefix -from vllm_gaudi.models.qwen2_5_vl import (HPUQwen2_5_VisionAttention) +from vllm_gaudi.models.qwen2_5_vl import HPUQwen2_5_VisionAttention class HPUQwen3_VisionBlock(Qwen3_VisionBlock): @@ -48,6 +54,27 @@ def __init__( prefix=f"{prefix}.attn", ) + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + attn_mask=None, + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + attn_mask=attn_mask, + max_seqlen=max_seqlen, + ) + + x = x + self.mlp(self.norm2(x)) + return x + class HPUQwen3_VisionTransformer(Qwen3_VisionTransformer): @@ -83,6 +110,51 @@ def __init__( ) for layer_idx in range(depth) ]) + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) + hidden_states = self.patch_embed(hidden_states) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = np.array(grid_thw, dtype=np.int32) + else: + grid_thw_list = grid_thw.tolist() + grid_thw = grid_thw.numpy() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + + cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(axis=0, dtype=np.int32) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + cu_seqlens = torch.from_numpy(cu_seqlens) + hidden_states = hidden_states.unsqueeze(1) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + attn_mask=attn_mask, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat([hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + class HpuQwen3_VLForConditionalGeneration(Qwen3VLForConditionalGeneration): @@ -101,6 +173,58 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "visual"), ) + def create_block_diagonal_mask(self, + cu_seqlens: torch.Tensor, + grid_thw: list[int], + device: torch.device = None, + dtype: torch.dtype = torch.bool) -> torch.Tensor: + """ + Create block diagonal mask that excludes padded tokens for Qwen3VL attention. + Args: + cu_seqlens: Cumulative sequence lengths from grid dimensions + grid_thw: The grid dimensions with merge_size=2 compatibility + device: Target device for the mask + dtype: Data type for the mask (typically torch.bool) + + Returns: + Block diagonal attention mask with shape [total_seq_len, total_seq_len] + """ + if device is None: + device = cu_seqlens.device + + # Calculate total sequence length including padding + total_patches = int(grid_thw.prod(-1).sum().item()) + # Create mask with total size including padding + mask = torch.zeros(total_patches, total_patches, device=device, dtype=dtype) + cu_seqlens = cu_seqlens.tolist() + cu_seqlens = [0] + cu_seqlens + starts = cu_seqlens[:-1] + ends = cu_seqlens[1:] + for start, end in zip(starts, ends): + mask[start:end, start:end] = True + return mask + + def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw.tolist(), + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw, attn_mask=None) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return image_embeds.split(sizes) + def _compute_deepstack_embeds( self, inputs_embeds: torch.Tensor, diff --git a/vllm_gaudi/models/qwen3_vl_moe.py b/vllm_gaudi/models/qwen3_vl_moe.py new file mode 100644 index 0000000000..1f520d9afb --- /dev/null +++ b/vllm_gaudi/models/qwen3_vl_moe.py @@ -0,0 +1,36 @@ +from vllm.config import VllmConfig +from vllm.model_executor.models.utils import maybe_prefix + +from vllm.model_executor.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration +from vllm_gaudi.models.qwen3_vl import HPUQwen3_VisionTransformer, HpuQwen3_VLForConditionalGeneration + +from vllm_gaudi.models.qwen3_moe import upgrade_qwen3_moe_blocks_inplace + + +class HpuQwen3_VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + quant_config = getattr(self, "quant_config", None) + multimodal_config = getattr(vllm_config.model_config, "multimodal_config", None) + + if hasattr(self, "visual") and self.visual is not None: + self.visual = HPUQwen3_VisionTransformer( + self.config.vision_config, + norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + # qwen3 moe mlp blocks: make forward for 3d safe (b,s,h -> t,h) + lm = getattr(self, "language_model", None) + if lm is not None: + _n = upgrade_qwen3_moe_blocks_inplace(lm) + + def _compute_deepstack_embeds(self, *args, **kwargs): + return HpuQwen3_VLForConditionalGeneration._compute_deepstack_embeds(self, *args, **kwargs) + + def embed_input_ids(self, *args, **kwargs): + return HpuQwen3_VLForConditionalGeneration.embed_input_ids(self, *args, **kwargs) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1004198603..7a4eea1975 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4171,13 +4171,12 @@ def log_warmup(self, phase, i, max_i, first_dim, second_dim, third_dim, causal=F f"free_mem:{free_mem}") tqdm.write(msg) - def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, w, h, f): + def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, w, h): free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory()) msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " f"seq_len:{seq_len} " f"resolution:{w}X{h} " - f"frames:{f} " f"free_mem:{free_mem}") logger.info(msg) @@ -4854,42 +4853,26 @@ def _get_mm_dummy_batch( self, modality: str, image_args: int, - ratio_w: int, - ratio_h: int, + width: int, + height: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" assert self.mm_budget is not None + num_frames = 100 count = 1 - num_frames = 0 - batch = image_args if self.get_model().vision_bucket_manager.is_batch_based else count if self.get_model().vision_bucket_manager.is_batch_based: - # Create ImageDummyOptions for Gemma3 - w = 896 # pixels as in gemma3 config - h = 896 # pixels as in gemma3 config batch = image_args else: - patch_size = int(self.get_patch_size_from_model()) - # Calculate width and height to maintain aspect ratio and patch count - # Total patches = (width/patch_size) * (height/patch_size) - # We want: (w/ps) * (h/ps) = num_patch where num_patch is image_args - # And: w/h = ratio_w/ratio_h - grid_w = int(math.sqrt(image_args * ratio_w / ratio_h)) - grid_h = int(image_args / grid_w) - w = grid_w * patch_size - h = grid_h * patch_size + mm_options = self.model_config.get_multimodal_config().get_dummy_options(modality) + count = mm_options.count if mm_options and hasattr(mm_options, 'count') else count batch = count - if modality == 'image': - mm_options = {"image": ImageDummyOptions(count=count, width=w, height=h), "video": None} + mm_options = {"image": ImageDummyOptions(count=count, width=width, height=height), "video": None} elif modality == 'video': - video_options = self.model_config.get_multimodal_config().get_dummy_options("video") - num_frames = video_options.num_frames if video_options and hasattr(video_options, 'num_frames') else 100 - w = video_options.width if video_options and hasattr(video_options, 'width') else w - h = video_options.height if video_options and hasattr(video_options, 'height') else h - count = video_options.count if video_options and hasattr(video_options, 'count') else 1 + num_frames = mm_options.num_frames if mm_options and hasattr(mm_options, 'num_frames') else num_frames mm_options = { "image": None, - "video": VideoDummyOptions(count=count, num_frames=num_frames, width=w, height=h) + "video": VideoDummyOptions(count=count, num_frames=num_frames, width=width, height=height) } else: raise NotImplementedError(f"Modality '{modality}' is not supported") @@ -4899,6 +4882,7 @@ def _get_mm_dummy_batch( dummy_mm_inputs = profiler._get_dummy_mm_inputs(seq_len=4196, mm_counts={modality: count}, mm_options=mm_options) + dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0] # We use the cache so that the item is saved to the cache, # but not read from the cache @@ -4910,7 +4894,7 @@ def _get_mm_dummy_batch( dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - )), w, h, num_frames + )) def warmup_multimodal_graphs(self, buckets): @@ -4921,44 +4905,63 @@ def warmup_multimodal_graphs(self, buckets): self.scheduler_config, self.mm_registry, ) if self.supports_mm_inputs else None + vision_bucket_manager = self.get_model().vision_bucket_manager + is_batch_based = vision_bucket_manager.is_batch_based + mm_config = self.model_config.get_multimodal_config() + + is_image_warmup = (mm_config is not None and mm_config.get_dummy_options("image") is not None + and self.mm_budget.mm_limits['image'] != 0) + is_video_warmup = (mm_config is not None and mm_config.get_dummy_options("video") is not None + and self.mm_budget.mm_limits['video'] != 999) + warmup_configs = { + "image": (0, lambda: mm_config.get_dummy_options("image")), + "video": (999, lambda: mm_config.get_dummy_options("video")) + } + width = height = None + warmup_lists = [] + for modality, (limit_value, get_options) in warmup_configs.items(): + if (mm_config and mm_config.get_dummy_options(modality) + and self.mm_budget.mm_limits[modality] != limit_value): + options = get_options() + width = options.width if hasattr(options, 'width') else None + height = options.height if hasattr(options, 'height') else None + if width is not None and height is not None: + warmup_lists.append((width, height)) + break - sanity_check = self.get_model().vision_bucket_manager.is_batch_based - - aspect_ratios = [ - (1, 1), # 1:1 square - (4, 3), # 4:3 landscape - (3, 4), # 3:4 portrait - (16, 9), # 16:9 widescreen - (9, 16), # 9:16 portrait - ] - - is_video_warmup = bool(self.model_config.get_multimodal_config() is not None and \ - self.model_config.get_multimodal_config().get_dummy_options("video") is not None \ - and self.mm_budget.mm_limits['video'] != 999) - - is_image_warmup = bool(self.model_config.get_multimodal_config() is not None and \ - self.model_config.get_multimodal_config().get_dummy_options("image") is not None \ - and self.mm_budget.mm_limits['image'] != 0) + if not is_batch_based and len(buckets) > 0: + patch_size = int(self.get_patch_size_from_model()) + warmup_lists = warmup_lists + \ + vision_bucket_manager.bucket_to_image_resolution(patch_size=patch_size) for modality, max_items in self.mm_budget.mm_limits.items(): if modality == 'image' and not is_image_warmup or modality == 'video' \ and not is_video_warmup: continue phase = f'Graph/Multimodal({modality})' - num_candidates = len(buckets) - for idx, img_arg in enumerate(buckets): - for (ratio_w, ratio_h) in aspect_ratios: - batched_dummy_mm_inputs, w, h, f = self._get_mm_dummy_batch(modality, img_arg, ratio_w, ratio_h) - dummy_encoder_outputs = \ - self.model.embed_multimodal( - **batched_dummy_mm_inputs) - if sanity_check: - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=img_arg, - ) - - self.graphed_buckets.add(img_arg) - self.log_warmup_multimodal(phase, idx, num_candidates, 1, 0, w, h, f) + candidates = buckets if is_batch_based else warmup_lists + for idx in range(len(candidates)): + if is_batch_based: + image_args = candidates[idx] + width = 896 # pixels as in gemma3 config + height = 896 # pixels as in gemma3 config + else: + image_args = None + width, height = candidates[idx] + batched_dummy_mm_inputs = self._get_mm_dummy_batch(modality, + image_args=image_args, + width=width, + height=height) + dummy_encoder_outputs = \ + self.model.embed_multimodal( + **batched_dummy_mm_inputs) + if is_batch_based: + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=candidates[idx], + ) + self.graphed_buckets.add(candidates[idx]) + self.log_warmup_multimodal(phase, idx, len(candidates), candidates[idx] if is_batch_based else 1, 0, + width, height) def _maybe_profile_unified_attn(self): unified_cfg_str = os.environ.get('VLLM_PROFILE_UNIFIED', None)