From d7b687414868b351ab8c36a390918a066ae7c666 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Mon, 11 Aug 2025 17:52:51 -0400 Subject: [PATCH 1/3] changes to support llama4; Note: cleanup needed custom rope for llama4 custom rope for llama4 and ops cleanup rope dead code remove debugging cases reformat moved ml_dtypes import lint fix lint fix remove imports fix tests black format undo lint for 2 lines --- python/tvm/relax/expr.py | 4 + .../frontend/nn/llm/position_embedding.py | 236 +++++++++++++++++- python/tvm/relax/frontend/nn/op.py | 86 +++++++ tests/python/relax/test_frontend_nn_op.py | 12 +- 4 files changed, 336 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 1a7a5c224add..8dd4eff5c703 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -22,6 +22,7 @@ import numpy as _np # type: ignore import tvm_ffi + import tvm.ir import tvm.relax from tvm import DataType @@ -1153,6 +1154,9 @@ def const( - bool maps to "bool" - other using the same default rule as numpy. """ + # Needed for bf16 and fp8 support (does not come with numpy) + import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel + if isinstance(value, (Number, (bool, list))): value = _np.array(value, dtype=dtype) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 1a1659b29e18..dfe2c83a3f4a 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st return cos_freq, sin_freq, {freq_var: freq} +def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: float, +): + """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" + orig_freq = tir.const(1, "float32") / tir.power( + theta, 2 * (d // 2) / tir.const(d_range, "float32") + ) + orig_freq_var = tir.Var("orig_freq", "float32") + + llama3_inv_scaling_factor = 1.0 / factor + + if high_freq_factor == low_freq_factor: + wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var + threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") + + scaled_freq = tir.if_then_else( + wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var + ) + smoothed_freq = s * scaled_freq + + else: + # Original smooth interpolation logic + inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) + + llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama3_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smoothed_freq = s * ( + (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + ) + + smoothed_freq_var = tir.Var("smoothed_freq", "float32") + cos_freq = tir.cos(smoothed_freq_var).astype(dtype) + sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + + def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, d: tir.Var, @@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: high_freq_factor=rope_scaling["high_freq_factor"], original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], ) + if rope_scaling["rope_type"] == "llama4": + return partial( + rope_freq_llama4, + factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) if rope_scaling["rope_type"] == "longrope": return partial( rope_freq_longrope, @@ -444,14 +497,195 @@ def _rope( # pylint: disable=too-many-arguments expr = tir.Let(var, value, expr) return expr + @T.prim_func(private=True) + def fused_rope( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + apply_rope: T.int64, + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int32() + position_map_elem_offset = T.int32() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + apply_rope > 0 and d < rotary_dim, + _rope(qkv, s, h, d, position_map[s]), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + @T.prim_func + def fused_rope_longrope_scaling( # pylint: disable=too-many-locals + var_qkv: T.handle, + var_position_map: T.handle, + var_q: T.handle, + var_k: T.handle, + var_v: T.handle, + ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore + ): + T.func_attr( + { + "op_pattern": 8, # 2 means injective, 8 means opaque + "tir.noalias": True, + } + ) + seq_len = T.int64() + position_map_elem_offset = T.int64() + qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype) + q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype) + k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype) + v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype) + position_map = T.match_buffer( + var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset + ) + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.block("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + ext_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + + if is_longrope_scaling: + return fused_rope_longrope_scaling + return fused_rope + + +def llama4_rope_with_position_map( # pylint: disable=too-many-arguments + theta: float, + scale: float, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + dtype: str, + rope_scaling: Dict[str, Any], + rotary_dim: Optional[int] = None, +): + """Return the TIR function that computes Llama-style RoPE with q position map. + + Parameters + ---------- + theta : float + The theta value, or "base" in RoPE, which controls the frequency. + + scale : float + The RoPE scaling factor. + + head_dim : int + The number of features on each head. + + num_q_heads : int + The number of query heads. + + num_kv_heads : int + The number of key/value heads. It differs from `num_q_heads` in group-query attention. + + dtype : str + The dtype of qkv data. + + rope_scaling : Dict + The configuration of RoPE scaling. + + rotary_dim : int + The number of dimensions in the embedding that RoPE is applied to. By default, the + rotary_dim is the same as head_dim. + """ + fused_heads = num_q_heads + num_kv_heads * 2 + if rotary_dim is None: + rotary_dim = head_dim + scale = tir.const(scale, "float32") + is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" + + def _rope( # pylint: disable=too-many-arguments + x: T.Buffer, + s: tir.Var, + h: tir.Var, + d: tir.Var, + pos: tir.Var, + ext_factors: Optional[T.Buffer] = None, + ): + kwargs = {} + if ext_factors: + kwargs["ext_factors"] = ext_factors + cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( + pos * scale, d, rotary_dim, theta, "float32", **kwargs + ) + cos = cos_freq * x[s, h, d].astype("float32") + if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") + else: + # Data layout is different for llama4 vs llama3 + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") + expr = (cos + sin).astype(dtype) + for var, value in var_map.items(): + expr = tir.Let(var, value, expr) + return expr + + @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, - apply_rope: T.int32, + apply_rope: T.int64, ): T.func_attr( { diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 714ae9478250..50d4772d8ca1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor: return wrap_nested(_op.exp(x._expr), name) +def log(x: Tensor, name: str = "log") -> Tensor: + r"""Applies the natural logarithm function. + + .. math:: + \text{Log}(x) = \log(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.log(x._expr), name) + + +def floor(x: Tensor, name: str = "floor") -> Tensor: + r"""Computes the floor of the input tensor. + + .. math:: + \text{Floor}(x) = \floor(x) + + Parameters + ---------- + x : Tensor + The input data to the operator. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return wrap_nested(_op.floor(x._expr), name) + + +def arange( + start: int, + end: Optional[int] = None, + step: int = 1, + dtype: Optional[str] = "float32", + name: str = "arange", +) -> Tensor: + r"""Construct a tensor with evenly spaced elements. + + Parameters + ---------- + start : int + The start of the interval. + + end : Optional[int] + The end of the interval. If not given, it will be set to start, + and start will be set to 0. + + step : int + The step size. + + dtype : Optional[str] + The data type of the created tensor. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.arange(start, end, step, dtype), name) + + def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor: """Permutes the dimensions of the input tensor. diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index e827f643b33c..28c11f6dfaf5 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -384,6 +384,8 @@ def test( def test_nn(): class Model(Module): def test(self, x: Tensor, weight: Tensor, bias: Tensor): + log_out = op.log(x) + floor_out = op.floor(x) relu_out = op.relu(x) relu6_out = op.relu6(x) silu_out = op.silu(x) @@ -409,6 +411,8 @@ def test( ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)): R.func_attr({"num_input": 4}) with R.dataflow(): + log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x) + floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x) relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) @@ -463,6 +467,8 @@ def test(self, x: Tensor): ) zeros_out = op.zeros([10, 10]) zeros_fp16_out = op.zeros([10, 10], dtype="float16") + + arange_out = op.arange(0, 10, 1, "float32") return x # fmt: off @@ -476,6 +482,7 @@ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Ten full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 10]), dtype="float32") zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 10]), dtype="float16") + arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0), T.int64(10), T.int64(1), dtype="float32") gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) R.output(gv1) return gv1 @@ -504,7 +511,10 @@ def test( lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32") lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1]) lv3: R.Tensor((5,), dtype="float32") = R.arange( - R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="float32" + R.prim_value(T.int64(0)), + R.prim_value(T.int64(5)), + R.prim_value(T.int64(1)), + dtype="float32", ) lv4: R.Tensor((5,), dtype="float32") = R.multiply( R.const(-9.2103404998779297, "float32"), lv3 From 34728a1265abd524a02354d24f17024ccda6ee1d Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Mon, 22 Sep 2025 23:08:12 -0400 Subject: [PATCH 2/3] remove private=true --- python/tvm/relax/frontend/nn/llm/position_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index dfe2c83a3f4a..269ced3f4922 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -497,14 +497,14 @@ def _rope( # pylint: disable=too-many-arguments expr = tir.Let(var, value, expr) return expr - @T.prim_func(private=True) + @T.prim_func def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, - apply_rope: T.int64, + apply_rope: T.int32, ): T.func_attr( { From 7459e299d44e5ab2c6a35fa9142dd231d6442b47 Mon Sep 17 00:00:00 2001 From: Pranav Venkatram Date: Tue, 23 Sep 2025 10:20:09 -0400 Subject: [PATCH 3/3] correct rope variable names for llama4 --- .../tvm/relax/frontend/nn/llm/position_embedding.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 269ced3f4922..6fda4b0bca62 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -86,13 +86,13 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals high_freq_factor: float, original_max_position_embeddings: float, ): - """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" + """Compute the inverse frequency of RoPE for llama4 RoPE scaling.""" orig_freq = tir.const(1, "float32") / tir.power( theta, 2 * (d // 2) / tir.const(d_range, "float32") ) orig_freq_var = tir.Var("orig_freq", "float32") - llama3_inv_scaling_factor = 1.0 / factor + llama4_inv_scaling_factor = 1.0 / factor if high_freq_factor == low_freq_factor: wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var @@ -107,11 +107,11 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals # Original smooth interpolation logic inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) - llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor - llama3_beta = low_freq_factor * inv_diff_freq_factor - smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + llama4_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor + llama4_beta = low_freq_factor * inv_diff_freq_factor + smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - llama4_beta)) smoothed_freq = s * ( - (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var + (1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var ) smoothed_freq_var = tir.Var("smoothed_freq", "float32")