Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as _np # type: ignore

import tvm_ffi

import tvm.ir
import tvm.relax
from tvm import DataType
Expand Down Expand Up @@ -1153,6 +1154,9 @@ def const(
- bool maps to "bool"
- other using the same default rule as numpy.
"""
# Needed for bf16 and fp8 support (does not come with numpy)
import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel

if isinstance(value, (Number, (bool, list))):
value = _np.array(value, dtype=dtype)

Expand Down
234 changes: 234 additions & 0 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
d_range: int,
theta: float,
dtype: str,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: float,
):
"""Compute the inverse frequency of RoPE for llama4 RoPE scaling."""
orig_freq = tir.const(1, "float32") / tir.power(
theta, 2 * (d // 2) / tir.const(d_range, "float32")
)
orig_freq_var = tir.Var("orig_freq", "float32")

llama4_inv_scaling_factor = 1.0 / factor

if high_freq_factor == low_freq_factor:
wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32")

scaled_freq = tir.if_then_else(
wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var
)
smoothed_freq = s * scaled_freq

else:
# Original smooth interpolation logic
inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)

llama4_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor
llama4_beta = low_freq_factor * inv_diff_freq_factor
smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - llama4_beta))
smoothed_freq = s * (
(1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var
)

smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}


def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
Expand Down Expand Up @@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "llama4":
return partial(
rope_freq_llama4,
factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "longrope":
return partial(
rope_freq_longrope,
Expand Down Expand Up @@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope


def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
theta: float,
scale: float,
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
dtype: str,
rope_scaling: Dict[str, Any],
rotary_dim: Optional[int] = None,
):
"""Return the TIR function that computes Llama-style RoPE with q position map.

Parameters
----------
theta : float
The theta value, or "base" in RoPE, which controls the frequency.

scale : float
The RoPE scaling factor.

head_dim : int
The number of features on each head.

num_q_heads : int
The number of query heads.

num_kv_heads : int
The number of key/value heads. It differs from `num_q_heads` in group-query attention.

dtype : str
The dtype of qkv data.

rope_scaling : Dict
The configuration of RoPE scaling.

rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to. By default, the
rotary_dim is the same as head_dim.
"""
fused_heads = num_q_heads + num_kv_heads * 2
if rotary_dim is None:
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"

def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
s: tir.Var,
h: tir.Var,
d: tir.Var,
pos: tir.Var,
ext_factors: Optional[T.Buffer] = None,
):
kwargs = {}
if ext_factors:
kwargs["ext_factors"] = ext_factors
cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
pos * scale, d, rotary_dim, theta, "float32", **kwargs
)
cos = cos_freq * x[s, h, d].astype("float32")
if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
else:
# Data layout is different for llama4 vs llama3
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
expr = (cos + sin).astype(dtype)
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
return expr

@T.prim_func(private=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
apply_rope: T.int64,
):
T.func_attr(
{
"op_pattern": 8, # 2 means injective, 8 means opaque
"tir.noalias": True,
}
)
seq_len = T.int32()
position_map_elem_offset = T.int32()
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
apply_rope > 0 and d < rotary_dim,
_rope(qkv, s, h, d, position_map[s]),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
apply_rope > 0 and d < rotary_dim,
_rope(qkv, s, h, d, position_map[s]),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

@T.prim_func
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
):
T.func_attr(
{
"op_pattern": 8, # 2 means injective, 8 means opaque
"tir.noalias": True,
}
)
seq_len = T.int64()
position_map_elem_offset = T.int64()
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
86 changes: 86 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor:
return wrap_nested(_op.exp(x._expr), name)


def log(x: Tensor, name: str = "log") -> Tensor:
r"""Applies the natural logarithm function.

.. math::
\text{Log}(x) = \log(x)

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.log(x._expr), name)


def floor(x: Tensor, name: str = "floor") -> Tensor:
r"""Computes the floor of the input tensor.

.. math::
\text{Floor}(x) = \floor(x)

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.

Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.floor(x._expr), name)


def arange(
start: int,
end: Optional[int] = None,
step: int = 1,
dtype: Optional[str] = "float32",
name: str = "arange",
) -> Tensor:
r"""Construct a tensor with evenly spaced elements.

Parameters
----------
start : int
The start of the interval.

end : Optional[int]
The end of the interval. If not given, it will be set to start,
and start will be set to 0.

step : int
The step size.

dtype : Optional[str]
The data type of the created tensor.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.arange(start, end, step, dtype), name)


def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor:
"""Permutes the dimensions of the input tensor.

Expand Down
Loading