Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as _np # type: ignore

import tvm_ffi

import tvm.ir
import tvm.relax
from tvm import DataType
Expand Down Expand Up @@ -1153,6 +1154,9 @@ def const(
- bool maps to "bool"
- other using the same default rule as numpy.
"""
# Needed for bf16 and fp8 support (does not come with numpy)
import ml_dtypes # pylint: disable=unused-import,import-outside-toplevel

if isinstance(value, (Number, (bool, list))):
value = _np.array(value, dtype=dtype)

Expand Down
234 changes: 234 additions & 0 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,51 @@ def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: st
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
d_range: int,
theta: float,
dtype: str,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: float,
):
"""Compute the inverse frequency of RoPE for llama3 RoPE scaling."""
orig_freq = tir.const(1, "float32") / tir.power(
theta, 2 * (d // 2) / tir.const(d_range, "float32")
)
orig_freq_var = tir.Var("orig_freq", "float32")

llama3_inv_scaling_factor = 1.0 / factor

if high_freq_factor == low_freq_factor:
wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var
threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32")

scaled_freq = tir.if_then_else(
wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var
)
smoothed_freq = s * scaled_freq

else:
# Original smooth interpolation logic
inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor)

llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor
llama3_beta = low_freq_factor * inv_diff_freq_factor
smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta))
smoothed_freq = s * (
(1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var
)

smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}


def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
s: tir.Var,
d: tir.Var,
Expand Down Expand Up @@ -208,6 +253,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "llama4":
return partial(
rope_freq_llama4,
factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "longrope":
return partial(
rope_freq_longrope,
Expand Down Expand Up @@ -545,3 +598,184 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope


def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
theta: float,
scale: float,
head_dim: int,
num_q_heads: int,
num_kv_heads: int,
dtype: str,
rope_scaling: Dict[str, Any],
rotary_dim: Optional[int] = None,
):
"""Return the TIR function that computes Llama-style RoPE with q position map.

Parameters
----------
theta : float
The theta value, or "base" in RoPE, which controls the frequency.

scale : float
The RoPE scaling factor.

head_dim : int
The number of features on each head.

num_q_heads : int
The number of query heads.

num_kv_heads : int
The number of key/value heads. It differs from `num_q_heads` in group-query attention.

dtype : str
The dtype of qkv data.

rope_scaling : Dict
The configuration of RoPE scaling.

rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to. By default, the
rotary_dim is the same as head_dim.
"""
fused_heads = num_q_heads + num_kv_heads * 2
if rotary_dim is None:
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"

def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
s: tir.Var,
h: tir.Var,
d: tir.Var,
pos: tir.Var,
ext_factors: Optional[T.Buffer] = None,
):
kwargs = {}
if ext_factors:
kwargs["ext_factors"] = ext_factors
cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
pos * scale, d, rotary_dim, theta, "float32", **kwargs
)
cos = cos_freq * x[s, h, d].astype("float32")
if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj":
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
else:
# Data layout is different for llama4 vs llama3
sin = sin_freq * tir.if_then_else(
d % 2 == 0,
-x[s, h, d + 1],
x[s, h, d - 1],
).astype("float32")
expr = (cos + sin).astype(dtype)
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
return expr

@T.prim_func(private=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
apply_rope: T.int64,
):
T.func_attr(
{
"op_pattern": 8, # 2 means injective, 8 means opaque
"tir.noalias": True,
}
)
seq_len = T.int32()
position_map_elem_offset = T.int32()
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
apply_rope > 0 and d < rotary_dim,
_rope(qkv, s, h, d, position_map[s]),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
apply_rope > 0 and d < rotary_dim,
_rope(qkv, s, h, d, position_map[s]),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

@T.prim_func
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
):
T.func_attr(
{
"op_pattern": 8, # 2 means injective, 8 means opaque
"tir.noalias": True,
}
)
seq_len = T.int64()
position_map_elem_offset = T.int64()
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope
86 changes: 86 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,92 @@ def exp(x: Tensor, name: str = "exp") -> Tensor:
return wrap_nested(_op.exp(x._expr), name)


def log(x: Tensor, name: str = "log") -> Tensor:
r"""Applies the natural logarithm function.

.. math::
\text{Log}(x) = \log(x)

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.log(x._expr), name)


def floor(x: Tensor, name: str = "floor") -> Tensor:
r"""Computes the floor of the input tensor.

.. math::
\text{Floor}(x) = \floor(x)

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.

Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.floor(x._expr), name)


def arange(
start: int,
end: Optional[int] = None,
step: int = 1,
dtype: Optional[str] = "float32",
name: str = "arange",
) -> Tensor:
r"""Construct a tensor with evenly spaced elements.

Parameters
----------
start : int
The start of the interval.

end : Optional[int]
The end of the interval. If not given, it will be set to start,
and start will be set to 0.

step : int
The step size.

dtype : Optional[str]
The data type of the created tensor.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.arange(start, end, step, dtype), name)


def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor:
"""Permutes the dimensions of the input tensor.

Expand Down
Loading