Skip to content

Commit f971595

Browse files
authored
[Relax] Operator and RoPE support for Llama4 (#18336)
Added LLama4 implementation, new rope implementation
1 parent 118e3b1 commit f971595

File tree

4 files changed

+335
-1
lines changed

4 files changed

+335
-1
lines changed

python/tvm/relax/expr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as _np # type: ignore
2323

2424
import tvm_ffi
25+
2526
import tvm.ir
2627
import tvm.relax
2728
from tvm import DataType
@@ -1153,6 +1154,9 @@ def const(
11531154
- bool maps to "bool"
11541155
- other using the same default rule as numpy.
11551156
"""
1157+
# Needed for bf16 and fp8 support (does not come with numpy)
1158+
import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel
1159+
11561160
if isinstance(value, (Number, (bool, list))):
11571161
value = _np.array(value, dtype=dtype)
11581162

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st
7575
return cos_freq, sin_freq, {freq_var: freq}
7676

7777

78+
def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
79+
s: tir.Var,
80+
d: tir.Var,
81+
d_range: int,
82+
theta: float,
83+
dtype: str,
84+
factor: float,
85+
low_freq_factor: float,
86+
high_freq_factor: float,
87+
original_max_position_embeddings: float,
88+
):
89+
"""Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
90+
orig_freq = tir.const(1, "float32") / tir.power(
91+
theta, 2 * (d // 2) / tir.const(d_range, "float32")
92+
)
93+
orig_freq_var = tir.Var("orig_freq", "float32")
94+
95+
llama4_inv_scaling_factor = 1.0 / factor
96+
97+
if high_freq_factor == low_freq_factor:
98+
wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
99+
threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32")
100+
101+
scaled_freq = tir.if_then_else(
102+
wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var
103+
)
104+
smoothed_freq = s * scaled_freq
105+
106+
else:
107+
# Original smooth interpolation logic
108+
inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)
109+
110+
llama4_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor
111+
llama4_beta = low_freq_factor * inv_diff_freq_factor
112+
smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - llama4_beta))
113+
smoothed_freq = s * (
114+
(1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var
115+
)
116+
117+
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
118+
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
119+
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
120+
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}
121+
122+
78123
def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
79124
s: tir.Var,
80125
d: tir.Var,
@@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
208253
high_freq_factor=rope_scaling["high_freq_factor"],
209254
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
210255
)
256+
if rope_scaling["rope_type"] == "llama4":
257+
return partial(
258+
rope_freq_llama4,
259+
factor=rope_scaling["factor"],
260+
low_freq_factor=rope_scaling["low_freq_factor"],
261+
high_freq_factor=rope_scaling["high_freq_factor"],
262+
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
263+
)
211264
if rope_scaling["rope_type"] == "longrope":
212265
return partial(
213266
rope_freq_longrope,
@@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
545598
if is_longrope_scaling:
546599
return fused_rope_longrope_scaling
547600
return fused_rope
601+
602+
603+
def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
604+
theta: float,
605+
scale: float,
606+
head_dim: int,
607+
num_q_heads: int,
608+
num_kv_heads: int,
609+
dtype: str,
610+
rope_scaling: Dict[str, Any],
611+
rotary_dim: Optional[int] = None,
612+
):
613+
"""Return the TIR function that computes Llama-style RoPE with q position map.
614+
615+
Parameters
616+
----------
617+
theta : float
618+
The theta value, or "base" in RoPE, which controls the frequency.
619+
620+
scale : float
621+
The RoPE scaling factor.
622+
623+
head_dim : int
624+
The number of features on each head.
625+
626+
num_q_heads : int
627+
The number of query heads.
628+
629+
num_kv_heads : int
630+
The number of key/value heads. It differs from `num_q_heads` in group-query attention.
631+
632+
dtype : str
633+
The dtype of qkv data.
634+
635+
rope_scaling : Dict
636+
The configuration of RoPE scaling.
637+
638+
rotary_dim : int
639+
The number of dimensions in the embedding that RoPE is applied to. By default, the
640+
rotary_dim is the same as head_dim.
641+
"""
642+
fused_heads = num_q_heads + num_kv_heads * 2
643+
if rotary_dim is None:
644+
rotary_dim = head_dim
645+
scale = tir.const(scale, "float32")
646+
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
647+
648+
def _rope( # pylint: disable=too-many-arguments
649+
x: T.Buffer,
650+
s: tir.Var,
651+
h: tir.Var,
652+
d: tir.Var,
653+
pos: tir.Var,
654+
ext_factors: Optional[T.Buffer] = None,
655+
):
656+
kwargs = {}
657+
if ext_factors:
658+
kwargs["ext_factors"] = ext_factors
659+
cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
660+
pos * scale, d, rotary_dim, theta, "float32", **kwargs
661+
)
662+
cos = cos_freq * x[s, h, d].astype("float32")
663+
if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
664+
sin = sin_freq * tir.if_then_else(
665+
d % 2 == 0,
666+
-x[s, h, d + 1],
667+
x[s, h, d - 1],
668+
).astype("float32")
669+
else:
670+
# Data layout is different for llama4 vs llama3
671+
sin = sin_freq * tir.if_then_else(
672+
d % 2 == 0,
673+
-x[s, h, d + 1],
674+
x[s, h, d - 1],
675+
).astype("float32")
676+
expr = (cos + sin).astype(dtype)
677+
for var, value in var_map.items():
678+
expr = tir.Let(var, value, expr)
679+
return expr
680+
681+
@T.prim_func(private=True)
682+
def fused_rope( # pylint: disable=too-many-locals
683+
var_qkv: T.handle,
684+
var_position_map: T.handle,
685+
var_q: T.handle,
686+
var_k: T.handle,
687+
var_v: T.handle,
688+
apply_rope: T.int64,
689+
):
690+
T.func_attr(
691+
{
692+
"op_pattern": 8, # 2 means injective, 8 means opaque
693+
"tir.noalias": True,
694+
}
695+
)
696+
seq_len = T.int32()
697+
position_map_elem_offset = T.int32()
698+
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
699+
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
700+
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
701+
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
702+
position_map = T.match_buffer(
703+
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
704+
)
705+
for iters in T.grid(seq_len, fused_heads, head_dim):
706+
with T.block("llama_fused_rope"):
707+
s, h, d = T.axis.remap("SSS", iters)
708+
if h < num_q_heads:
709+
q[s, h, d] = T.if_then_else(
710+
apply_rope > 0 and d < rotary_dim,
711+
_rope(qkv, s, h, d, position_map[s]),
712+
qkv[s, h, d],
713+
)
714+
elif h < num_q_heads + num_kv_heads:
715+
k[s, h - num_q_heads, d] = T.if_then_else(
716+
apply_rope > 0 and d < rotary_dim,
717+
_rope(qkv, s, h, d, position_map[s]),
718+
qkv[s, h, d],
719+
)
720+
else:
721+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
722+
723+
@T.prim_func
724+
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
725+
var_qkv: T.handle,
726+
var_position_map: T.handle,
727+
var_q: T.handle,
728+
var_k: T.handle,
729+
var_v: T.handle,
730+
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
731+
):
732+
T.func_attr(
733+
{
734+
"op_pattern": 8, # 2 means injective, 8 means opaque
735+
"tir.noalias": True,
736+
}
737+
)
738+
seq_len = T.int64()
739+
position_map_elem_offset = T.int64()
740+
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
741+
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
742+
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
743+
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
744+
position_map = T.match_buffer(
745+
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
746+
)
747+
for iters in T.grid(seq_len, fused_heads, head_dim):
748+
with T.block("llama_fused_rope"):
749+
s, h, d = T.axis.remap("SSS", iters)
750+
if h < num_q_heads:
751+
q[s, h, d] = T.if_then_else(
752+
d < rotary_dim,
753+
_rope(
754+
qkv,
755+
s,
756+
h,
757+
d,
758+
position_map[s],
759+
ext_factors if is_longrope_scaling else None,
760+
),
761+
qkv[s, h, d],
762+
)
763+
elif h < num_q_heads + num_kv_heads:
764+
k[s, h - num_q_heads, d] = T.if_then_else(
765+
d < rotary_dim,
766+
_rope(
767+
qkv,
768+
s,
769+
h,
770+
d,
771+
position_map[s],
772+
ext_factors if is_longrope_scaling else None,
773+
),
774+
qkv[s, h, d],
775+
)
776+
else:
777+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
778+
779+
if is_longrope_scaling:
780+
return fused_rope_longrope_scaling
781+
return fused_rope

