diff --git a/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py b/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py index 80bb66e7ea67..9b6ca64605b6 100644 --- a/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py +++ b/python/sglang/srt/batch_invariant_ops/batch_invariant_ops.py @@ -268,7 +268,7 @@ def matmul_persistent( K, N = b.shape # DeepGEMM has minimum dimension requirements for TMA descriptors - MIN_DEEPGEMM_DIM = 16 + MIN_DEEPGEMM_DIM = 2048 # wili, 2048 for Qwen3VL + TP4 / TP8, default value is 16 if ( _ENABLE_MM_DEEPGEMM diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 4abe5b2c4049..deef6f068c5c 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -3,6 +3,7 @@ import dataclasses import functools import math +import os # wili from functools import lru_cache, partial from typing import Any, Callable, Optional, Tuple @@ -755,8 +756,13 @@ def __init__( **kwargs, ): super().__init__() - self.tp_size = 1 if use_data_parallel else get_attention_tp_size() - self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() + self.enable_vfly = bool(int(os.environ.get("ENABLE_VFLY", "0"))) # wili + if self.enable_vfly: # wili, reuse TP group but keep TP size as 1 + self.tp_size = 1 + self.tp_rank = 0 + else: # wili, original code + self.tp_size = 1 if use_data_parallel else get_attention_tp_size() + self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() self.dropout = dropout self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( @@ -1092,18 +1098,30 @@ def forward( else: q, k = self._apply_qk_norm(q, k) - output = self.qkv_backend.forward( - q=q, - k=k, - v=v, - bsz=bsz, - seq_len=s, - cu_seqlens=cu_seqlens, - attention_mask=attention_mask, - sequence_lengths=sequence_lengths, - max_seqlen=max_seqlen, - output_ws=attn_output_ws, - ) + if self.enable_vfly: # wili + assert ( + attention_mask is None + ) # wili, `attention_mask` is always None in our workflow + q = q.unsqueeze( + 0 + ) # wili, vfly needs input as [batch_size, sequence_length, num_head, head_dim] + k = k.unsqueeze(0) + v = v.unsqueeze(0) + output = self.processor(q, k, v, cu_seqlens) + output = output[0] + else: # wili, original code + output = self.qkv_backend.forward( + q=q, + k=k, + v=v, + bsz=bsz, + seq_len=s, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + sequence_lengths=sequence_lengths, + max_seqlen=max_seqlen, + output_ws=attn_output_ws, + ) assert output.dim() == 3, output.shape diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index da5c315afb84..5d3a628cd623 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -683,7 +683,15 @@ def load_model( @staticmethod def load_weights_and_postprocess(model, weights, target_device): + + if bool(int(os.environ.get("ENABLE_VFLY", "0"))): # wili + print("[wili] Enable vision fly") + model.enable_vision_fly() + else: + print("[wili] Disable vision fly") + model.load_weights(weights) + model.visual.patch_embed.copy_conv3d_weight_to_linear() # wili, Conv3d -> Linear for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 34a645078c71..7c9737e69fce 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -16,15 +16,17 @@ import logging import math +import os # wili import re from functools import lru_cache, partial from typing import Callable, Iterable, List, Optional, Tuple, Union -import numpy as np +import numpy as np # wili import torch import torch.nn as nn from einops import rearrange from transformers.activations import ACT2FN +from vfly.utils.parallel import dit_sp_gather, dit_sp_split # wili from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -90,8 +92,12 @@ def __init__( use_data_parallel: bool = False, ): super().__init__() - self.tp_size = 1 if use_data_parallel else get_attention_tp_size() - self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() + if bool(int(os.environ.get("ENABLE_VFLY", "0"))): # wili, for vfly + self.tp_size = 1 # wili, reuse TP group but keep TP size as 1 + self.tp_rank = 0 # wili + else: # wili, original code + self.tp_size = 1 if use_data_parallel else get_attention_tp_size() + self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() self.linear_fc1 = ColumnParallelLinear( in_features, hidden_features, @@ -119,6 +125,8 @@ def forward(self, x: torch.Tensor): return mlp_output +# wili, original code of class Qwen3VLVisionPatchEmbed +original_Qwen3VLVisionPatchEmbed = """ class Qwen3VLVisionPatchEmbed(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -149,6 +157,44 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: -1, self.embed_dim ) return hidden_states +""" + + +# wili, improved version of class Qwen3VLVisionPatchEmbed +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, + self.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + k = self.in_channels * self.temporal_patch_size * self.patch_size**2 + self.linear = nn.Linear( + in_features=k, + out_features=self.embed_dim, + bias=True, + dtype=self.proj.weight.dtype, + ) + + def copy_conv3d_weight_to_linear(self): + # Call this after model loading in `sglang/srt/model_loader/loader.py: load_weights_and_postprocess()` + print("Copy weights from Conv3d to Linear in PatchEmbed") + with torch.no_grad(): + self.linear.weight.copy_(self.proj.weight.view(self.embed_dim, -1)) + self.linear.bias.copy_(self.proj.bias) + del self.proj + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.linear(hidden_states) class Qwen3_VisionBlock(nn.Module): @@ -246,8 +292,12 @@ def __init__( self.norm = norm_layer( self.hidden_size if use_postshuffle_norm else context_dim ) - self.tp_size = 1 if use_data_parallel else get_attention_tp_size() - self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() + if bool(int(os.environ.get("ENABLE_VFLY", "0"))): # wili, for vfly + self.tp_size = 1 # wili, reuse TP group but keep TP size as 1 + self.tp_rank = 0 # wili + else: # wili, original code + self.tp_size = 1 if use_data_parallel else get_attention_tp_size() + self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank() self.linear_fc1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, @@ -332,6 +382,7 @@ def __init__( base=10000.0, is_neox_style=True, ) + self.enable_vfly = bool(int(os.environ.get("ENABLE_VFLY", "0"))) # wili workspace_buffer = None if get_global_server_args().mm_attention_backend == "flashinfer_cudnn": @@ -387,18 +438,23 @@ def __init__( ] ) - self.tp_size = ( - 1 if use_data_parallel else get_tensor_model_parallel_world_size() - ) + if bool(int(os.environ.get("ENABLE_VFLY", "0"))): # wili, for vfly + self.tp_size = 1 + else: # wili, original code, but seems useless? + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.cuda_graph_runner: Optional[ViTCudaGraphRunner] = ViTCudaGraphRunner(self) @property def dtype(self) -> torch.dtype: - return self.patch_embed.proj.weight.dtype + # return self.patch_embed.proj.weight.dtype # wili, Conv3d -> Linear + return self.patch_embed.linear.weight.dtype # wili, Conv3d -> Linear @property def device(self) -> torch.device: - return self.patch_embed.proj.weight.device + # return self.patch_embed.proj.weight.device # wili, Conv3d -> Linear + return self.patch_embed.linear.weight.device # wili, Conv3d -> Linear def rot_pos_emb( self, grid_thw: list[list[int]] @@ -419,6 +475,39 @@ def rot_pos_emb( return cos_combined, sin_combined + def rot_pos_emb_v2(self, grid_thw): # wili, TODO: align logic with original code + """ + grid_thw: LongTensor on CPU / GPU, shape [N, 3], value (t,h,w) per row + return : bfloat16 tensor on GPU, shape [Σ(t*h*w), 2 * 18] + """ + device = grid_thw.device + m = self.spatial_merge_size + + pos_ids_list = [] + + for t, h, w in grid_thw: + t, h, w = t.item(), h.item(), w.item() + + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = ( + hpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + ) + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = ( + wpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + ) + + sample_pos = torch.stack([hpos_ids, wpos_ids], dim=-1) # [h*w, 2] + pos_ids_list.append(sample_pos.repeat(t, 1)) # [t*h*w, 2] + + pos_ids = torch.cat(pos_ids_list, dim=0) + + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + def _get_interpolation_indices(self, dim_size: int) -> torch.Tensor: """ Compute continuous interpolation indices for a single dimension. @@ -715,6 +804,163 @@ def compute_flashinfer_sequence_lengths_padded( seq_lens = np.concatenate([seq_lens, pad], axis=0) # (B_padded,) return seq_lens + def fast_pos_embed_interpolate_v2( + self, grid_thw: torch.Tensor + ): # wili, TODO: align logic with original code + """ + grid_thw: LongTensor on CPU / GPU, shape [N, 3], value (t,h,w) per row + return : bfloat16 tensor on GPU, shape [Σ(t*h*w), self.pos_embed.embedding_dim] + """ + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + grid_thw = grid_thw.to(device, non_blocking=True) + num_grid = int(self.num_position_embeddings**0.5) + m_size = self.spatial_merge_size + embedding_dim = self.pos_embed.embedding_dim + + num_patch_per_clip = grid_thw.prod( + dim=1 + ) # [t_i * h_i * w_i for i in range len(grid_thw)] + num_patch_quad = num_patch_per_clip * 4 # 4 indice / weights per patch + num_elements = int(num_patch_per_clip.sum()) # number of total patches, on CPU + + offset = torch.cat( + [ + torch.tensor([0], dtype=torch.long, device=device), + num_patch_per_clip.cumsum(0), + ] + ) + offset_quad = offset * 4 + + idx_all = torch.empty(num_patch_quad.sum(), dtype=torch.long, device=device) + wgt_all = torch.empty(num_patch_quad.sum(), dtype=dtype, device=device) + + for st, ed, (t, h, w) in zip(offset_quad[:-1], offset_quad[1:], grid_thw): + h_idx = torch.linspace(0, num_grid - 1, h, device=device) + w_idx = torch.linspace(0, num_grid - 1, w, device=device) + + h_floor = h_idx.floor().long() + w_floor = w_idx.floor().long() + h_ceil = (h_floor + 1).clamp_max(num_grid - 1) + w_ceil = (w_floor + 1).clamp_max(num_grid - 1) + + hf, wf = torch.meshgrid(h_floor, w_floor, indexing="ij") + hc, wf = torch.meshgrid(h_ceil, w_floor, indexing="ij") + hf, wc = torch.meshgrid(h_floor, w_ceil, indexing="ij") + hc, wc = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + idx4 = torch.stack( + [ + hf * num_grid + wf, + hf * num_grid + wc, + hc * num_grid + wf, + hc * num_grid + wc, + ], + dim=-1, + ) + + dh = (h_idx - h_floor.float()).view(-1, 1) + dw = (w_idx - w_floor.float()).view(1, -1) + w4 = torch.stack( + [(1 - dh) * (1 - dw), (1 - dh) * dw, dh * (1 - dw), dh * dw], dim=-1 + ) + + idx_all[st:ed] = idx4.flatten().repeat_interleave(t) + wgt_all[st:ed] = w4.flatten().repeat_interleave(t) + + patch_pos_embed = self.pos_embed(idx_all) * wgt_all.unsqueeze(1) + patch_pos_embed = patch_pos_embed.view(-1, 4, embedding_dim).sum(dim=1) + + out = torch.empty([num_elements, embedding_dim], dtype=dtype, device=device) + for st, ed, (t, h, w) in zip(offset[:-1], offset[1:], grid_thw): + emb = patch_pos_embed[st:ed] + emb = emb.view(t, h // m_size, m_size, w // m_size, m_size, embedding_dim) + emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().flatten(0, 4) + out[st:ed] = emb + + return out + + def fast_pos_embed_interpolate_v3( + self, grid_thw: torch.Tensor + ): # wili, TODO: align logic with original code + """ + grid_thw: LongTensor on CPU / GPU, shape [N, 3], value (t,h,w) per row + return : bfloat16 tensor on GPU, shape [Σ(t*h*w), self.pos_embed.embedding_dim] + """ + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + grid_thw_cpu = grid_thw.detach().cpu().numpy() + num_grid = int(self.num_position_embeddings**0.5) + m_size = self.spatial_merge_size + embedding_dim = self.pos_embed.embedding_dim + + num_patch_per_clip = [int(t * h * w) for t, h, w in grid_thw_cpu] + total_patches = sum(num_patch_per_clip) + + idx_all = np.empty((total_patches, 4), dtype=np.int64) + wgt_all = np.empty((total_patches, 4), dtype=np.float32) + + offset = 0 + for t, h, w in grid_thw_cpu: + h_idx = np.linspace(0, num_grid - 1, h) + w_idx = np.linspace(0, num_grid - 1, w) + h_floor = np.floor(h_idx).astype(int) + w_floor = np.floor(w_idx).astype(int) + h_ceil = np.clip(h_floor + 1, 0, num_grid - 1) + w_ceil = np.clip(w_floor + 1, 0, num_grid - 1) + + hf, wf = np.meshgrid(h_floor, w_floor, indexing="ij") + hc, wf2 = np.meshgrid(h_ceil, w_floor, indexing="ij") + hf2, wc = np.meshgrid(h_floor, w_ceil, indexing="ij") + hc2, wc2 = np.meshgrid(h_ceil, w_ceil, indexing="ij") + idx4 = np.stack( + [ + hf * num_grid + wf, + hf2 * num_grid + wc, + hc * num_grid + wf2, + hc2 * num_grid + wc2, + ], + axis=-1, + ) # [h, w, 4] + + dh = (h_idx - h_floor).reshape(-1, 1) + dw = (w_idx - w_floor).reshape(1, -1) + w4 = np.stack( + [(1 - dh) * (1 - dw), (1 - dh) * dw, dh * (1 - dw), dh * dw], axis=-1 + ) # [h, w, 4] + + idx4 = np.tile(idx4, (t, 1, 1, 1)) # [t, h, w, 4] + w4 = np.tile(w4, (t, 1, 1, 1)) # [t, h, w, 4] + + patch_count = t * h * w + idx_all[offset : offset + patch_count] = idx4.reshape(-1, 4) + wgt_all[offset : offset + patch_count] = w4.reshape(-1, 4) + offset += patch_count + + idx_all = torch.from_numpy(idx_all.reshape(-1)).to(device) + wgt_all = torch.from_numpy(wgt_all.reshape(-1)).to(device, dtype=dtype) + + patch_pos_embed = self.pos_embed(idx_all) * wgt_all.unsqueeze(1) + patch_pos_embed = patch_pos_embed.view(-1, 4, embedding_dim).sum(dim=1) + + offset_cumsum = np.cumsum([0] + num_patch_per_clip) + out = torch.empty([total_patches, embedding_dim], dtype=dtype, device=device) + + # PErmute indices rather than values + all_indices = np.empty(total_patches, dtype=np.int32) + for i, (st, ed, (t, h, w)) in enumerate( + zip(offset_cumsum[:-1], offset_cumsum[1:], grid_thw_cpu) + ): + base_idx = np.arange(st, ed).reshape(t, h, w) + base_idx = base_idx.reshape(t, h // m_size, m_size, w // m_size, m_size) + base_idx = base_idx.transpose(0, 1, 3, 2, 4) + base_idx = base_idx.reshape(-1) + all_indices[st:ed] = base_idx + + all_indices = torch.from_numpy(all_indices) + out[:] = patch_pos_embed[all_indices] + + return out + def forward( self, x: torch.Tensor, @@ -724,6 +970,9 @@ def forward( return self.forward_with_cuda_graph(x, grid_thw) x = x.to(device=self.device, dtype=self.dtype) + grid_thw = grid_thw.to( + device=self.device + ) # wili, TODO: align logic with original code x = self.patch_embed(x) if isinstance(grid_thw, list): @@ -733,10 +982,17 @@ def forward( grid_thw_list = grid_thw.tolist() grid_thw = grid_thw.cpu().numpy() - pos_embeds = self.fast_pos_embed_interpolate_from_list(grid_thw_list) + pos_embeds = self.fast_pos_embed_interpolate_from_list( + grid_thw_list + ) # wili, original code + # pos_embeds = self.fast_pos_embed_interpolate_v2(grid_thw) # wili + # pos_embeds = self.fast_pos_embed_interpolate_v3(grid_thw) # wili, TODO: align logic with original code x += pos_embeds - rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( + grid_thw_list + ) # wili, original code + # rotary_pos_emb = self.rot_pos_emb_v2(grid_thw) # wili, TODO: align logic with original code # ---- build token indptr (B+1,) ---- token_cu_seqlens = np.repeat( @@ -794,7 +1050,52 @@ def forward( cu_seqlens = cu_seqlens.to("cpu") max_seqlen = None - x = x.unsqueeze(1) + x = x.unsqueeze( + 1 + ) # wili, [sl,bs=1,h], transpose to [bs,sl,h] in Qwen3_VisionBlock just before attention + # print("[wili] ==== shape before blocks ====================================") + # print(f"{x.shape = }") + # print(f"{rotary_pos_emb_cos.shape = }") + # print("[wili] ==== shape before blocks ====================================") + + if self.enable_vfly: # wili + # wili, pad sequence length to be multiple of 32 + # For example we use a picture of 1608x828 as input, + # It is resize (by `smart_resize()` in library transformers) to 1600x832 with aligned 32 + # Then it is merged pixels (by `_preprocess()` in library transformers) to sequence length of 1600/16*832/16=5200 with 16x16 per a block + # So the the sequence length here (x.shape[1]) is 5200 + # Using cp=8, the sequence length per worker will be 5200 / 8 = 650, + # Noticing this line (Qwen3VLMoeVisionPatchMerger.forward, around Line 278 in this file): + # x = self.norm(x.view(-1, self.hidden_size)) # here x.shape == [650, 1152], self.hidden_size == 4608 + # So 650 / 4 = 162.5, leads to a error + # As a workaround, we pad the sequence 5200 to 5216 here with aligned 32, + # So 5216 / 8 / 4 = 163, OK for the PatchMerger. + seq_len = x.shape[0] + pad_length = int((seq_len + 31) / 32) * 32 - seq_len + pad_size = [x.size(i) for i in range(x.ndim)] + pad_size[0] = pad_length + x = torch.cat([x, x.new_zeros(*pad_size)], dim=0).contiguous() + pad_size = [ + rotary_pos_emb_cos.size(i) for i in range(rotary_pos_emb_cos.ndim) + ] + pad_size[0] = pad_length + rotary_pos_emb_cos = torch.cat( + [rotary_pos_emb_cos, rotary_pos_emb_cos.new_zeros(*pad_size)], dim=0 + ).contiguous() + rotary_pos_emb_sin = torch.cat( + [rotary_pos_emb_sin, rotary_pos_emb_sin.new_zeros(*pad_size)], dim=0 + ).contiguous() + + x = dit_sp_split( + x, dim=0 + ) # wili, split sequence parts for distributed processing + rotary_pos_emb_cos = dit_sp_split(rotary_pos_emb_cos, dim=0) # wili + rotary_pos_emb_sin = dit_sp_split(rotary_pos_emb_sin, dim=0) # wili + + # print("[wili] ==== shape after dit_sp_split ====================================") + # print(f"{x.shape = }") + # print(f"{rotary_pos_emb_cos.shape = }") + # print("[wili] ==== shape after dit_sp_split ====================================") cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) @@ -821,6 +1122,25 @@ def forward( hidden_states = torch.cat( [x] + deepstack_feature_lists, dim=1 ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + if self.enable_vfly: # wili + hidden_states = hidden_states.unsqueeze( + 0 + ) # wili, [batch_size=1, seq_len, hidden_size * (1 + depth_of_deepstack)] + # print("[wili] ==== shape before dit_sp_gather ====================================") + # print(f"{hidden_states.shape = }") + # print("[wili] ==== shape before dit_sp_gather ====================================") + + hidden_states = dit_sp_gather( + hidden_states, dim=1 + ) # wili, gather sequence parts back + + hidden_states = hidden_states.squeeze(0) + + # print("[wili] ==== shape after blocks ====================================") + # print(f"{hidden_states.shape = }") + # print("[wili] ==== shape after blocks ====================================") + return hidden_states def forward_with_cuda_graph( @@ -1035,7 +1355,8 @@ def __init__( # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. quant_config=None, norm_eps=getattr(config, "rms_norm_eps", 1e-6), - prefix=add_prefix("model.visual", prefix), + # prefix=add_prefix("model.visual", prefix), # wili, original code + prefix=add_prefix("visual", prefix), # wili, add back this line for QWen3VL use_data_parallel=self.use_data_parallel, ) @@ -1051,7 +1372,10 @@ def __init__( self.model = language_model_cls( config=self.config, quant_config=quant_config, - prefix=add_prefix("model.language_model", prefix), + # prefix=add_prefix("model.language_model", prefix), # wili, original code + prefix=add_prefix( + "model", prefix + ), # wili, add back this line for Qwen3VL ) if self.pp_group.is_last_rank: if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: @@ -1376,5 +1700,111 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def enable_vision_fly(self): # wili + # Copy from /work/qwen3-vl/third_party/VisionFly/examples/qwenimage.py + from vfly import setup_configs + from vfly.configs.parallel import DiTParallelConfig, get_dit_parallel_config + from vfly.configs.pipeline import PipelineConfig + from vfly.layers import VflyAttnProcessor, apply_vfly_linear, apply_vfly_norm + from vfly.utils import get_logger + + from .common import ( + BaseArgumentParser, + create_vfly_config, + validate_parallel_config, + ) + + logger = get_logger(__name__) + + class VflyQwenDoubleStreamAttnProcessor2_0(VflyAttnProcessor): + + def __init__(self): + super().__init__() + logger.debug("VflyQwenDoubleStreamAttnProcessor2_0 initialized") + + def __call__( + self, + q: torch.FloatTensor, + k: torch.FloatTensor, + v: torch.FloatTensor, + cu_seqlens: torch.IntTensor, + ) -> torch.FloatTensor: + pfg = get_dit_parallel_config() + world_size = pfg.ulysses_size() # Only ulysses is supported + + seq_lens_list = cu_seqlens.diff() + max_seqlen = torch.max(seq_lens_list) + total_seq_len = cu_seqlens[-1].item() + seq_len_padded = ( + (total_seq_len + world_size - 1) // world_size * world_size + ) + uneven_number = seq_len_padded - total_seq_len + seq_len_cur_rank = q.shape[1] + if torch.distributed.get_rank() == world_size - 1: + seq_len_cur_rank = seq_len_cur_rank - uneven_number + + parallel_config = DiTParallelConfig() + parallel_config.set_config( + cfg_size=pfg.cp_size(), + ulysses_size=pfg.ulysses_size(), + ring_size=pfg.ring_size(), + ) + PipelineConfig.set_uneven_cp_config( + total_seq_len, seq_len_padded, seq_len_cur_rank, parallel_config + ) + + return self.vfly_attn( + q, + k, + v, + tensor_layout="NHD", + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens.clone(), + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + ) + + # Setup argument parser + parser = BaseArgumentParser("") + parser.set_defaults( + ulysses=torch.distributed.get_world_size(), attn_type="flash-attn3" + ) + args = parser.parse_args([]) + """ + enable_autotuner = False + if args.linear_type == "auto" or args.attn_type == "auto": + enable_autotuner = True + if not args.disable_torch_compile: + logger.warning("Disable torch compile when using autotuner") + args.disable_torch_compile = True + if args.enable_vfly_cpu_offload: + logger.warning("Disable vfly cpu offload when using autotuner") + args.enable_vfly_cpu_offload = False + """ + # Validate configuration + validate_parallel_config(args) + + # Load pipeline + vfly_configs = create_vfly_config(args) + setup_configs(**vfly_configs) + + pipe = self.visual + for name, module in pipe.blocks.named_modules(): + if isinstance(module, VisionAttention): + attn_processor = VflyQwenDoubleStreamAttnProcessor2_0() + attn_processor.name = name + module.processor = attn_processor + apply_vfly_linear(pipe, load_parameters=True) + apply_vfly_norm( + pipe, + rmsnorm=["norm_q", "norm_k", "norm_added_q", "norm_added_k"], + load_parameters=True, + ) + """ + if not args.disable_torch_compile: + self.visual = torch.compile(self.visual, mode=args.torch_compile_mode) + """ + return + EntryClass = Qwen3VLForConditionalGeneration diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 22adf71e2946..f4f3b66b9757 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -431,7 +431,12 @@ def _load_single_item( try: if modality == Modality.IMAGE: img, _ = load_image(data) - if discard_alpha_channel and img.mode != "RGB": + if ( + discard_alpha_channel + and img.mode != "RGB" + and not isinstance(img, torch.Tensor) + ): # wili + # if discard_alpha_channel and img.mode != "RGB": # wili, original code img = img.convert("RGB") return img elif modality == Modality.VIDEO: diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index cb3204e3d794..6142d1a9c5ef 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -101,6 +101,8 @@ from sglang.srt.server_args import ServerArgs +from torchvision.io import decode_jpeg # wili + logger = logging.getLogger(__name__) torch_release = pkg_version.parse(torch.__version__).release @@ -923,8 +925,24 @@ def load_image( elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): image = Image.open(image_file) elif image_file.startswith("data:"): - image_file = image_file.split(",")[1] - image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) + image_metadata, image_file = image_file.split(",") # wili + if ("jpg" in image_metadata) or ( + "jpeg" in image_metadata + ): # wili, for jpeg base64 on NVIDIA GPU + image_bytes = pybase64.b64decode(image_file, validate=True) + image = torch.frombuffer(image_bytes, dtype=torch.uint8) + image = decode_jpeg(image, device="cuda") + # import cupy as cp # wili, deprecated solution + # from nvidia import nvimgcodec # wili, deprecated solution + # code_stream = nvimgcodec.CodeStream(image_bytes) + # image = torch.from_dlpack(img_cupy.to_dlpack()).clone() + # img_cupy = nvimgcodec.Decoder().decode(code_stream) + # elif ("jp2" in image_metadata) or ("j2k" in image_metadata): # wili, specified for jpeg2000 base64, not supported yet + # pass + else: + image = Image.open( + BytesIO(pybase64.b64decode(image_file, validate=True)) + ) # wili, original code elif isinstance(image_file, str): image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) else: