@@ -532,6 +532,40 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
532532 return cache
533533
534534
535+ class DynamicNTKAlphaRotaryEmbedding (RotaryEmbedding ):
536+ """RotaryEmbedding extended with Dynamic NTK alpha.
537+
538+ Based on the original RotaryEmbedding implementation.
539+ """
540+
541+ def __init__ (
542+ self ,
543+ head_size : int ,
544+ rotary_dim : int ,
545+ max_position_embeddings : int ,
546+ base : int ,
547+ is_neox_style : bool ,
548+ scaling_alpha : float ,
549+ dtype : torch .dtype ,
550+ ) -> None :
551+ self .scaling_alpha = scaling_alpha
552+ super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
553+ is_neox_style , dtype )
554+
555+ def _compute_cos_sin_cache (self ) -> torch .Tensor :
556+ # For Hunyuan DynamicNTKAlphaRotaryEmbedding
557+ max_len = self .max_position_embeddings
558+ base = self .base * self .scaling_alpha ** (self .rotary_dim / (self .rotary_dim - 2 ))
559+ inv_freq = self ._compute_inv_freq (base )
560+ t = torch .arange (max_len , dtype = torch .float )
561+
562+ freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
563+ cos = freqs .cos ()
564+ sin = freqs .sin ()
565+ cache = torch .cat ((cos , sin ), dim = - 1 )
566+ return cache
567+
568+
535569# Inverse dim formula to find dim based on number of rotations
536570def _yarn_find_correction_dim (num_rotations : int ,
537571 dim : int ,
@@ -1810,9 +1844,15 @@ def get_rope(
18101844 mixed_b )
18111845 elif scaling_type == "dynamic" :
18121846 scaling_factor = rope_scaling ["factor" ]
1813- rotary_emb = DynamicNTKScalingRotaryEmbedding (
1814- head_size , rotary_dim , max_position , base , is_neox_style ,
1815- scaling_factor , dtype )
1847+ scaling_alpha = rope_scaling ["alpha" ]
1848+ if scaling_alpha :
1849+ rotary_emb = DynamicNTKAlphaRotaryEmbedding (
1850+ head_size , rotary_dim , max_position , base , is_neox_style ,
1851+ scaling_alpha , dtype )
1852+ else :
1853+ rotary_emb = DynamicNTKScalingRotaryEmbedding (
1854+ head_size , rotary_dim , max_position , base , is_neox_style ,
1855+ scaling_factor , dtype )
18161856 elif scaling_type == "yarn" :
18171857 scaling_factor = rope_scaling ["factor" ]
18181858 original_max_position = rope_scaling [
0 commit comments