Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def _all_reduce_out_place(
qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm
torch_symm_mem_comm = self.torch_symm_mem_comm
assert any([qr_comm, ca_comm, pymscclpp_comm])
assert any([qr_comm, ca_comm, pymscclpp_comm, torch_symm_mem_comm])
if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def forward(
Returns:
[b * s, h, head_size]
"""
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1:
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"):
if "output_ws" not in kwargs:
raise RuntimeError("output_ws should be prepared for cuda-graph mode")

Expand Down Expand Up @@ -363,7 +363,7 @@ def forward(
Returns:
[b * s, h, head_size]
"""
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1:
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"):
max_seqlen = cu_seqlens[1]
output = flash_attn_varlen_func(
q,
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def forward(
position_embeddings: torch.Tensor,
output_ws=None,
) -> torch.Tensor:
ws = output_ws
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
Expand All @@ -182,7 +181,7 @@ def forward(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
output_ws=ws,
output_ws=output_ws,
)
attn = rearrange(attn, "b s h -> s b h")

Expand Down Expand Up @@ -390,7 +389,7 @@ def forward(
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH") and self.tp_size == 1:
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"):
return self.forward_with_cuda_graph(x, grid_thw)

# patchify
Expand Down
53 changes: 52 additions & 1 deletion python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
compute_cu_seqlens_from_grid_numpy,
)
from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model
from sglang.srt.multimodal.vit_cuda_graph_runner import ViTCudaGraphRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, get_int_env_var
from sglang.srt.utils import add_prefix, get_bool_env_var, get_int_env_var
from sglang.srt.utils.hf_transformers_utils import get_processor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,6 +189,7 @@ def forward(
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
output_ws: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
Expand All @@ -196,6 +198,7 @@ def forward(
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
output_ws=output_ws,
)
attn = rearrange(attn, "b s ... -> s b ...")
x += attn
Expand Down Expand Up @@ -341,6 +344,11 @@ def __init__(
]
)

self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.cuda_graph_runner: Optional[ViTCudaGraphRunner] = ViTCudaGraphRunner(self)

@property
def dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
Expand Down Expand Up @@ -458,6 +466,9 @@ def forward(
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
if get_bool_env_var("SGLANG_VIT_ENABLE_CUDA_GRAPH"):
return self.forward_with_cuda_graph(x, grid_thw)
Comment thread
yuan-luo marked this conversation as resolved.

x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)

Expand Down Expand Up @@ -501,6 +512,46 @@ def forward(
) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states

def forward_with_cuda_graph(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)

if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
x += pos_embeds

# rotary embedding -> (cos, sin)
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

# compute cu_seqlens
cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw)
if not isinstance(cu_seqlens, torch.Tensor):
cu_seqlens = torch.tensor(cu_seqlens, device=x.device, dtype=torch.int32)
else:
cu_seqlens = cu_seqlens.to(device=x.device, dtype=torch.int32)
cu_seqlens = cu_seqlens.contiguous()
Comment thread
yuan-luo marked this conversation as resolved.

# blocks + merger + deepstack(optional) via CUDA Graph Runner
return self.cuda_graph_runner.run(
x=x,
position_embeddings=None,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
cu_seqlens=cu_seqlens,
cu_window_seqlens=None,
output_indices=None,
)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down
Loading
Loading