From 729224662a630396b403dc04efeadeb889e8c62c Mon Sep 17 00:00:00 2001 From: "luoyuan.luo" Date: Sun, 11 Jan 2026 21:05:01 +0800 Subject: [PATCH] Refactor ViT CUDA Graph key generation --- python/sglang/srt/models/qwen2_5_vl.py | 36 ++- python/sglang/srt/models/qwen3_vl.py | 19 +- .../srt/multimodal/vit_cuda_graph_runner.py | 301 ++++++++++++------ 3 files changed, 247 insertions(+), 109 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7f44978fc59c..56744e1dc7a7 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -488,9 +488,17 @@ def forward_with_cuda_graph( # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) + # NOTE: get_window_index returns: + # - window_index: torch.Tensor (CPU) + # - cu_window_seqlens: python list[int] (CPU) + window_index, cu_window_seqlens_list = self.get_window_index(grid_thw) + + # Match runtime tensor values: unique_consecutive on CPU list (no GPU sync) + cu_window_seqlens_list = ViTCudaGraphRunner._unique_consecutive_list( + cu_window_seqlens_list + ) cu_window_seqlens = torch.tensor( - cu_window_seqlens, + cu_window_seqlens_list, device=x.device, dtype=torch.int32, ) @@ -523,6 +531,29 @@ def forward_with_cuda_graph( position_embeddings[1].to(x.device, x.dtype), ) + # CPU-side graph_key_hint (0 GPU sync) + # grid_thw is expected to be CPU tensor in normal pipeline. + # If it is CUDA, tolist() would sync; we intentionally avoid that case. + if isinstance(grid_thw, torch.Tensor): + assert ( + grid_thw.device.type == "cpu" + ), "grid_thw must be on CPU to avoid GPU sync for graph_key_hint." + grid_list = grid_thw.tolist() + else: + grid_list = grid_thw + + # cu_seqlens in this model is built as: [0] + [0] + cumsum(prod(t,h,w)) + cum = 0 + tmp = [0] + for t, h, w in grid_list: + cum += int(t) * int(h) * int(w) + tmp.append(cum) + cu_seqlens_list = [0] + tmp # double-leading-0 to match existing code + + full_sig = ViTCudaGraphRunner._hash_i32_list(cu_seqlens_list) + window_sig = ViTCudaGraphRunner._hash_i32_list(cu_window_seqlens_list) + graph_key_hint = (full_sig, window_sig) + # compute cu_seqlens - move cu_seqlens to GPU and make it int32 cu_seqlens = torch.cat( [ @@ -540,6 +571,7 @@ def forward_with_cuda_graph( cu_seqlens=cu_seqlens, cu_window_seqlens=cu_window_seqlens, output_indices=reverse_indices, + graph_key_hint=graph_key_hint, ) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 079f4584368d..541b7f09eb8f 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -497,8 +497,22 @@ def forward_with_cuda_graph( # rotary embedding -> (cos, sin) rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) - # compute cu_seqlens - cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw) + # compute cu_seqlens (this utility is expected to return CPU data; no GPU sync) + cu_seqlens_cpu = compute_cu_seqlens_from_grid_numpy(grid_thw) + + if isinstance(cu_seqlens_cpu, torch.Tensor): + assert ( + cu_seqlens_cpu.device.type == "cpu" + ), "compute_cu_seqlens_from_grid_numpy must return CPU tensor to avoid GPU sync." + cu_seqlens_list = cu_seqlens_cpu.to(torch.int32).contiguous().tolist() + else: + # numpy array / python list + cu_seqlens_list = [int(v) for v in cu_seqlens_cpu] + + graph_key_hint = ViTCudaGraphRunner._hash_i32_list(cu_seqlens_list) + + # move cu_seqlens to device for attention + cu_seqlens = cu_seqlens_cpu if not isinstance(cu_seqlens, torch.Tensor): cu_seqlens = torch.tensor(cu_seqlens, device=x.device, dtype=torch.int32) else: @@ -514,6 +528,7 @@ def forward_with_cuda_graph( cu_seqlens=cu_seqlens, cu_window_seqlens=None, output_indices=None, + graph_key_hint=graph_key_hint, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py index 683271be5141..2ee10d17c649 100644 --- a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py +++ b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py @@ -15,12 +15,21 @@ """ViT CUDA Graph Runner class.""" from __future__ import annotations +import array +import hashlib import inspect from typing import Dict, Hashable, List, Optional, Tuple import torch import torch.nn as nn +# TorchDynamo may lazily trace/compile on first call. +# Tracing queries CUDA RNG state, which is forbidden during CUDA graph capture. +try: + import torch._dynamo as dynamo +except Exception: + dynamo = None + from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.server_args import get_global_server_args @@ -109,22 +118,55 @@ def _ensure_sin_cos_ws(self, seq_len: int, head_dim: int): ) self.sin_cos_ws = (cos_ws, sin_ws) - def _get_graph_key(self, x_3d: torch.Tensor) -> int: - # x_3d: [S, B, H], B=1, S as graph_key - return x_3d.shape[0] + @staticmethod + def _unique_consecutive_list(vals: List[int]) -> List[int]: + out: List[int] = [] + last = None + for v in vals: + v = int(v) + if last is None or v != last: + out.append(v) + last = v + return out + + @staticmethod + def _hash_i32_list(vals: List[int]) -> int: + """Stable 64-bit hash for a non-negative int list (CPU only).""" + # cu_seqlens / cu_window_seqlens are non-negative and int32-safe in practice. + arr = array.array("I", (int(v) & 0xFFFFFFFF for v in vals)) + digest = hashlib.blake2b(arr.tobytes(), digest_size=8).digest() + return int.from_bytes(digest, "little", signed=False) + + def _make_graph_key( + self, + seq_len: int, + *, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + rotary_pos_emb_cos: Optional[torch.Tensor], + rotary_pos_emb_sin: Optional[torch.Tensor], + graph_key_hint: Optional[Hashable], + ) -> Hashable: + backend = get_global_server_args().mm_attention_backend + if position_embeddings is not None: + emb_mode = "pos" + elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + emb_mode = "rot" + else: + emb_mode = "none" + # If caller doesn't provide hint, fall back to key by seq_len only. + hint = graph_key_hint if graph_key_hint is not None else seq_len + return (int(seq_len), backend, emb_mode, hint) def _create_graph( self, - graph_key: int, - position_embeddings: Optional[ - Tuple[torch.Tensor, torch.Tensor] - ] = None, # (cos, sin), [S, D] + graph_key: Hashable, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, rotary_pos_emb_cos: Optional[torch.Tensor] = None, rotary_pos_emb_sin: Optional[torch.Tensor] = None, ): - graph = torch.cuda.CUDAGraph() vit = self.vit + seq_len = int(graph_key[0]) # graph_key = (seq_len, backend, emb_mode, hint) # Qwen2.5-VL if self._fullatt_block_indexes: @@ -138,89 +180,116 @@ def _create_graph( override_backend = get_global_server_args().mm_attention_backend - with torch.cuda.graph(graph): - y = None - deepstack_outs: List[torch.Tensor] = [] - deepstack_capture_idx = 0 - - for layer_num, blk in enumerate(vit.blocks): - if self._fullatt_block_indexes: - if layer_num in vit.fullatt_block_indexes: + def _capture_body(): + with torch.cuda.graph(graph): + y = None + deepstack_outs: List[torch.Tensor] = [] + deepstack_capture_idx = 0 + + for layer_num, blk in enumerate(vit.blocks): + if self._fullatt_block_indexes: + if layer_num in vit.fullatt_block_indexes: + cu_seqlens_now = cu_full + cu_seqlens_kk_now = cu_full_kk + max_len = max_full_len + else: + cu_seqlens_now = cu_window + cu_seqlens_kk_now = cu_window_kk + max_len = max_window_len + else: cu_seqlens_now = cu_full cu_seqlens_kk_now = cu_full_kk max_len = max_full_len + + if override_backend == "triton_attn": + cu_seq_len_ws = [cu_seqlens_now, cu_seqlens_kk_now, max_len] + elif override_backend == "fa3": + cu_seq_len_ws = [cu_seqlens_now, max_len] else: - cu_seqlens_now = cu_window - cu_seqlens_kk_now = cu_window_kk - max_len = max_window_len - else: - cu_seqlens_now = cu_full - cu_seqlens_kk_now = cu_full_kk - max_len = max_full_len - - if override_backend == "triton_attn": - cu_seq_len_ws = [cu_seqlens_now, cu_seqlens_kk_now, max_len] - elif override_backend == "fa3": - cu_seq_len_ws = [cu_seqlens_now, max_len] - else: - raise RuntimeError("Not supported ViT attention backend") - - if position_embeddings is not None: - if layer_num == 0: - y = blk( - self.block_input[graph_key], - cu_seqlens=cu_seq_len_ws, - position_embeddings=position_embeddings, - output_ws=self.block_ws[graph_key], - ) - else: - y = blk( - y, - cu_seqlens=cu_seq_len_ws, - position_embeddings=position_embeddings, - output_ws=self.block_ws[graph_key], - ) - elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: - if layer_num == 0: - y = blk( - self.block_input[graph_key], - cu_seqlens=cu_seq_len_ws, - rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin, - output_ws=self.block_ws[graph_key], - ) - else: - y = blk( - y, - cu_seqlens=cu_seq_len_ws, - rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin, - output_ws=self.block_ws[graph_key], - ) - - # Optional deepstack support (Qwen3-VL) - if ( - self._deepstack_visual_indexes - and layer_num in self._deepstack_visual_indexes - ): - if self._deepstack_merger_list is None: - raise RuntimeError( - "deepstack_visual_indexes exists but deepstack_merger_list is missing." - ) - deepstack_out = self._deepstack_merger_list[deepstack_capture_idx]( - y + raise RuntimeError("Not supported ViT attention backend") + + if position_embeddings is not None: + if layer_num == 0: + y = blk( + self.block_input[graph_key], + cu_seqlens=cu_seq_len_ws, + position_embeddings=position_embeddings, + output_ws=self.block_ws[graph_key], + ) + else: + y = blk( + y, + cu_seqlens=cu_seq_len_ws, + position_embeddings=position_embeddings, + output_ws=self.block_ws[graph_key], + ) + elif ( + rotary_pos_emb_cos is not None + and rotary_pos_emb_sin is not None + ): + if layer_num == 0: + y = blk( + self.block_input[graph_key], + cu_seqlens=cu_seq_len_ws, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + output_ws=self.block_ws[graph_key], + ) + else: + y = blk( + y, + cu_seqlens=cu_seq_len_ws, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + output_ws=self.block_ws[graph_key], + ) + + # Optional deepstack (Qwen3-VL) + if ( + self._deepstack_visual_indexes + and layer_num in self._deepstack_visual_indexes + ): + if self._deepstack_merger_list is None: + raise RuntimeError( + "deepstack_visual_indexes exists but deepstack_merger_list is missing." + ) + deepstack_out = self._deepstack_merger_list[ + deepstack_capture_idx + ](y) + deepstack_outs.append(deepstack_out) + deepstack_capture_idx += 1 + + main_out = vit.merger(y) + + if deepstack_outs: + self.block_output[graph_key] = torch.cat( + [main_out] + deepstack_outs, dim=1 ) - deepstack_outs.append(deepstack_out) - deepstack_capture_idx += 1 - - main_out = vit.merger(y) - - if deepstack_outs: - self.block_output[graph_key] = torch.cat( - [main_out] + deepstack_outs, dim=1 - ) - else: - self.block_output[graph_key] = main_out + else: + self.block_output[graph_key] = main_out + + try: + import torch._dynamo as dynamo # noqa + + # wrap capture body + try: + _capture = dynamo.disable(_capture_body) + except Exception: + _capture = dynamo.disable()(_capture_body) + + # temporarily force global disable (if present) + prev_disable = None + has_flag = hasattr(dynamo, "config") and hasattr(dynamo.config, "disable") + if has_flag: + prev_disable = dynamo.config.disable + dynamo.config.disable = True + try: + _capture() + finally: + if has_flag: + dynamo.config.disable = prev_disable + except Exception: + _capture_body() self.block_graphs[graph_key] = graph @@ -228,15 +297,23 @@ def create_graph( self, x_3d: torch.Tensor, # [S, 1, H] cu_seqlens: torch.Tensor, - cu_window_seqlens: torch.Tensor, + cu_window_seqlens: Optional[torch.Tensor], position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ], # (cos, sin), [S, D] rotary_pos_emb_cos: Optional[torch.Tensor] = None, rotary_pos_emb_sin: Optional[torch.Tensor] = None, - ) -> int: + graph_key_hint: Optional[Hashable] = None, + ) -> Hashable: vit = self.vit - graph_key = self._get_graph_key(x_3d) + seq_len = int(x_3d.shape[0]) + graph_key = self._make_graph_key( + seq_len, + position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + graph_key_hint=graph_key_hint, + ) if graph_key in self.block_graphs: return graph_key @@ -254,7 +331,7 @@ def create_graph( x_3d, device=self.device ).contiguous() self.block_ws[graph_key] = torch.empty( - graph_key, + seq_len, num_heads, attn_head_dim, device=self.device, @@ -266,6 +343,10 @@ def create_graph( if graph_key not in self.cu_window_len: self.cu_window_len[graph_key] = cu_window_seqlens self.cu_full_len[graph_key] = cu_seqlens + if cu_window_seqlens is None: + raise ValueError( + "cu_window_seqlens is required when fullatt_block_indexes is set." + ) self.cu_window_len_kk[graph_key] = ( cu_window_seqlens[1:] - cu_window_seqlens[:-1] ) @@ -278,10 +359,10 @@ def create_graph( if position_embeddings is not None: # make sure rotary workspace head_dim = position_embeddings[0].shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) + self._ensure_sin_cos_ws(seq_len, head_dim) - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws = self.sin_cos_ws[0][:seq_len, :] + used_sin_ws = self.sin_cos_ws[1][:seq_len, :] used_cos_ws.copy_(position_embeddings[0]) used_sin_ws.copy_(position_embeddings[1]) persist_position_embeddings = (used_cos_ws, used_sin_ws) @@ -291,10 +372,10 @@ def create_graph( elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # make sure rotary workspace head_dim = rotary_pos_emb_cos.shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) + self._ensure_sin_cos_ws(seq_len, head_dim) - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + used_cos_ws = self.sin_cos_ws[0][:seq_len, :] + used_sin_ws = self.sin_cos_ws[1][:seq_len, :] used_cos_ws.copy_(rotary_pos_emb_cos) used_sin_ws.copy_(rotary_pos_emb_sin) self._create_graph( @@ -308,28 +389,29 @@ def create_graph( def replay( self, - graph_key: int, + graph_key: Hashable, x_3d: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, rotary_pos_emb_cos: Optional[torch.Tensor] = None, rotary_pos_emb_sin: Optional[torch.Tensor] = None, output_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: + seq_len = int(graph_key[0]) if position_embeddings is not None: # update rotary workspace content head_dim = position_embeddings[0].shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + self._ensure_sin_cos_ws(seq_len, head_dim) + used_cos_ws = self.sin_cos_ws[0][:seq_len, :] + used_sin_ws = self.sin_cos_ws[1][:seq_len, :] used_cos_ws.copy_(position_embeddings[0]) used_sin_ws.copy_(position_embeddings[1]) elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # update rotary workspace content head_dim = rotary_pos_emb_cos.shape[1] - self._ensure_sin_cos_ws(graph_key, head_dim) - used_cos_ws = self.sin_cos_ws[0][:graph_key, :] - used_sin_ws = self.sin_cos_ws[1][:graph_key, :] + self._ensure_sin_cos_ws(seq_len, head_dim) + used_cos_ws = self.sin_cos_ws[0][:seq_len, :] + used_sin_ws = self.sin_cos_ws[1][:seq_len, :] used_cos_ws.copy_(rotary_pos_emb_cos) used_sin_ws.copy_(rotary_pos_emb_sin) @@ -351,15 +433,23 @@ def run( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - cu_window_seqlens: torch.Tensor, + cu_window_seqlens: Optional[torch.Tensor], position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], rotary_pos_emb_cos: Optional[torch.Tensor] = None, rotary_pos_emb_sin: Optional[torch.Tensor] = None, output_indices: Optional[torch.Tensor] = None, + graph_key_hint: Optional[Hashable] = None, ) -> torch.Tensor: # x: [seq_len, hidden] -> [S, B=1, H] x_3d = x.unsqueeze(1) - graph_key = self._get_graph_key(x_3d) + seq_len = int(x_3d.shape[0]) + graph_key = self._make_graph_key( + seq_len, + position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + graph_key_hint=graph_key_hint, + ) if graph_key not in self.block_graphs: self.create_graph( @@ -369,6 +459,7 @@ def run( cu_window_seqlens=cu_window_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, + graph_key_hint=graph_key_hint, ) return self.replay(