Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions python/sglang/srt/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,17 @@ def forward_with_cuda_graph(
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)

window_index, cu_window_seqlens = self.get_window_index(grid_thw)
# NOTE: get_window_index returns:
# - window_index: torch.Tensor (CPU)
# - cu_window_seqlens: python list[int] (CPU)
window_index, cu_window_seqlens_list = self.get_window_index(grid_thw)

# Match runtime tensor values: unique_consecutive on CPU list (no GPU sync)
cu_window_seqlens_list = ViTCudaGraphRunner._unique_consecutive_list(
cu_window_seqlens_list
)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
cu_window_seqlens_list,
device=x.device,
dtype=torch.int32,
)
Expand Down Expand Up @@ -523,6 +531,29 @@ def forward_with_cuda_graph(
position_embeddings[1].to(x.device, x.dtype),
)

# CPU-side graph_key_hint (0 GPU sync)
# grid_thw is expected to be CPU tensor in normal pipeline.
# If it is CUDA, tolist() would sync; we intentionally avoid that case.
if isinstance(grid_thw, torch.Tensor):
assert (
grid_thw.device.type == "cpu"
), "grid_thw must be on CPU to avoid GPU sync for graph_key_hint."
grid_list = grid_thw.tolist()
else:
grid_list = grid_thw

# cu_seqlens in this model is built as: [0] + [0] + cumsum(prod(t,h,w))
cum = 0
tmp = [0]
for t, h, w in grid_list:
cum += int(t) * int(h) * int(w)
tmp.append(cum)
cu_seqlens_list = [0] + tmp # double-leading-0 to match existing code

full_sig = ViTCudaGraphRunner._hash_i32_list(cu_seqlens_list)
window_sig = ViTCudaGraphRunner._hash_i32_list(cu_window_seqlens_list)
graph_key_hint = (full_sig, window_sig)

# compute cu_seqlens - move cu_seqlens to GPU and make it int32
cu_seqlens = torch.cat(
[
Expand All @@ -540,6 +571,7 @@ def forward_with_cuda_graph(
cu_seqlens=cu_seqlens,
cu_window_seqlens=cu_window_seqlens,
output_indices=reverse_indices,
graph_key_hint=graph_key_hint,
)


Expand Down
19 changes: 17 additions & 2 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,22 @@ def forward_with_cuda_graph(
# rotary embedding -> (cos, sin)
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

# compute cu_seqlens
cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw)
# compute cu_seqlens (this utility is expected to return CPU data; no GPU sync)
cu_seqlens_cpu = compute_cu_seqlens_from_grid_numpy(grid_thw)

if isinstance(cu_seqlens_cpu, torch.Tensor):
assert (
cu_seqlens_cpu.device.type == "cpu"
), "compute_cu_seqlens_from_grid_numpy must return CPU tensor to avoid GPU sync."
cu_seqlens_list = cu_seqlens_cpu.to(torch.int32).contiguous().tolist()
else:
# numpy array / python list
cu_seqlens_list = [int(v) for v in cu_seqlens_cpu]

graph_key_hint = ViTCudaGraphRunner._hash_i32_list(cu_seqlens_list)

# move cu_seqlens to device for attention
cu_seqlens = cu_seqlens_cpu
if not isinstance(cu_seqlens, torch.Tensor):
cu_seqlens = torch.tensor(cu_seqlens, device=x.device, dtype=torch.int32)
else:
Expand All @@ -514,6 +528,7 @@ def forward_with_cuda_graph(
cu_seqlens=cu_seqlens,
cu_window_seqlens=None,
output_indices=None,
graph_key_hint=graph_key_hint,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Expand Down
Loading
Loading