From 5b4d6b2b60d16cccee65fc86ad7e04da44675e86 Mon Sep 17 00:00:00 2001 From: "luoyuan.luo" Date: Wed, 17 Dec 2025 13:59:56 +0800 Subject: [PATCH] Support ViT CUDA Graph for Qwen3-VL --- .../sglang/srt/distributed/parallel_state.py | 2 +- python/sglang/srt/layers/attention/vision.py | 4 +- python/sglang/srt/models/qwen2_5_vl.py | 5 +- python/sglang/srt/models/qwen3_vl.py | 53 +++- .../srt/multimodal/vit_cuda_graph_runner.py | 232 +++++++++++++----- .../nightly/test_vlms_vit_cuda_graph.py | 1 + 6 files changed, 233 insertions(+), 64 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 216b8d5a316f..73f69b0871b4 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -629,7 +629,7 @@ def _all_reduce_out_place( qr_comm = self.qr_comm pymscclpp_comm = self.pymscclpp_comm torch_symm_mem_comm = self.torch_symm_mem_comm - assert any([qr_comm, ca_comm, pymscclpp_comm]) + assert any([qr_comm, ca_comm, pymscclpp_comm, torch_symm_mem_comm]) if outplace_all_reduce_method == "ca": assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 33aeec320e99..f66c248447e9 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -294,7 +294,7 @@ def forward( Returns: [b * s, h, head_size] """ - if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1: + if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"): if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for cuda-graph mode") @@ -363,7 +363,7 @@ def forward( Returns: [b * s, h, head_size] """ - if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1: + if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"): max_seqlen = cu_seqlens[1] output = flash_attn_varlen_func( q, diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7185de34c6f5..432179af7a93 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -170,7 +170,6 @@ def forward( position_embeddings: torch.Tensor, output_ws=None, ) -> torch.Tensor: - ws = output_ws S, B, H = x.shape # norm1: flatten to 2D -> [S*B, H], then reshape back x2d = x.reshape(-1, H) @@ -182,7 +181,7 @@ def forward( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, - output_ws=ws, + output_ws=output_ws, ) attn = rearrange(attn, "b s h -> s b h") @@ -390,7 +389,7 @@ def forward( x: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: - if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1: + if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"): return self.forward_with_cuda_graph(x, grid_thw) # patchify diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 9e270d213ee0..b7adf8d8c7ad 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -57,8 +57,9 @@ compute_cu_seqlens_from_grid_numpy, ) from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model +from sglang.srt.multimodal.vit_cuda_graph_runner import ViTCudaGraphRunner from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, get_int_env_var +from sglang.srt.utils import add_prefix, get_bool_env_var, get_int_env_var from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -188,6 +189,7 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, + output_ws: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") @@ -196,6 +198,7 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, + output_ws=output_ws, ) attn = rearrange(attn, "b s ... -> s b ...") x += attn @@ -341,6 +344,11 @@ def __init__( ] ) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.cuda_graph_runner: Optional[ViTCudaGraphRunner] = ViTCudaGraphRunner(self) + @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @@ -458,6 +466,9 @@ def forward( x: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: + if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"): + return self.forward_with_cuda_graph(x, grid_thw) + x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) @@ -501,6 +512,46 @@ def forward( ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] return hidden_states + def forward_with_cuda_graph( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + x += pos_embeds + + # rotary embedding -> (cos, sin) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + + # compute cu_seqlens + cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw) + if not isinstance(cu_seqlens, torch.Tensor): + cu_seqlens = torch.tensor(cu_seqlens, device=x.device, dtype=torch.int32) + else: + cu_seqlens = cu_seqlens.to(device=x.device, dtype=torch.int32) + cu_seqlens = cu_seqlens.contiguous() + + # blocks + merger + deepstack(optional) via CUDA Graph Runner + return self.cuda_graph_runner.run( + x=x, + position_embeddings=None, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + cu_seqlens=cu_seqlens, + cu_window_seqlens=None, + output_indices=None, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py index f5ba1b86dd62..683271be5141 100644 --- a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py +++ b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py @@ -16,7 +16,7 @@ from __future__ import annotations import inspect -from typing import Dict, Hashable, Optional, Tuple +from typing import Dict, Hashable, List, Optional, Tuple import torch import torch.nn as nn @@ -26,11 +26,18 @@ class ViTCudaGraphRunner: - """ViT CUDA Graph Runner + """Generic ViT CUDA Graph Runner. - Cached with graph_key = seq_len, for each seq_len capture once. - expose run(), internally call create_graph(). - exceed call invokes replay(). + This runner captures the "blocks + merger + deepstack merger (optional)" part + of a vision transformer into a CUDA graph and replays it for identical shapes. + + Optional for Qwen2.5 windowed attention: + - vit.fullatt_block_indexes: Sequence[int] + - run() provides both cu_seqlens and cu_window_seqlens + + Optional for Qwen3 deepstack: + - vit.deepstack_vision_indexes: Sequence[int] + - vit.deepstack_merger_list: nn.ModuleList (same length as deepstack_vision_indexes) """ def __init__( @@ -55,8 +62,15 @@ def __init__( self.sin_cos_ws: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.max_context_len = getattr(vit, "max_context_len", None) + # Qwen2.5-VL specific viarable. self._fullatt_block_indexes = set(getattr(vit, "fullatt_block_indexes", ())) + # Qwen3-VL specific variables. + self._deepstack_visual_indexes = list( + getattr(vit, "deepstack_visual_indexes", []) or [] + ) + self._deepstack_merger_list = getattr(vit, "deepstack_merger_list", None) + first_blk = vit.blocks[0] self._blk_accepts_output_ws = ( "output_ws" in inspect.signature(first_blk.forward).parameters @@ -99,31 +113,50 @@ def _get_graph_key(self, x_3d: torch.Tensor) -> int: # x_3d: [S, B, H], B=1, S as graph_key return x_3d.shape[0] - def _create_graph(self, graph_key: int, temp_cos_sin): + def _create_graph( + self, + graph_key: int, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # (cos, sin), [S, D] + rotary_pos_emb_cos: Optional[torch.Tensor] = None, + rotary_pos_emb_sin: Optional[torch.Tensor] = None, + ): + graph = torch.cuda.CUDAGraph() vit = self.vit - cu_window = self.cu_window_len[graph_key] + # Qwen2.5-VL + if self._fullatt_block_indexes: + cu_window = self.cu_window_len[graph_key] + cu_window_kk = self.cu_window_len_kk[graph_key] + max_window_len = int(cu_window_kk.max().item()) + cu_full = self.cu_full_len[graph_key] - cu_window_kk = self.cu_window_len_kk[graph_key] cu_full_kk = self.cu_full_len_kk[graph_key] - max_full_len = int(cu_full_kk.max().item()) - max_window_len = int(cu_window_kk.max().item()) override_backend = get_global_server_args().mm_attention_backend with torch.cuda.graph(graph): y = None + deepstack_outs: List[torch.Tensor] = [] + deepstack_capture_idx = 0 + for layer_num, blk in enumerate(vit.blocks): - if layer_num in vit.fullatt_block_indexes: + if self._fullatt_block_indexes: + if layer_num in vit.fullatt_block_indexes: + cu_seqlens_now = cu_full + cu_seqlens_kk_now = cu_full_kk + max_len = max_full_len + else: + cu_seqlens_now = cu_window + cu_seqlens_kk_now = cu_window_kk + max_len = max_window_len + else: cu_seqlens_now = cu_full cu_seqlens_kk_now = cu_full_kk max_len = max_full_len - else: - cu_seqlens_now = cu_window - cu_seqlens_kk_now = cu_window_kk - max_len = max_window_len if override_backend == "triton_attn": cu_seq_len_ws = [cu_seqlens_now, cu_seqlens_kk_now, max_len] @@ -132,31 +165,75 @@ def _create_graph(self, graph_key: int, temp_cos_sin): else: raise RuntimeError("Not supported ViT attention backend") - if layer_num == 0: - y = blk( - self.block_input[graph_key], - cu_seqlens=cu_seq_len_ws, - position_embeddings=temp_cos_sin, - output_ws=self.block_ws[graph_key], - ) - else: - y = blk( - y, - cu_seqlens=cu_seq_len_ws, - position_embeddings=temp_cos_sin, - output_ws=self.block_ws[graph_key], + if position_embeddings is not None: + if layer_num == 0: + y = blk( + self.block_input[graph_key], + cu_seqlens=cu_seq_len_ws, + position_embeddings=position_embeddings, + output_ws=self.block_ws[graph_key], + ) + else: + y = blk( + y, + cu_seqlens=cu_seq_len_ws, + position_embeddings=position_embeddings, + output_ws=self.block_ws[graph_key], + ) + elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + if layer_num == 0: + y = blk( + self.block_input[graph_key], + cu_seqlens=cu_seq_len_ws, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + output_ws=self.block_ws[graph_key], + ) + else: + y = blk( + y, + cu_seqlens=cu_seq_len_ws, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + output_ws=self.block_ws[graph_key], + ) + + # Optional deepstack support (Qwen3-VL) + if ( + self._deepstack_visual_indexes + and layer_num in self._deepstack_visual_indexes + ): + if self._deepstack_merger_list is None: + raise RuntimeError( + "deepstack_visual_indexes exists but deepstack_merger_list is missing." + ) + deepstack_out = self._deepstack_merger_list[deepstack_capture_idx]( + y ) + deepstack_outs.append(deepstack_out) + deepstack_capture_idx += 1 + + main_out = vit.merger(y) - self.block_output[graph_key] = vit.merger(y) + if deepstack_outs: + self.block_output[graph_key] = torch.cat( + [main_out] + deepstack_outs, dim=1 + ) + else: + self.block_output[graph_key] = main_out self.block_graphs[graph_key] = graph def create_graph( self, x_3d: torch.Tensor, # [S, 1, H] - position_embeddings: Tuple[torch.Tensor, torch.Tensor], # (cos, sin), [S, D] cu_seqlens: torch.Tensor, cu_window_seqlens: torch.Tensor, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ], # (cos, sin), [S, D] + rotary_pos_emb_cos: Optional[torch.Tensor] = None, + rotary_pos_emb_sin: Optional[torch.Tensor] = None, ) -> int: vit = self.vit graph_key = self._get_graph_key(x_3d) @@ -164,16 +241,6 @@ def create_graph( if graph_key in self.block_graphs: return graph_key - # make sure rotary workspace - head_dim = position_embeddings[0].shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) - - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] - used_cos_ws.copy_(position_embeddings[0]) - used_sin_ws.copy_(position_embeddings[1]) - temp_cos_sin = (used_cos_ws, used_sin_ws) - # pre-allocate workspace attn_module: VisionAttention = vit.blocks[0].attn num_heads = attn_module.num_attention_heads_per_partition @@ -194,15 +261,48 @@ def create_graph( dtype=self.dtype, ) - if graph_key not in self.cu_window_len: - self.cu_window_len[graph_key] = cu_window_seqlens - self.cu_full_len[graph_key] = cu_seqlens - self.cu_window_len_kk[graph_key] = ( - cu_window_seqlens[1:] - cu_window_seqlens[:-1] + # Qwen2.5-VL + if self._fullatt_block_indexes: + if graph_key not in self.cu_window_len: + self.cu_window_len[graph_key] = cu_window_seqlens + self.cu_full_len[graph_key] = cu_seqlens + self.cu_window_len_kk[graph_key] = ( + cu_window_seqlens[1:] - cu_window_seqlens[:-1] + ) + self.cu_full_len_kk[graph_key] = cu_seqlens[1:] - cu_seqlens[:-1] + else: + if graph_key not in self.cu_full_len: + self.cu_full_len[graph_key] = cu_seqlens + self.cu_full_len_kk[graph_key] = cu_seqlens[1:] - cu_seqlens[:-1] + + if position_embeddings is not None: + # make sure rotary workspace + head_dim = position_embeddings[0].shape[1] + self._ensure_sin_cos_ws(graph_key, head_dim) + + used_cos_ws = self.sin_cos_ws[0][:graph_key, :] + used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws.copy_(position_embeddings[0]) + used_sin_ws.copy_(position_embeddings[1]) + persist_position_embeddings = (used_cos_ws, used_sin_ws) + self._create_graph( + graph_key=graph_key, position_embeddings=persist_position_embeddings + ) + elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + # make sure rotary workspace + head_dim = rotary_pos_emb_cos.shape[1] + self._ensure_sin_cos_ws(graph_key, head_dim) + + used_cos_ws = self.sin_cos_ws[0][:graph_key, :] + used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws.copy_(rotary_pos_emb_cos) + used_sin_ws.copy_(rotary_pos_emb_sin) + self._create_graph( + graph_key=graph_key, + position_embeddings=None, + rotary_pos_emb_cos=used_cos_ws, + rotary_pos_emb_sin=used_sin_ws, ) - self.cu_full_len_kk[graph_key] = cu_seqlens[1:] - cu_seqlens[:-1] - - self._create_graph(graph_key, temp_cos_sin) return graph_key @@ -210,16 +310,28 @@ def replay( self, graph_key: int, x_3d: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb_cos: Optional[torch.Tensor] = None, + rotary_pos_emb_sin: Optional[torch.Tensor] = None, output_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # update rotary workspace content - head_dim = position_embeddings[0].shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] - used_cos_ws.copy_(position_embeddings[0]) - used_sin_ws.copy_(position_embeddings[1]) + + if position_embeddings is not None: + # update rotary workspace content + head_dim = position_embeddings[0].shape[1] + self._ensure_sin_cos_ws(graph_key, head_dim) + used_cos_ws = self.sin_cos_ws[0][:graph_key, :] + used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws.copy_(position_embeddings[0]) + used_sin_ws.copy_(position_embeddings[1]) + elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + # update rotary workspace content + head_dim = rotary_pos_emb_cos.shape[1] + self._ensure_sin_cos_ws(graph_key, head_dim) + used_cos_ws = self.sin_cos_ws[0][:graph_key, :] + used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws.copy_(rotary_pos_emb_cos) + used_sin_ws.copy_(rotary_pos_emb_sin) # copy input self.block_input[graph_key].copy_(x_3d) @@ -238,9 +350,11 @@ def replay( def run( self, x: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], cu_seqlens: torch.Tensor, cu_window_seqlens: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + rotary_pos_emb_cos: Optional[torch.Tensor] = None, + rotary_pos_emb_sin: Optional[torch.Tensor] = None, output_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: # x: [seq_len, hidden] -> [S, B=1, H] @@ -253,11 +367,15 @@ def run( position_embeddings=position_embeddings, cu_seqlens=cu_seqlens, cu_window_seqlens=cu_window_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, ) return self.replay( graph_key=graph_key, x_3d=x_3d, position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, output_indices=output_indices, ) diff --git a/test/manual/nightly/test_vlms_vit_cuda_graph.py b/test/manual/nightly/test_vlms_vit_cuda_graph.py index 32948e398173..50e601126735 100644 --- a/test/manual/nightly/test_vlms_vit_cuda_graph.py +++ b/test/manual/nightly/test_vlms_vit_cuda_graph.py @@ -19,6 +19,7 @@ MODELS = [ SimpleNamespace(model="Qwen/Qwen2.5-VL-7B-Instruct", mmmu_accuracy=0.60), + SimpleNamespace(model="Qwen/Qwen3-VL-8B-Instruct", mmmu_accuracy=0.60), ]