Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,8 @@ def forward(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
rotary_pos_emb_cos: Optional[torch.Tensor] = None,
rotary_pos_emb_sin: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -724,26 +726,34 @@ def forward(
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]

if position_embeddings is not None:
original_shape = q.shape
cos = None
sin = None

if position_embeddings is not None:
if self.customized_position_embedding_applier is not None:
q, k = self.customized_position_embedding_applier(
q, k, position_embeddings, x_shape
)
q = q.view(original_shape)
k = k.view(original_shape)
else:
cos, sin = position_embeddings
elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
cos = rotary_pos_emb_cos
sin = rotary_pos_emb_sin

if cos is not None and sin is not None:
original_shape = q.shape

# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)

q, k = apply_rotary_pos_emb(q, k, cos, sin)
if cos.size(-1) * 2 == self.head_size:
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)

q = q.view(original_shape)
k = k.view(original_shape)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.view(original_shape)
k = k.view(original_shape)

if q.dim() == 4:
# [b, s, head, head_size] --> [b * s, head, head_size]
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ def get_cos_sin_with_position(self, positions):
sin.view(-1, 1, 1, last_dim).contiguous(),
)

def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[:seqlen]
cos, sin = cos_sin.chunk(2, dim=-1)
return cos, sin

def forward_native(
self,
positions: torch.Tensor,
Expand Down
84 changes: 34 additions & 50 deletions python/sglang/srt/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
Expand Down Expand Up @@ -157,7 +158,8 @@ def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
) -> torch.Tensor:
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
Expand All @@ -169,7 +171,8 @@ def forward(
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)
attn = rearrange(attn, "b s h -> s b h")

Expand Down Expand Up @@ -363,44 +366,6 @@ def forward(
return embeddings


class Glm4vVisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None

def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0,
self.dim,
2,
dtype=torch.float,
device=self.inv_freq.device,
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs

def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]


class Glm4vVisionModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -431,7 +396,13 @@ def __init__(
)

head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)

self.blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -481,7 +452,9 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(
self, grid_thw: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
Expand All @@ -507,11 +480,15 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
.flatten()
)
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids

# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined, pos_ids

def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
# patchify
Expand All @@ -520,7 +497,9 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
x = self.post_conv_layernorm(x)

# compute position embedding
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
Expand All @@ -532,14 +511,19 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
)

emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
rotary_pos_emb_tuple = (emb.cos(), emb.sin())
rotary_pos_emb_cos = torch.cat([rotary_pos_emb_cos, rotary_pos_emb_cos], dim=-1)
rotary_pos_emb_sin = torch.cat([rotary_pos_emb_sin, rotary_pos_emb_sin], dim=-1)

# x.shape: (s, b, d) where b=1 for vision processing
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple)
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)

# adapter
x = self.post_layernorm(x)
Expand Down
61 changes: 41 additions & 20 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
import torch.nn as nn
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionRotaryEmbedding,
)

from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.distributed import (
Expand All @@ -39,6 +36,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
Expand Down Expand Up @@ -188,14 +186,16 @@ def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)
attn = rearrange(attn, "b s ... -> s b ...")
x += attn
Expand Down Expand Up @@ -292,7 +292,13 @@ def __init__(
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)

self.blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -343,17 +349,24 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.patch_embed.proj.weight.device

def rot_pos_emb(self, grid_thw):
def rot_pos_emb(
self, grid_thw: list[list[int]]
) -> tuple[torch.Tensor, torch.Tensor]:
pos_ids = []
for t, h, w in grid_thw:
base = self.rot_pos_ids(h, w, self.spatial_merge_size)
pos_ids.append(base if t == 1 else base.repeat(t, 1))

pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
max_grid_size = max(max(h, w) for _, h, w in grid_thw)

# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)

return cos_combined, sin_combined

def fast_pos_embed_interpolate(self, grid_thw):
num_grid_per_side = int(self.num_position_embeddings**0.5)
Expand Down Expand Up @@ -448,26 +461,34 @@ def forward(
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)

if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
x += pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)

seq_len, _ = x.size()
rotary_pos_emb = rotary_pos_emb.to(x.device)

rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

# compute cu_seqlens
cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw)

x = x.unsqueeze(1)
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)

deepstack_feature_lists = []
num_deepstack_captured = 0

for layer_num, blk in enumerate(self.blocks):
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)

if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
x
Expand Down
Loading