diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index e95e9543f7f6..27e28577c96e 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -171,6 +171,9 @@ def get_rope( dtype, mrope_section=rope_scaling["mrope_section"], mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + mrope_interleaved_glm=rope_scaling.get( + "mrope_interleaved_glm", False + ), ) elif rope_scaling.get("use_fope", False): rotary_emb = FourierRotaryEmbedding( diff --git a/python/sglang/srt/layers/rotary_embedding/mrope.py b/python/sglang/srt/layers/rotary_embedding/mrope.py index 237528fd1d47..9c93ad1ffd21 100644 --- a/python/sglang/srt/layers/rotary_embedding/mrope.py +++ b/python/sglang/srt/layers/rotary_embedding/mrope.py @@ -52,12 +52,14 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[List[int]] = None, mrope_interleaved: bool = False, + mrope_interleaved_glm: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved + self.mrope_interleaved_glm = mrope_interleaved_glm if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -86,6 +88,37 @@ def __init__( f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) + # MRoPE axis_map interleaving pattern depends on mrope_section sizes. + # The algorithm cycles through axes [0(T), 1(H), 2(W)] round-robin, + # skipping any axis that has exhausted its allocated pairs. + # + # For GLM-V (mrope_section=[8,12,12]): + # T(8) < H(12) = W(12), so T exhausts first at pair 24. + # Result: [0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 1,1,2, 1,1,2, 2,2] + # After T runs out, only H and W fill the remaining slots. + # + # For Qwen3-VL (mrope_section=[24,20,20]): + # T(24) > H(20) = W(20), so H and W exhaust first near the tail. + # Result: [0,1,2, 0,1,2, ...repeated evenly..., 0,1, 0,1, 0,0] + # After H/W run out, T fills the remaining slots. + + if self.mrope_interleaved_glm: + num_pairs = rotary_dim // 2 + axis_map = torch.empty(num_pairs, dtype=torch.long) + assert sum(self.mrope_section) == num_pairs + counts = [0, 0, 0] + current_ax = 0 + + for i in range(num_pairs): + current_ax = i % 3 + while counts[current_ax] >= self.mrope_section[current_ax]: + current_ax = (current_ax + 1) % 3 + + axis_map[i] = current_ax + counts[current_ax] += 1 + self.register_buffer("axis_map", axis_map, persistent=False) + else: + self.axis_map = None if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native @@ -214,7 +247,9 @@ def forward_triton( self.head_size, self.rotary_dim, self.mrope_interleaved, + self.mrope_interleaved_glm, self.is_neox_style, + self.axis_map, ) return query, key diff --git a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py index 9a3d21bf83bb..0a8dc2c33c7b 100644 --- a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py +++ b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py @@ -29,7 +29,9 @@ def _triton_mrope_forward_fused( mrope_section_h: tl.constexpr, mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, + is_interleaved_glm: tl.constexpr, is_neox_style: tl.constexpr, + axis_map_ptr, ): pid = tl.program_id(0) q_ptr = q_ptr + pid * q_stride @@ -46,9 +48,15 @@ def _triton_mrope_forward_fused( w_sin = w_cos + half_rd cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) - w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) - t_mask = ~(h_mask | w_mask) + if is_interleaved_glm: + axes = tl.load(axis_map_ptr + cos_offsets, mask=cos_offsets < (pad_hd // 2)) + t_mask = axes == 0 + h_mask = axes == 1 + w_mask = axes == 2 + else: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t h_end = t_end + mrope_section_h @@ -109,7 +117,9 @@ def triton_mrope_fused( head_size: int, rotary_dim: int, mrope_interleaved: bool, + mrope_interleaved_glm: bool, is_neox_style: bool, + axis_map: torch.Tensor, ) -> None: num_tokens, n_q_dim = q.shape n_k_dim = k.shape[1] @@ -137,7 +147,9 @@ def triton_mrope_fused( mrope_section[1], mrope_section[2], mrope_interleaved, + mrope_interleaved_glm, is_neox_style, + axis_map, )