diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 672fe2384e40..a58f669b22e3 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -919,6 +919,7 @@ def embed_mm_inputs( data_embedding_func_mapping: Dict[Modality, DataEmbeddingFunc] = None, placeholder_tokens: dict[Modality, List[int]] = None, use_deepstack: Dict[Modality, bool] = {}, + prealloc_deepstack: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: """ Embed multimodal inputs and integrate them with text token embeddings. @@ -1019,12 +1020,16 @@ def embed_mm_inputs( deepstack_embedding_shape = input_embeds.shape[:-1] + ( input_embeds.shape[-1] * num_deepstack_embeddings, ) - # a zero-filled embedding, with the same length of input_embeds, but different hidden_size - input_deepstack_embeds = torch.zeros( - deepstack_embedding_shape, - device=input_embeds.device, - dtype=input_embeds.dtype, - ) + if prealloc_deepstack is not None: + assert prealloc_deepstack.shape == deepstack_embedding_shape + input_deepstack_embeds = prealloc_deepstack + input_deepstack_embeds.zero_() + else: + input_deepstack_embeds = torch.zeros( + deepstack_embedding_shape, + device=input_embeds.device, + dtype=input_embeds.dtype, + ) other_info["input_deepstack_embeds"] = input_deepstack_embeds @@ -1091,6 +1096,7 @@ def general_mm_embed_routine( for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) if forward_batch.mm_inputs[i] is not None ] + prealloc_deepstack = kwargs.get("input_deepstack_embeds", None) input_embeds, other_info = embed_mm_inputs( mm_inputs_list=mm_inputs_list, extend_prefix_lens=extend_prefix_lens, @@ -1101,6 +1107,7 @@ def general_mm_embed_routine( data_embedding_func_mapping=data_embedding_funcs, placeholder_tokens=placeholder_tokens, use_deepstack=use_deepstack, + prealloc_deepstack=prealloc_deepstack, ) # add for qwen3_vl deepstack if use_deepstack: diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index feccc5c5608b..37364f3d46d7 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -245,6 +245,18 @@ def __init__(self, model_runner: ModelRunner): mrope_positions = torch.zeros( (3, self.max_num_tokens), dtype=torch.int64 ) + if hasattr(self.model_runner.model, "num_deepstack_embeddings"): + deepstack_slots = getattr( + self.model_runner.model, "num_deepstack_embeddings", 3 + ) + self.input_deepstack_embeds = torch.zeros( + ( + self.max_num_tokens, + self.model_runner.model_config.hidden_size + * deepstack_slots, + ), + dtype=self.model_runner.dtype, + ) else: input_embeds = None mrope_positions = None @@ -392,6 +404,9 @@ def warmup_torch_compile(self, num_tokens: int): forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(None, num_tokens, forward_batch.dp_padding_mode.is_max_len()) set_is_extend_in_batch(False) + kwargs = {} + if self.is_multimodal and hasattr(self, "input_deepstack_embeds"): + kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[:num_tokens] with set_forward_context( forward_batch, self.attention_layers, @@ -403,6 +418,7 @@ def warmup_torch_compile(self, num_tokens: int): forward_batch.input_ids, forward_batch.positions, forward_batch, + **kwargs, ) def _cache_loc_dtype(self): @@ -547,6 +563,10 @@ def capture_one_batch_size(self, num_tokens: int): self.model_runner.attn_backend.init_forward_metadata(forward_batch) + kwargs = {} + if self.is_multimodal and hasattr(self, "input_deepstack_embeds"): + kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[:num_tokens] + # Run and capture def run_once(): # Clean intermediate result cache for DP attention @@ -560,7 +580,6 @@ def run_once(): # It is True in this context but we need to set it to use low latency deepep mode. set_is_extend_in_batch(False) - kwargs = {} with set_forward_context( forward_batch, self.attention_layers, @@ -735,6 +754,11 @@ def replay( # Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch self.model_runner.attn_backend.init_forward_metadata(forward_batch) static_forward_batch = self.replay_prepare(forward_batch, **kwargs) + if self.is_multimodal and hasattr(self, "input_deepstack_embeds"): + static_num_tokens = len(static_forward_batch.input_ids) + kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[ + :static_num_tokens + ] # Replay with set_forward_context( static_forward_batch, diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 09c443b894c1..0a3c445b1625 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -913,16 +913,37 @@ def __init__( self.deepstack_embed_to_decoder_layer = range( len(config.vision_config.deepstack_visual_indexes) ) + self.deepstack_embeds_buffer = None def get_deepstack_embeds( - self, layer_idx: int, input_deepstack_embeds: Optional[torch.Tensor] + self, + layer_idx: int, + input_deepstack_embeds: Optional[torch.Tensor], + seq_len: int, ) -> Optional[torch.Tensor]: - """Get deepstack embeddings for a given layer index, or None if not applicable.""" - if ( - input_deepstack_embeds is None - or layer_idx not in self.deepstack_embed_to_decoder_layer - ): + if layer_idx not in self.deepstack_embed_to_decoder_layer: return None + if input_deepstack_embeds is None: + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + total = self.hidden_size * len(self.deepstack_embed_to_decoder_layer) + if ( + self.deepstack_embeds_buffer is None + or self.deepstack_embeds_buffer.size(0) < seq_len + ): + new_len = max( + seq_len, + ( + self.deepstack_embeds_buffer.size(0) * 2 + if self.deepstack_embeds_buffer is not None + else seq_len + ), + ) + self.deepstack_embeds_buffer = torch.zeros( + new_len, total, dtype=dtype, device=device + ).contiguous() + sep = self.hidden_size * layer_idx + return self.deepstack_embeds_buffer[:seq_len, sep : sep + self.hidden_size] sep = self.hidden_size * layer_idx return input_deepstack_embeds[:, sep : sep + self.hidden_size] @@ -963,7 +984,7 @@ def forward( # The order matters because addition with different tensors is not associative in practice. # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. deepstack_embeds = self.get_deepstack_embeds( - layer_idx - 1, input_deepstack_embeds + layer_idx - 1, input_deepstack_embeds, hidden_states.shape[0] ) hidden_states, residual = layer( positions, @@ -975,7 +996,7 @@ def forward( # Handle deepstack for the last processed layer if it exists. last_deepstack = self.get_deepstack_embeds( - self.end_layer - 1, input_deepstack_embeds + self.end_layer - 1, input_deepstack_embeds, hidden_states.shape[0] ) if not self.pp_group.is_last_rank: @@ -1241,6 +1262,7 @@ def forward( forward_batch: ForwardBatch, get_embedding: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, + **kwargs, ): """Run forward pass for Qwen3-VL. @@ -1266,16 +1288,17 @@ def forward( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" ) - - hidden_states = general_mm_embed_routine( - input_ids=input_ids, - forward_batch=forward_batch, - language_model=self.model, - multimodal_model=self, - positions=positions, - use_deepstack=self.use_deepstack, - pp_proxy_tensors=pp_proxy_tensors, - ) + with torch.no_grad(): + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + multimodal_model=self, + positions=positions, + use_deepstack=self.use_deepstack, + pp_proxy_tensors=pp_proxy_tensors, + **kwargs, + ) if self.pp_group.is_last_rank: if not get_embedding: diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index f98460665d37..6b0050d1cea8 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -57,19 +57,40 @@ def __init__( # This approach follows the original implementation. # TODO: make config of type Qwen3VLMoeConfig, so that we can directly obtain deepstack_visual_indexes. self.deepstack_embed_to_decoder_layer = range(3) + self.deepstack_embeds_buffer = None def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def get_deepstack_embeds( - self, layer_idx: int, input_deepstack_embeds: Optional[torch.Tensor] + self, + layer_idx: int, + input_deepstack_embeds: Optional[torch.Tensor], + seq_len: int, ) -> Optional[torch.Tensor]: - """Get deepstack embeddings for a given layer index, or None if not applicable.""" - if ( - input_deepstack_embeds is None - or layer_idx not in self.deepstack_embed_to_decoder_layer - ): + if layer_idx not in self.deepstack_embed_to_decoder_layer: return None + if input_deepstack_embeds is None: + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + total = self.hidden_size * len(self.deepstack_embed_to_decoder_layer) + if ( + self.deepstack_embeds_buffer is None + or self.deepstack_embeds_buffer.size(0) < seq_len + ): + new_len = max( + seq_len, + ( + self.deepstack_embeds_buffer.size(0) * 2 + if self.deepstack_embeds_buffer is not None + else seq_len + ), + ) + self.deepstack_embeds_buffer = torch.zeros( + new_len, total, dtype=dtype, device=device + ).contiguous() + sep = self.hidden_size * layer_idx + return self.deepstack_embeds_buffer[:seq_len, sep : sep + self.hidden_size] sep = self.hidden_size * layer_idx return input_deepstack_embeds[:, sep : sep + self.hidden_size] @@ -109,7 +130,7 @@ def forward( # The order matters because addition with different tensors is not associative in practice. # Deepstack for prev_layer is applied at the start of current layer via post_residual_addition. deepstack_embeds = self.get_deepstack_embeds( - layer_idx - 1, input_deepstack_embeds + layer_idx - 1, input_deepstack_embeds, hidden_states.shape[0] ) hidden_states, residual = layer( positions, @@ -121,7 +142,7 @@ def forward( # Handle deepstack for the last processed layer if it exists. last_deepstack = self.get_deepstack_embeds( - self.end_layer - 1, input_deepstack_embeds + self.end_layer - 1, input_deepstack_embeds, hidden_states.shape[0] ) if not self.pp_group.is_last_rank: diff --git a/test/manual/nightly/test_vlms_piecewise_cuda_graph.py b/test/manual/nightly/test_vlms_piecewise_cuda_graph.py index 7a72dd3fa41c..fb32a6e10caa 100644 --- a/test/manual/nightly/test_vlms_piecewise_cuda_graph.py +++ b/test/manual/nightly/test_vlms_piecewise_cuda_graph.py @@ -18,7 +18,9 @@ ) MODELS = [ - SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.60), + SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.50), + SimpleNamespace(model="Qwen/Qwen3-VL-8B-Instruct", mmmu_accuracy=0.50), + SimpleNamespace(model="Qwen/Qwen3-VL-30B-A3B-Instruct", mmmu_accuracy=0.50), ] @@ -59,7 +61,7 @@ def run_mmmu_eval( """ # -------- fixed settings -------- model = "openai_compatible" - tp = 1 + tp = 2 tasks = "mmmu_val" batch_size = 32 log_suffix = "openai_compatible" @@ -138,7 +140,7 @@ def _run_vlm_mmmu_test( "--piecewise-cuda-graph-max-tokens", "8192", "--enable-piecewise-cuda-graph", - "--tp=8", + "--tp=2", "--piecewise-cuda-graph-compiler=eager", "--disable-radix-cache", "--log-level",