diff --git a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py index cfdf62915a55..8819cfdaba86 100644 --- a/python/sglang/srt/multimodal/vit_cuda_graph_runner.py +++ b/python/sglang/srt/multimodal/vit_cuda_graph_runner.py @@ -17,11 +17,13 @@ from __future__ import annotations import inspect +from contextlib import nullcontext from typing import Dict, Hashable, List, Optional, Tuple import torch import torch.nn as nn +from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.server_args import get_global_server_args @@ -139,7 +141,11 @@ def _create_graph( override_backend = get_global_server_args().mm_attention_backend - with torch.cuda.graph(graph): + tp_group = get_tp_group() + ca_comm = tp_group.ca_comm + capture_ctx = ca_comm.capture() if ca_comm is not None else nullcontext() + + with capture_ctx, torch.cuda.graph(graph): y = None deepstack_outs: List[torch.Tensor] = [] deepstack_capture_idx = 0