Skip to content
Merged
111 changes: 101 additions & 10 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,83 @@ def _get_llm_pos_ids_for_vision(
return llm_pos_ids


# Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L554
class YaRNScalingMRotaryEmbedding(MRotaryEmbedding):
"""MRoPE-enabled rotary embedding with YaRN context scaling."""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
mrope_section: Optional[List[int]] = None,
mrope_interleaved: bool = False,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
truncate: bool = True,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.truncate = truncate
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
mrope_section=mrope_section,
mrope_interleaved=mrope_interleaved,
)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
self.truncate,
)
inv_freq_mask = (
1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache


class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
"""3D rotary positional embedding. [h w h w h w h w... t t t...]"""

Expand Down Expand Up @@ -2953,16 +3030,30 @@ def get_rope(
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
}
extra_kwargs["truncate"] = rope_scaling.get("truncate", True)
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
if "mrope_section" in rope_scaling:
rotary_emb = YaRNScalingMRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
**extra_kwargs,
)
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
Expand Down
29 changes: 27 additions & 2 deletions test/registered/rotary/test_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ class MRoPETestInfo(NamedTuple):
num_tokens_list = [11, 8192]


def create_yarn_rope_scaling(original_config, scaling_factor=2.0):
yarn_config = {
"rope_type": "yarn",
"factor": scaling_factor,
"original_max_position_embeddings": original_config.max_position_embeddings,
}
if hasattr(original_config, "rope_scaling") and original_config.rope_scaling:
if "mrope_section" in original_config.rope_scaling:
yarn_config["mrope_section"] = original_config.rope_scaling["mrope_section"]
if "mrope_interleaved" in original_config.rope_scaling:
yarn_config["mrope_interleaved"] = original_config.rope_scaling[
"mrope_interleaved"
]
return yarn_config


@pytest.mark.skipif(not (_is_cuda or _is_hip), reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_info, model_name",
Expand All @@ -85,12 +101,16 @@ class MRoPETestInfo(NamedTuple):
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
@pytest.mark.parametrize(
"rope_scaling_type", ["default", "yarn"], ids=["mrope_default", "mrope_yarn"]
)
def test_mrope(
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
dtype: torch.dtype,
num_tokens: int,
rope_scaling_type: str,
):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))

Expand All @@ -117,13 +137,18 @@ def test_mrope(
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)

if rope_scaling_type == "yarn":
rope_scaling_config = create_yarn_rope_scaling(config, scaling_factor=2.0)
else:
rope_scaling_config = config.rope_scaling

mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
rope_scaling=rope_scaling_config,
dtype=dtype,
).to(device=device)

Expand All @@ -133,7 +158,7 @@ def test_mrope(
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
)

query_native, key_native = mrope_helper_class._forward_native(
query_native, key_native = mrope_helper_class.forward_native(
positions,
query.clone(),
key.clone(),
Expand Down
Loading