Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def matmul_persistent(
K, N = b.shape

# DeepGEMM has minimum dimension requirements for TMA descriptors
MIN_DEEPGEMM_DIM = 16
MIN_DEEPGEMM_DIM = 2048 # wili, 2048 for Qwen3VL + TP4 / TP8, default value is 16

if (
_ENABLE_MM_DEEPGEMM
Expand Down
46 changes: 32 additions & 14 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import functools
import math
import os # wili
from functools import lru_cache, partial
from typing import Any, Callable, Optional, Tuple

Expand Down Expand Up @@ -755,8 +756,13 @@ def __init__(
**kwargs,
):
super().__init__()
self.tp_size = 1 if use_data_parallel else get_attention_tp_size()
self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank()
self.enable_vfly = bool(int(os.environ.get("ENABLE_VFLY", "0"))) # wili
if self.enable_vfly: # wili, reuse TP group but keep TP size as 1
self.tp_size = 1
self.tp_rank = 0
else: # wili, original code
self.tp_size = 1 if use_data_parallel else get_attention_tp_size()
self.tp_rank = 0 if use_data_parallel else get_attention_tp_rank()
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
Expand Down Expand Up @@ -1092,18 +1098,30 @@ def forward(
else:
q, k = self._apply_qk_norm(q, k)

output = self.qkv_backend.forward(
q=q,
k=k,
v=v,
bsz=bsz,
seq_len=s,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
sequence_lengths=sequence_lengths,
max_seqlen=max_seqlen,
output_ws=attn_output_ws,
)
if self.enable_vfly: # wili
assert (
attention_mask is None
) # wili, `attention_mask` is always None in our workflow
q = q.unsqueeze(
0
) # wili, vfly needs input as [batch_size, sequence_length, num_head, head_dim]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
output = self.processor(q, k, v, cu_seqlens)
output = output[0]
else: # wili, original code
output = self.qkv_backend.forward(
q=q,
k=k,
v=v,
bsz=bsz,
seq_len=s,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
sequence_lengths=sequence_lengths,
max_seqlen=max_seqlen,
output_ws=attn_output_ws,
)

assert output.dim() == 3, output.shape

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,15 @@ def load_model(

@staticmethod
def load_weights_and_postprocess(model, weights, target_device):

if bool(int(os.environ.get("ENABLE_VFLY", "0"))): # wili
print("[wili] Enable vision fly")
model.enable_vision_fly()
else:
print("[wili] Disable vision fly")

model.load_weights(weights)
model.visual.patch_embed.copy_conv3d_weight_to_linear() # wili, Conv3d -> Linear

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
Loading
Loading