Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,43 @@ def forward(
return query_out.type_as(query), key_out.type_as(key)


class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.

Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

def _compute_cos_sin_cache(self) -> torch.Tensor:
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (
self.rotary_dim / (self.rotary_dim - 2)
)
Comment on lines +916 to +918
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation self.rotary_dim / (self.rotary_dim - 2) can lead to a division-by-zero error if self.rotary_dim is equal to 2. Adding a check would make the implementation more robust against unexpected configurations.


inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""

Expand Down Expand Up @@ -1234,15 +1271,26 @@ def get_rope(
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
if "alpha" in rope_scaling:
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling["alpha"],
dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
Expand Down
Loading
Loading