diff --git a/python/sglang/srt/models/ernie45_vl.py b/python/sglang/srt/models/ernie45_vl.py index c371de793dc6..e2071f416667 100644 --- a/python/sglang/srt/models/ernie45_vl.py +++ b/python/sglang/srt/models/ernie45_vl.py @@ -30,6 +30,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, @@ -120,14 +121,16 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm1(x) hidden_states = rearrange(hidden_states, "s b ... -> b s ...") attn = self.attn( hidden_states, cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, ) attn = rearrange(attn, "b s ... -> s b ...") x = x + attn @@ -388,7 +391,13 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads - self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) self.blocks = nn.ModuleList( [ Ernie4_5_VisionBlock( @@ -413,7 +422,9 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb( + self, grid_thw: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pos_ids = [] for i in range(grid_thw.size(0)): t, h, w = grid_thw[i].tolist() @@ -440,11 +451,15 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: .flatten() ) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_combined = cos[pos_ids].flatten(1) + sin_combined = sin[pos_ids].flatten(1) + return cos_combined, sin_combined, pos_ids def forward( self, @@ -456,9 +471,11 @@ def forward( x = self.patch_embed(x) # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) + rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb( + grid_thw + ) + rotary_pos_emb_cos = torch.cat([rotary_pos_emb_cos, rotary_pos_emb_cos], dim=-1) + rotary_pos_emb_sin = torch.cat([rotary_pos_emb_sin, rotary_pos_emb_sin], dim=-1) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -468,7 +485,12 @@ def forward( # transformers x = x.unsqueeze(1) for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + ) final_output = self.ln(x)