diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index a3f2d30807fa..33aeec320e99 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -675,6 +675,8 @@ def forward( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb_cos: Optional[torch.Tensor] = None, + rotary_pos_emb_sin: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -724,26 +726,34 @@ def forward( rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] - if position_embeddings is not None: - original_shape = q.shape + cos = None + sin = None + if position_embeddings is not None: if self.customized_position_embedding_applier is not None: q, k = self.customized_position_embedding_applier( q, k, position_embeddings, x_shape ) - q = q.view(original_shape) - k = k.view(original_shape) else: cos, sin = position_embeddings + elif rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: + cos = rotary_pos_emb_cos + sin = rotary_pos_emb_sin + + if cos is not None and sin is not None: + original_shape = q.shape - # [total_tokens, head, head_size] - q = q.view(-1, head, self.head_size) - k = k.view(-1, head, self.head_size) + # [total_tokens, head, head_size] + q = q.view(-1, head, self.head_size) + k = k.view(-1, head, self.head_size) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + if cos.size(-1) * 2 == self.head_size: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) - q = q.view(original_shape) - k = k.view(original_shape) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + q = q.view(original_shape) + k = k.view(original_shape) if q.dim() == 4: # [b, s, head, head_size] --> [b * s, head, head_size] diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 0385c446632e..a7ab6f35dc95 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -219,6 +219,11 @@ def get_cos_sin_with_position(self, positions): sin.view(-1, 1, 1, last_dim).contiguous(), ) + def get_cos_sin(self, seqlen: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_sin = self.cos_sin_cache[:seqlen] + cos, sin = cos_sin.chunk(2, dim=-1) + return cos, sin + def forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index 05bb14af4bf9..cee2dd8d3da0 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -44,6 +44,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import ( @@ -157,7 +158,8 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, ) -> torch.Tensor: S, B, H = x.shape # norm1: flatten to 2D -> [S*B, H], then reshape back @@ -169,7 +171,8 @@ def forward( attn = self.attn( hidden_states, cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, ) attn = rearrange(attn, "b s h -> s b h") @@ -363,44 +366,6 @@ def forward( return embeddings -class Glm4vVisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - self.dim = dim - self.theta = theta - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self._seq_len_cached = 0 - self._freqs_cached = None - - def update_freqs_cache(self, seqlen: int) -> None: - if seqlen > self._seq_len_cached: - seqlen *= 2 - self._seq_len_cached = seqlen - self.inv_freq = 1.0 / ( - self.theta - ** ( - torch.arange( - 0, - self.dim, - 2, - dtype=torch.float, - device=self.inv_freq.device, - ) - / self.dim - ) - ) - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - self._freqs_cached = freqs - - def forward(self, seqlen: int) -> torch.Tensor: - self.update_freqs_cache(seqlen) - return self._freqs_cached[:seqlen] - - class Glm4vVisionModel(nn.Module): def __init__( self, @@ -431,7 +396,13 @@ def __init__( ) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ @@ -481,7 +452,9 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb( + self, grid_thw: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) @@ -507,11 +480,15 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: .flatten() ) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) + return cos_combined, sin_combined, pos_ids def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: # patchify @@ -520,7 +497,9 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: x = self.post_conv_layernorm(x) # compute position embedding - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) + rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb( + grid_thw + ) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -532,14 +511,19 @@ def forward(self, x: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] ) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - rotary_pos_emb_tuple = (emb.cos(), emb.sin()) + rotary_pos_emb_cos = torch.cat([rotary_pos_emb_cos, rotary_pos_emb_cos], dim=-1) + rotary_pos_emb_sin = torch.cat([rotary_pos_emb_sin, rotary_pos_emb_sin], dim=-1) # x.shape: (s, b, d) where b=1 for vision processing # transformers x = x.unsqueeze(1) for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=rotary_pos_emb_tuple) + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + ) # adapter x = self.post_layernorm(x) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 922d47632ab0..9e270d213ee0 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -24,9 +24,6 @@ import torch.nn as nn from einops import rearrange from transformers.activations import ACT2FN -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VisionRotaryEmbedding, -) from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig from sglang.srt.distributed import ( @@ -39,6 +36,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import ( @@ -188,14 +186,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") attn = self.attn( hidden_states, cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, ) attn = rearrange(attn, "b s ... -> s b ...") x += attn @@ -292,7 +292,13 @@ def __init__( self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ @@ -343,17 +349,24 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw): + def rot_pos_emb( + self, grid_thw: list[list[int]] + ) -> tuple[torch.Tensor, torch.Tensor]: pos_ids = [] for t, h, w in grid_thw: base = self.rot_pos_ids(h, w, self.spatial_merge_size) pos_ids.append(base if t == 1 else base.repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) + + return cos_combined, sin_combined def fast_pos_embed_interpolate(self, grid_thw): num_grid_per_side = int(self.num_position_embeddings**0.5) @@ -448,26 +461,34 @@ def forward( x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) x += pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - seq_len, _ = x.size() - rotary_pos_emb = rotary_pos_emb.to(x.device) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) # compute cu_seqlens cu_seqlens = compute_cu_seqlens_from_grid_numpy(grid_thw) x = x.unsqueeze(1) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] num_deepstack_captured = 0 + for layer_num, blk in enumerate(self.blocks): - x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + ) + if layer_num in self.deepstack_visual_indexes: deepstack_feature = self.deepstack_merger_list[num_deepstack_captured]( x