Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/rotary_embedding/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def get_rope(
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
mrope_interleaved_glm=rope_scaling.get(
"mrope_interleaved_glm", False
),
)
elif rope_scaling.get("use_fope", False):
rotary_emb = FourierRotaryEmbedding(
Expand Down
35 changes: 35 additions & 0 deletions python/sglang/srt/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def __init__(
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
mrope_interleaved: bool = False,
mrope_interleaved_glm: bool = False,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
self.mrope_interleaved_glm = mrope_interleaved_glm
if self.mrope_section:
expected_sum = rotary_dim // 2
actual_sum = sum(self.mrope_section)
Expand Down Expand Up @@ -86,6 +88,37 @@ def __init__(
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
)

# MRoPE axis_map interleaving pattern depends on mrope_section sizes.
# The algorithm cycles through axes [0(T), 1(H), 2(W)] round-robin,
# skipping any axis that has exhausted its allocated pairs.
#
# For GLM-V (mrope_section=[8,12,12]):
# T(8) < H(12) = W(12), so T exhausts first at pair 24.
# Result: [0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 1,1,2, 1,1,2, 2,2]
# After T runs out, only H and W fill the remaining slots.
#
# For Qwen3-VL (mrope_section=[24,20,20]):
# T(24) > H(20) = W(20), so H and W exhaust first near the tail.
# Result: [0,1,2, 0,1,2, ...repeated evenly..., 0,1, 0,1, 0,0]
# After H/W run out, T fills the remaining slots.

if self.mrope_interleaved_glm:
num_pairs = rotary_dim // 2
axis_map = torch.empty(num_pairs, dtype=torch.long)
assert sum(self.mrope_section) == num_pairs
counts = [0, 0, 0]
current_ax = 0

for i in range(num_pairs):
current_ax = i % 3
while counts[current_ax] >= self.mrope_section[current_ax]:
current_ax = (current_ax + 1) % 3

axis_map[i] = current_ax
counts[current_ax] += 1
self.register_buffer("axis_map", axis_map, persistent=False)
else:
self.axis_map = None
if get_global_server_args().rl_on_policy_target is not None:
self._forward_method = self.forward_native

Expand Down Expand Up @@ -214,7 +247,9 @@ def forward_triton(
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
self.mrope_interleaved_glm,
self.is_neox_style,
self.axis_map,
)
return query, key

Expand Down
18 changes: 15 additions & 3 deletions python/sglang/srt/layers/rotary_embedding/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _triton_mrope_forward_fused(
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
is_interleaved_glm: tl.constexpr,
is_neox_style: tl.constexpr,
axis_map_ptr,
):
pid = tl.program_id(0)
q_ptr = q_ptr + pid * q_stride
Expand All @@ -46,9 +48,15 @@ def _triton_mrope_forward_fused(
w_sin = w_cos + half_rd
cos_offsets = tl.arange(0, pad_hd // 2)
if is_interleaved:
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
t_mask = ~(h_mask | w_mask)
if is_interleaved_glm:
axes = tl.load(axis_map_ptr + cos_offsets, mask=cos_offsets < (pad_hd // 2))
t_mask = axes == 0
h_mask = axes == 1
w_mask = axes == 2
else:
h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h)
w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w)
t_mask = ~(h_mask | w_mask)
else:
t_end = mrope_section_t
h_end = t_end + mrope_section_h
Expand Down Expand Up @@ -109,7 +117,9 @@ def triton_mrope_fused(
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
mrope_interleaved_glm: bool,
is_neox_style: bool,
axis_map: torch.Tensor,
) -> None:
num_tokens, n_q_dim = q.shape
n_k_dim = k.shape[1]
Expand Down Expand Up @@ -137,7 +147,9 @@ def triton_mrope_fused(
mrope_section[1],
mrope_section[2],
mrope_interleaved,
mrope_interleaved_glm,
is_neox_style,
axis_map,
)


Expand Down
Loading