@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st
7575 return cos_freq , sin_freq , {freq_var : freq }
7676
7777
78+ def rope_freq_llama4 ( # pylint: disable=too-many-arguments,too-many-locals
79+ s : tir .Var ,
80+ d : tir .Var ,
81+ d_range : int ,
82+ theta : float ,
83+ dtype : str ,
84+ factor : float ,
85+ low_freq_factor : float ,
86+ high_freq_factor : float ,
87+ original_max_position_embeddings : float ,
88+ ):
89+ """Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
90+ orig_freq = tir .const (1 , "float32" ) / tir .power (
91+ theta , 2 * (d // 2 ) / tir .const (d_range , "float32" )
92+ )
93+ orig_freq_var = tir .Var ("orig_freq" , "float32" )
94+
95+ llama4_inv_scaling_factor = 1.0 / factor
96+
97+ if high_freq_factor == low_freq_factor :
98+ wavelength = tir .const (2 * math .pi , "float32" ) / orig_freq_var
99+ threshold_wavelen = tir .const (original_max_position_embeddings / low_freq_factor , "float32" )
100+
101+ scaled_freq = tir .if_then_else (
102+ wavelength > threshold_wavelen , orig_freq_var / factor , orig_freq_var
103+ )
104+ smoothed_freq = s * scaled_freq
105+
106+ else :
107+ # Original smooth interpolation logic
108+ inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor )
109+
110+ llama4_alpha = original_max_position_embeddings / (2 * math .pi ) * inv_diff_freq_factor
111+ llama4_beta = low_freq_factor * inv_diff_freq_factor
112+ smooth = tir .max (0.0 , tir .min (1.0 , llama4_alpha * orig_freq_var - llama4_beta ))
113+ smoothed_freq = s * (
114+ (1.0 - smooth ) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var
115+ )
116+
117+ smoothed_freq_var = tir .Var ("smoothed_freq" , "float32" )
118+ cos_freq = tir .cos (smoothed_freq_var ).astype (dtype )
119+ sin_freq = tir .sin (smoothed_freq_var ).astype (dtype )
120+ return cos_freq , sin_freq , {smoothed_freq_var : smoothed_freq , orig_freq_var : orig_freq }
121+
122+
78123def rope_freq_llama3 ( # pylint: disable=too-many-arguments,too-many-locals
79124 s : tir .Var ,
80125 d : tir .Var ,
@@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
208253 high_freq_factor = rope_scaling ["high_freq_factor" ],
209254 original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ],
210255 )
256+ if rope_scaling ["rope_type" ] == "llama4" :
257+ return partial (
258+ rope_freq_llama4 ,
259+ factor = rope_scaling ["factor" ],
260+ low_freq_factor = rope_scaling ["low_freq_factor" ],
261+ high_freq_factor = rope_scaling ["high_freq_factor" ],
262+ original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ],
263+ )
211264 if rope_scaling ["rope_type" ] == "longrope" :
212265 return partial (
213266 rope_freq_longrope ,
@@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
545598 if is_longrope_scaling :
546599 return fused_rope_longrope_scaling
547600 return fused_rope
601+
602+
603+ def llama4_rope_with_position_map ( # pylint: disable=too-many-arguments
604+ theta : float ,
605+ scale : float ,
606+ head_dim : int ,
607+ num_q_heads : int ,
608+ num_kv_heads : int ,
609+ dtype : str ,
610+ rope_scaling : Dict [str , Any ],
611+ rotary_dim : Optional [int ] = None ,
612+ ):
613+ """Return the TIR function that computes Llama-style RoPE with q position map.
614+
615+ Parameters
616+ ----------
617+ theta : float
618+ The theta value, or "base" in RoPE, which controls the frequency.
619+
620+ scale : float
621+ The RoPE scaling factor.
622+
623+ head_dim : int
624+ The number of features on each head.
625+
626+ num_q_heads : int
627+ The number of query heads.
628+
629+ num_kv_heads : int
630+ The number of key/value heads. It differs from `num_q_heads` in group-query attention.
631+
632+ dtype : str
633+ The dtype of qkv data.
634+
635+ rope_scaling : Dict
636+ The configuration of RoPE scaling.
637+
638+ rotary_dim : int
639+ The number of dimensions in the embedding that RoPE is applied to. By default, the
640+ rotary_dim is the same as head_dim.
641+ """
642+ fused_heads = num_q_heads + num_kv_heads * 2
643+ if rotary_dim is None :
644+ rotary_dim = head_dim
645+ scale = tir .const (scale , "float32" )
646+ is_longrope_scaling = rope_scaling .get ("rope_type" ) == "longrope"
647+
648+ def _rope ( # pylint: disable=too-many-arguments
649+ x : T .Buffer ,
650+ s : tir .Var ,
651+ h : tir .Var ,
652+ d : tir .Var ,
653+ pos : tir .Var ,
654+ ext_factors : Optional [T .Buffer ] = None ,
655+ ):
656+ kwargs = {}
657+ if ext_factors :
658+ kwargs ["ext_factors" ] = ext_factors
659+ cos_freq , sin_freq , var_map = switch_rope_freq_func (rope_scaling )(
660+ pos * scale , d , rotary_dim , theta , "float32" , ** kwargs
661+ )
662+ cos = cos_freq * x [s , h , d ].astype ("float32" )
663+ if "rope_type" in rope_scaling and rope_scaling ["rope_type" ] == "gptj" :
664+ sin = sin_freq * tir .if_then_else (
665+ d % 2 == 0 ,
666+ - x [s , h , d + 1 ],
667+ x [s , h , d - 1 ],
668+ ).astype ("float32" )
669+ else :
670+ # Data layout is different for llama4 vs llama3
671+ sin = sin_freq * tir .if_then_else (
672+ d % 2 == 0 ,
673+ - x [s , h , d + 1 ],
674+ x [s , h , d - 1 ],
675+ ).astype ("float32" )
676+ expr = (cos + sin ).astype (dtype )
677+ for var , value in var_map .items ():
678+ expr = tir .Let (var , value , expr )
679+ return expr
680+
681+ @T .prim_func (private = True )
682+ def fused_rope ( # pylint: disable=too-many-locals
683+ var_qkv : T .handle ,
684+ var_position_map : T .handle ,
685+ var_q : T .handle ,
686+ var_k : T .handle ,
687+ var_v : T .handle ,
688+ apply_rope : T .int64 ,
689+ ):
690+ T .func_attr (
691+ {
692+ "op_pattern" : 8 , # 2 means injective, 8 means opaque
693+ "tir.noalias" : True ,
694+ }
695+ )
696+ seq_len = T .int32 ()
697+ position_map_elem_offset = T .int32 ()
698+ qkv = T .match_buffer (var_qkv , (seq_len , fused_heads , head_dim ), dtype )
699+ q = T .match_buffer (var_q , (seq_len , num_q_heads , head_dim ), dtype )
700+ k = T .match_buffer (var_k , (seq_len , num_kv_heads , head_dim ), dtype )
701+ v = T .match_buffer (var_v , (seq_len , num_kv_heads , head_dim ), dtype )
702+ position_map = T .match_buffer (
703+ var_position_map , (seq_len ,), "int32" , elem_offset = position_map_elem_offset
704+ )
705+ for iters in T .grid (seq_len , fused_heads , head_dim ):
706+ with T .block ("llama_fused_rope" ):
707+ s , h , d = T .axis .remap ("SSS" , iters )
708+ if h < num_q_heads :
709+ q [s , h , d ] = T .if_then_else (
710+ apply_rope > 0 and d < rotary_dim ,
711+ _rope (qkv , s , h , d , position_map [s ]),
712+ qkv [s , h , d ],
713+ )
714+ elif h < num_q_heads + num_kv_heads :
715+ k [s , h - num_q_heads , d ] = T .if_then_else (
716+ apply_rope > 0 and d < rotary_dim ,
717+ _rope (qkv , s , h , d , position_map [s ]),
718+ qkv [s , h , d ],
719+ )
720+ else :
721+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
722+
723+ @T .prim_func
724+ def fused_rope_longrope_scaling ( # pylint: disable=too-many-locals
725+ var_qkv : T .handle ,
726+ var_position_map : T .handle ,
727+ var_q : T .handle ,
728+ var_k : T .handle ,
729+ var_v : T .handle ,
730+ ext_factors : T .Buffer ((rotary_dim // 2 ,), "float32" ), # type: ignore
731+ ):
732+ T .func_attr (
733+ {
734+ "op_pattern" : 8 , # 2 means injective, 8 means opaque
735+ "tir.noalias" : True ,
736+ }
737+ )
738+ seq_len = T .int64 ()
739+ position_map_elem_offset = T .int64 ()
740+ qkv = T .match_buffer (var_qkv , (seq_len , fused_heads , head_dim ), dtype )
741+ q = T .match_buffer (var_q , (seq_len , num_q_heads , head_dim ), dtype )
742+ k = T .match_buffer (var_k , (seq_len , num_kv_heads , head_dim ), dtype )
743+ v = T .match_buffer (var_v , (seq_len , num_kv_heads , head_dim ), dtype )
744+ position_map = T .match_buffer (
745+ var_position_map , (seq_len ,), "int32" , elem_offset = position_map_elem_offset
746+ )
747+ for iters in T .grid (seq_len , fused_heads , head_dim ):
748+ with T .block ("llama_fused_rope" ):
749+ s , h , d = T .axis .remap ("SSS" , iters )
750+ if h < num_q_heads :
751+ q [s , h , d ] = T .if_then_else (
752+ d < rotary_dim ,
753+ _rope (
754+ qkv ,
755+ s ,
756+ h ,
757+ d ,
758+ position_map [s ],
759+ ext_factors if is_longrope_scaling else None ,
760+ ),
761+ qkv [s , h , d ],
762+ )
763+ elif h < num_q_heads + num_kv_heads :
764+ k [s , h - num_q_heads , d ] = T .if_then_else (
765+ d < rotary_dim ,
766+ _rope (
767+ qkv ,
768+ s ,
769+ h ,
770+ d ,
771+ position_map [s ],
772+ ext_factors if is_longrope_scaling else None ,
773+ ),
774+ qkv [s , h , d ],
775+ )
776+ else :
777+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
778+
779+ if is_longrope_scaling :
780+ return fused_rope_longrope_scaling
781+ return fused_rope
0 commit comments