Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ def embed_mm_inputs(
data_embedding_func_mapping: Dict[Modality, DataEmbeddingFunc] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: Dict[Modality, bool] = {},
prealloc_deepstack: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
Expand Down Expand Up @@ -1019,12 +1020,16 @@ def embed_mm_inputs(
deepstack_embedding_shape = input_embeds.shape[:-1] + (
input_embeds.shape[-1] * num_deepstack_embeddings,
)
# a zero-filled embedding, with the same length of input_embeds, but different hidden_size
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=input_embeds.device,
dtype=input_embeds.dtype,
)
if prealloc_deepstack is not None:
assert prealloc_deepstack.shape == deepstack_embedding_shape
input_deepstack_embeds = prealloc_deepstack
input_deepstack_embeds.zero_()
else:
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=input_embeds.device,
dtype=input_embeds.dtype,
)

other_info["input_deepstack_embeds"] = input_deepstack_embeds

Expand Down Expand Up @@ -1091,6 +1096,7 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
prealloc_deepstack = kwargs.get("input_deepstack_embeds", None)
input_embeds, other_info = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
Expand All @@ -1101,6 +1107,7 @@ def general_mm_embed_routine(
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
prealloc_deepstack=prealloc_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ def __init__(self, model_runner: ModelRunner):
mrope_positions = torch.zeros(
(3, self.max_num_tokens), dtype=torch.int64
)
if hasattr(self.model_runner.model, "num_deepstack_embeddings"):
deepstack_slots = getattr(
self.model_runner.model, "num_deepstack_embeddings", 3
)
self.input_deepstack_embeds = torch.zeros(
(
self.max_num_tokens,
self.model_runner.model_config.hidden_size
* deepstack_slots,
),
dtype=self.model_runner.dtype,
)
else:
input_embeds = None
mrope_positions = None
Expand Down Expand Up @@ -392,6 +404,9 @@ def warmup_torch_compile(self, num_tokens: int):
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(None, num_tokens, forward_batch.dp_padding_mode.is_max_len())
set_is_extend_in_batch(False)
kwargs = {}
if self.is_multimodal and hasattr(self, "input_deepstack_embeds"):
kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[:num_tokens]
with set_forward_context(
forward_batch,
self.attention_layers,
Expand All @@ -403,6 +418,7 @@ def warmup_torch_compile(self, num_tokens: int):
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)

def _cache_loc_dtype(self):
Expand Down Expand Up @@ -547,6 +563,10 @@ def capture_one_batch_size(self, num_tokens: int):

self.model_runner.attn_backend.init_forward_metadata(forward_batch)

kwargs = {}
if self.is_multimodal and hasattr(self, "input_deepstack_embeds"):
kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[:num_tokens]

# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
Expand All @@ -560,7 +580,6 @@ def run_once():
# It is True in this context but we need to set it to use low latency deepep mode.
set_is_extend_in_batch(False)

kwargs = {}
with set_forward_context(
forward_batch,
self.attention_layers,
Expand Down Expand Up @@ -735,6 +754,11 @@ def replay(
# Due to the dispatch kernel for MLA model, we init the metadata with original forward_batch
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
if self.is_multimodal and hasattr(self, "input_deepstack_embeds"):
static_num_tokens = len(static_forward_batch.input_ids)
kwargs["input_deepstack_embeds"] = self.input_deepstack_embeds[
:static_num_tokens
]
# Replay
with set_forward_context(
static_forward_batch,
Expand Down
59 changes: 41 additions & 18 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,16 +913,37 @@ def __init__(
self.deepstack_embed_to_decoder_layer = range(
len(config.vision_config.deepstack_visual_indexes)
)
self.deepstack_embeds_buffer = None

def get_deepstack_embeds(
self, layer_idx: int, input_deepstack_embeds: Optional[torch.Tensor]
self,
layer_idx: int,
input_deepstack_embeds: Optional[torch.Tensor],
seq_len: int,
) -> Optional[torch.Tensor]:
"""Get deepstack embeddings for a given layer index, or None if not applicable."""
if (
input_deepstack_embeds is None
or layer_idx not in self.deepstack_embed_to_decoder_layer
):
if layer_idx not in self.deepstack_embed_to_decoder_layer:
return None
if input_deepstack_embeds is None:
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
total = self.hidden_size * len(self.deepstack_embed_to_decoder_layer)
if (
self.deepstack_embeds_buffer is None
or self.deepstack_embeds_buffer.size(0) < seq_len
):
new_len = max(
seq_len,
(
self.deepstack_embeds_buffer.size(0) * 2
if self.deepstack_embeds_buffer is not None
else seq_len
),
)
self.deepstack_embeds_buffer = torch.zeros(
new_len, total, dtype=dtype, device=device
).contiguous()
sep = self.hidden_size * layer_idx
return self.deepstack_embeds_buffer[:seq_len, sep : sep + self.hidden_size]
sep = self.hidden_size * layer_idx
return input_deepstack_embeds[:, sep : sep + self.hidden_size]

Expand Down Expand Up @@ -963,7 +984,7 @@ def forward(
# The order matters because addition with different tensors is not associative in practice.
# Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
deepstack_embeds = self.get_deepstack_embeds(
layer_idx - 1, input_deepstack_embeds
layer_idx - 1, input_deepstack_embeds, hidden_states.shape[0]
)
hidden_states, residual = layer(
positions,
Expand All @@ -975,7 +996,7 @@ def forward(

# Handle deepstack for the last processed layer if it exists.
last_deepstack = self.get_deepstack_embeds(
self.end_layer - 1, input_deepstack_embeds
self.end_layer - 1, input_deepstack_embeds, hidden_states.shape[0]
)

if not self.pp_group.is_last_rank:
Expand Down Expand Up @@ -1241,6 +1262,7 @@ def forward(
forward_batch: ForwardBatch,
get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
**kwargs,
):
"""Run forward pass for Qwen3-VL.

Expand All @@ -1266,16 +1288,17 @@ def forward(
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)

hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
pp_proxy_tensors=pp_proxy_tensors,
)
with torch.no_grad():
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
use_deepstack=self.use_deepstack,
pp_proxy_tensors=pp_proxy_tensors,
**kwargs,
)

if self.pp_group.is_last_rank:
if not get_embedding:
Expand Down
37 changes: 29 additions & 8 deletions python/sglang/srt/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,40 @@ def __init__(
# This approach follows the original implementation.
# TODO: make config of type Qwen3VLMoeConfig, so that we can directly obtain deepstack_visual_indexes.
self.deepstack_embed_to_decoder_layer = range(3)
self.deepstack_embeds_buffer = None

def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens

def get_deepstack_embeds(
self, layer_idx: int, input_deepstack_embeds: Optional[torch.Tensor]
self,
layer_idx: int,
input_deepstack_embeds: Optional[torch.Tensor],
seq_len: int,
) -> Optional[torch.Tensor]:
"""Get deepstack embeddings for a given layer index, or None if not applicable."""
if (
input_deepstack_embeds is None
or layer_idx not in self.deepstack_embed_to_decoder_layer
):
if layer_idx not in self.deepstack_embed_to_decoder_layer:
return None
if input_deepstack_embeds is None:
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
total = self.hidden_size * len(self.deepstack_embed_to_decoder_layer)
if (
self.deepstack_embeds_buffer is None
or self.deepstack_embeds_buffer.size(0) < seq_len
):
new_len = max(
seq_len,
(
self.deepstack_embeds_buffer.size(0) * 2
if self.deepstack_embeds_buffer is not None
else seq_len
),
)
self.deepstack_embeds_buffer = torch.zeros(
new_len, total, dtype=dtype, device=device
).contiguous()
sep = self.hidden_size * layer_idx
return self.deepstack_embeds_buffer[:seq_len, sep : sep + self.hidden_size]
sep = self.hidden_size * layer_idx
return input_deepstack_embeds[:, sep : sep + self.hidden_size]

Expand Down Expand Up @@ -109,7 +130,7 @@ def forward(
# The order matters because addition with different tensors is not associative in practice.
# Deepstack for prev_layer is applied at the start of current layer via post_residual_addition.
deepstack_embeds = self.get_deepstack_embeds(
layer_idx - 1, input_deepstack_embeds
layer_idx - 1, input_deepstack_embeds, hidden_states.shape[0]
)
hidden_states, residual = layer(
positions,
Expand All @@ -121,7 +142,7 @@ def forward(

# Handle deepstack for the last processed layer if it exists.
last_deepstack = self.get_deepstack_embeds(
self.end_layer - 1, input_deepstack_embeds
self.end_layer - 1, input_deepstack_embeds, hidden_states.shape[0]
)

if not self.pp_group.is_last_rank:
Expand Down
8 changes: 5 additions & 3 deletions test/manual/nightly/test_vlms_piecewise_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
)

MODELS = [
SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.60),
SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.50),
SimpleNamespace(model="Qwen/Qwen3-VL-8B-Instruct", mmmu_accuracy=0.50),
SimpleNamespace(model="Qwen/Qwen3-VL-30B-A3B-Instruct", mmmu_accuracy=0.50),
]


Expand Down Expand Up @@ -59,7 +61,7 @@ def run_mmmu_eval(
"""
# -------- fixed settings --------
model = "openai_compatible"
tp = 1
tp = 2
tasks = "mmmu_val"
batch_size = 32
log_suffix = "openai_compatible"
Expand Down Expand Up @@ -138,7 +140,7 @@ def _run_vlm_mmmu_test(
"--piecewise-cuda-graph-max-tokens",
"8192",
"--enable-piecewise-cuda-graph",
"--tp=8",
"--tp=2",
"--piecewise-cuda-graph-compiler=eager",
"--disable-radix-cache",
"--log-level",
Expand Down
Loading