diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index da3d3ecb04d4..63b2bd6d6a7d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -2514,6 +2514,83 @@ def _get_llm_pos_ids_for_vision( return llm_pos_ids +# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L554 +class YaRNScalingMRotaryEmbedding(MRotaryEmbedding): + """MRoPE-enabled rotary embedding with YaRN context scaling.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + mrope_section: Optional[List[int]] = None, + mrope_interleaved: bool = False, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + truncate: bool = True, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.truncate = truncate + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + self.truncate, + ) + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. [h w h w h w h w... t t t...]""" @@ -2953,16 +3030,30 @@ def get_rope( in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } extra_kwargs["truncate"] = rope_scaling.get("truncate", True) - rotary_emb = YaRNScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) + if "mrope_section" in rope_scaling: + rotary_emb = YaRNScalingMRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + **extra_kwargs, + ) + else: + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling["original_max_position_embeddings"] diff --git a/test/registered/rotary/test_mrope.py b/test/registered/rotary/test_mrope.py index ed1e40aa65cb..3b2c8c56f6da 100644 --- a/test/registered/rotary/test_mrope.py +++ b/test/registered/rotary/test_mrope.py @@ -74,6 +74,22 @@ class MRoPETestInfo(NamedTuple): num_tokens_list = [11, 8192] +def create_yarn_rope_scaling(original_config, scaling_factor=2.0): + yarn_config = { + "rope_type": "yarn", + "factor": scaling_factor, + "original_max_position_embeddings": original_config.max_position_embeddings, + } + if hasattr(original_config, "rope_scaling") and original_config.rope_scaling: + if "mrope_section" in original_config.rope_scaling: + yarn_config["mrope_section"] = original_config.rope_scaling["mrope_section"] + if "mrope_interleaved" in original_config.rope_scaling: + yarn_config["mrope_interleaved"] = original_config.rope_scaling[ + "mrope_interleaved" + ] + return yarn_config + + @pytest.mark.skipif(not (_is_cuda or _is_hip), reason="Skipping CUDA/ROCm only tests.") @pytest.mark.parametrize( "model_info, model_name", @@ -85,12 +101,16 @@ class MRoPETestInfo(NamedTuple): @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("num_tokens", num_tokens_list) +@pytest.mark.parametrize( + "rope_scaling_type", ["default", "yarn"], ids=["mrope_default", "mrope_yarn"] +) def test_mrope( model_name: str, model_info: MRoPETestInfo, tp_size: int, dtype: torch.dtype, num_tokens: int, + rope_scaling_type: str, ): set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) @@ -117,13 +137,18 @@ def test_mrope( partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) rotary_dim = int(head_dim * partial_rotary_factor) + if rope_scaling_type == "yarn": + rope_scaling_config = create_yarn_rope_scaling(config, scaling_factor=2.0) + else: + rope_scaling_config = config.rope_scaling + mrope_helper_class = get_rope( head_size=head_dim, rotary_dim=rotary_dim, max_position=max_position, base=rope_theta, is_neox_style=is_neox_style, - rope_scaling=config.rope_scaling, + rope_scaling=rope_scaling_config, dtype=dtype, ).to(device=device) @@ -133,7 +158,7 @@ def test_mrope( num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device ) - query_native, key_native = mrope_helper_class._forward_native( + query_native, key_native = mrope_helper_class.forward_native( positions, query.clone(), key.clone(),