From afc23aef058765e0ab3b6a937019fceb98d9eec4 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Fri, 14 Nov 2025 18:03:22 -0800 Subject: [PATCH 1/5] initial support for mrope section with yarn Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- python/sglang/srt/layers/rotary_embedding.py | 107 +++++++++++++++++-- 1 file changed, 97 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e9f37513845c..e800c00c9610 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -2298,6 +2298,79 @@ def _get_llm_pos_ids_for_vision( return llm_pos_ids +class YaRNScalingMRotaryEmbedding(MRotaryEmbedding): + """MRoPE-enabled rotary embedding with YaRN context scaling.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + class DualChunkRotaryEmbedding(CustomOp): """Rotary positional embedding for Dual Chunk Attention.""" @@ -2652,16 +2725,30 @@ def get_rope( in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } extra_kwargs["truncate"] = rope_scaling.get("truncate", True) - rotary_emb = YaRNScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) + if "mrope_section" in rope_scaling: + rotary_emb = YaRNScalingMRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + **extra_kwargs, + ) + else: + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] From e9a408135725360f2981f9d32733e7935f0feae6 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Sat, 15 Nov 2025 14:38:11 -0800 Subject: [PATCH 2/5] add credit Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- python/sglang/srt/layers/rotary_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index e800c00c9610..9ae4747df90e 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -2298,6 +2298,7 @@ def _get_llm_pos_ids_for_vision( return llm_pos_ids +# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L554 class YaRNScalingMRotaryEmbedding(MRotaryEmbedding): """MRoPE-enabled rotary embedding with YaRN context scaling.""" From 23c64e8015d7b31b443a26a354805992b4075e8a Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Thu, 27 Nov 2025 23:40:26 -0800 Subject: [PATCH 3/5] add unit test Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- test/srt/rotary_embedding/test_mrope.py | 28 ++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/srt/rotary_embedding/test_mrope.py b/test/srt/rotary_embedding/test_mrope.py index c9837f87facd..b625291bc3b9 100644 --- a/test/srt/rotary_embedding/test_mrope.py +++ b/test/srt/rotary_embedding/test_mrope.py @@ -68,6 +68,23 @@ class MRoPETestInfo(NamedTuple): num_tokens_list = [11, 8192] +def create_yarn_rope_scaling(original_config, scaling_factor=2.0): + yarn_config = { + "rope_type": "yarn", + "factor": scaling_factor, + "original_max_position_embeddings": original_config.max_position_embeddings, + } + print(original_config) + if hasattr(original_config, "rope_scaling") and original_config.rope_scaling: + if "mrope_section" in original_config.rope_scaling: + yarn_config["mrope_section"] = original_config.rope_scaling["mrope_section"] + if "mrope_interleaved" in original_config.rope_scaling: + yarn_config["mrope_interleaved"] = original_config.rope_scaling[ + "mrope_interleaved" + ] + return yarn_config + + @pytest.mark.skipif(not _is_cuda, reason="Skipping CUDA/ROCm only tests.") @pytest.mark.parametrize( "model_info, model_name", @@ -79,12 +96,16 @@ class MRoPETestInfo(NamedTuple): @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("num_tokens", num_tokens_list) +@pytest.mark.parametrize( + "rope_scaling_type", ["default", "yarn"], ids=["mrope_default", "mrope_yarn"] +) def test_mrope( model_name: str, model_info: MRoPETestInfo, tp_size: int, dtype: torch.dtype, num_tokens: int, + rope_scaling_type: str, ): set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) @@ -111,13 +132,18 @@ def test_mrope( partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) rotary_dim = int(head_dim * partial_rotary_factor) + if rope_scaling_type == "yarn": + rope_scaling_config = create_yarn_rope_scaling(config, scaling_factor=2.0) + else: + rope_scaling_config = config.rope_scaling + mrope_helper_class = get_rope( head_size=head_dim, rotary_dim=rotary_dim, max_position=max_position, base=rope_theta, is_neox_style=is_neox_style, - rope_scaling=config.rope_scaling, + rope_scaling=rope_scaling_config, dtype=dtype, ).to(device=device) From 7f4581c58bda552c79df8917294bbe1ea8d81b85 Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Tue, 2 Dec 2025 11:36:54 -0800 Subject: [PATCH 4/5] caught an extraneous print Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- test/srt/rotary_embedding/test_mrope.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/rotary_embedding/test_mrope.py b/test/srt/rotary_embedding/test_mrope.py index b625291bc3b9..873f9385081a 100644 --- a/test/srt/rotary_embedding/test_mrope.py +++ b/test/srt/rotary_embedding/test_mrope.py @@ -74,7 +74,6 @@ def create_yarn_rope_scaling(original_config, scaling_factor=2.0): "factor": scaling_factor, "original_max_position_embeddings": original_config.max_position_embeddings, } - print(original_config) if hasattr(original_config, "rope_scaling") and original_config.rope_scaling: if "mrope_section" in original_config.rope_scaling: yarn_config["mrope_section"] = original_config.rope_scaling["mrope_section"] From 6908180be1279194dc815a54c8086b0ce1ef57dd Mon Sep 17 00:00:00 2001 From: "Raayan Dhar raayan.dhar@gmail.com" Date: Mon, 5 Jan 2026 10:32:38 -0800 Subject: [PATCH 5/5] fix unit test Signed-off-by: Raayan Dhar raayan.dhar@gmail.com --- python/sglang/srt/layers/rotary_embedding.py | 3 +++ test/srt/rotary_embedding/test_mrope.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f929a65803a4..a693c0a6c043 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -2343,12 +2343,14 @@ def __init__( attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, + truncate: bool = True, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.truncate = truncate self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__( head_size, @@ -2374,6 +2376,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: self.rotary_dim, self.base, self.max_position_embeddings, + self.truncate, ) inv_freq_mask = ( 1 diff --git a/test/srt/rotary_embedding/test_mrope.py b/test/srt/rotary_embedding/test_mrope.py index 8ae4f9a37a21..e9d377b0c335 100644 --- a/test/srt/rotary_embedding/test_mrope.py +++ b/test/srt/rotary_embedding/test_mrope.py @@ -152,7 +152,7 @@ def test_mrope( num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device ) - query_native, key_native = mrope_helper_class._forward_native( + query_native, key_native = mrope_helper_class.forward_native( positions, query.clone(), key.clone(),