Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions python/sglang/srt/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
Expand Down Expand Up @@ -120,14 +121,16 @@ def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
Expand Down Expand Up @@ -388,7 +391,13 @@ def __init__(

norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2)
self.rotary_pos_emb = get_rope(
head_size=head_dim,
rotary_dim=head_dim // 2,
max_position=8192,
base=10000.0,
is_neox_style=True,
)
self.blocks = nn.ModuleList(
[
Ernie4_5_VisionBlock(
Expand All @@ -413,7 +422,9 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
def rot_pos_emb(
self, grid_thw: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_ids = []
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
Expand All @@ -440,11 +451,15 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
.flatten()
)
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

# Use pre-computed cos_sin_cache from RotaryEmbedding
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)

cos_combined = cos[pos_ids].flatten(1)
sin_combined = sin[pos_ids].flatten(1)
return cos_combined, sin_combined, pos_ids

def forward(
self,
Expand All @@ -456,9 +471,11 @@ def forward(
x = self.patch_embed(x)

# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
rotary_pos_emb_cos = torch.cat([rotary_pos_emb_cos, rotary_pos_emb_cos], dim=-1)
rotary_pos_emb_sin = torch.cat([rotary_pos_emb_sin, rotary_pos_emb_sin], dim=-1)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
Expand All @@ -468,7 +485,12 @@ def forward(
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
)

final_output = self.ln(x)

Expand Down
Loading