python/tvm/relax/frontend/nn/op.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor:
11741174
return wrap_nested(_op.exp(x._expr), name)
11751175

11761176

1177+
def log(x: Tensor, name: str = "log") -> Tensor:
1178+
r"""Applies the natural logarithm function.
1179+
1180+
.. math::
1181+
\text{Log}(x) = \log(x)
1182+
1183+
Parameters
1184+
----------
1185+
x : Tensor
1186+
The input data to the operator.
1187+
1188+
name : str
1189+
Name hint.
1190+
1191+
Returns
1192+
-------
1193+
result : Tensor
1194+
The computed result.
1195+
Note
1196+
----
1197+
The input tensor is required to have float dtype
1198+
"""
1199+
return wrap_nested(_op.log(x._expr), name)
1200+
1201+
1202+
def floor(x: Tensor, name: str = "floor") -> Tensor:
1203+
r"""Computes the floor of the input tensor.
1204+
1205+
.. math::
1206+
\text{Floor}(x) = \floor(x)
1207+
1208+
Parameters
1209+
----------
1210+
x : Tensor
1211+
The input data to the operator.
1212+
1213+
name : str
1214+
Name hint.
1215+
1216+
Returns
1217+
-------
1218+
result : Tensor
1219+
The computed result.
1220+
1221+
Note
1222+
----
1223+
The input tensor is required to have float dtype
1224+
"""
1225+
return wrap_nested(_op.floor(x._expr), name)
1226+
1227+
1228+
def arange(
1229+
start: int,
1230+
end: Optional[int] = None,
1231+
step: int = 1,
1232+
dtype: Optional[str] = "float32",
1233+
name: str = "arange",
1234+
) -> Tensor:
1235+
r"""Construct a tensor with evenly spaced elements.
1236+
1237+
Parameters
1238+
----------
1239+
start : int
1240+
The start of the interval.
1241+
1242+
end : Optional[int]
1243+
The end of the interval. If not given, it will be set to start,
1244+
and start will be set to 0.
1245+
1246+
step : int
1247+
The step size.
1248+
1249+
dtype : Optional[str]
1250+
The data type of the created tensor.
1251+
1252+
name : str
1253+
Name hint.
1254+
1255+
Returns
1256+
-------
1257+
result : Tensor
1258+
The computed result.
1259+
"""
1260+
return wrap_nested(_op.arange(start, end, step, dtype), name)
1261+
1262+
11771263
def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor:
11781264
"""Permutes the dimensions of the input tensor.
11791265

0 commit comments

Comments
 (0)