Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/srt/multimodal/vit_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from __future__ import annotations

import inspect
from contextlib import nullcontext
from typing import Dict, Hashable, List, Optional, Tuple

import torch
import torch.nn as nn

from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.server_args import get_global_server_args

Expand Down Expand Up @@ -139,7 +141,11 @@ def _create_graph(

override_backend = get_global_server_args().mm_attention_backend

with torch.cuda.graph(graph):
tp_group = get_tp_group()
ca_comm = tp_group.ca_comm
capture_ctx = ca_comm.capture() if ca_comm is not None else nullcontext()

with capture_ctx, torch.cuda.graph(graph):
y = None
deepstack_outs: List[torch.Tensor] = []
deepstack_capture_idx = 0
Expand Down
Loading