Skip to content

Commit

Permalink
10% faster RoPE embedding from HuyNguyen-hust (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
HuyNguyen-hust authored Mar 15, 2024
1 parent 9ac4ed6 commit 809bdbe
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions unsloth/kernels/rope_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _rope_embedding(
Q, Q_row_stride,
cos, cos_row_stride,
sin, sin_row_stride,
seqlen, head_dim,
seqlen, head_dim, group_size, n_heads,
BACKWARD_PASS: tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
Expand All @@ -34,7 +34,7 @@ def _rope_embedding(
See our blog post for more info
"""
row_position = tl.program_id(0)
head_position = tl.program_id(1)
group_head_position = tl.program_id(1)
col_offsets = tl.arange(0, BLOCK_SIZE)
half_head_dim = head_dim // 2
mask = col_offsets < half_head_dim
Expand All @@ -44,23 +44,25 @@ def _rope_embedding(
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
half_head_dim*0 + col_offsets, mask = mask, other = 0)

# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*0 + col_offsets, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*1 + col_offsets, mask = mask, other = 0).to(sin1.dtype)

if BACKWARD_PASS:
# See our blog post for more info.
sin1 = -sin1
pass

tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*0 + col_offsets,
Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + row_position*Q_row_stride + head_position*head_dim + \
half_head_dim*1 + col_offsets,
Q2*cos1 + Q1*sin1, mask = mask)
head_start = group_head_position * group_size
head_end = tl.math.min((head_start + group_size), n_heads)

for i in range(head_start, head_end):
offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets
offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim

# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)

tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
pass
pass


Expand All @@ -75,12 +77,16 @@ def forward(ctx, Q, cos, sin):

# [TODO] Changing blocksize to head_dim//2 seems to have
# some concurrency / un-deterministic issues.
BLOCK_SIZE, num_warps = calculate_settings(head_dim) # (head_dim//2)
_rope_embedding[(n_rows, n_heads,)](
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
group_size = 4 # 4 or 8, too large group_size can hurt performance.
n_groups = triton.cdiv(n_heads, group_size)

grid = (n_rows, n_groups, )
_rope_embedding[grid](
Q, Q.stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim,
seq_len, head_dim, group_size, n_heads,
BACKWARD_PASS = False,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
Expand All @@ -102,11 +108,15 @@ def backward(ctx, dY):
cos = ctx.cos
sin = ctx.sin

_rope_embedding[(n_rows, n_heads,)](
group_size = 4 # 4 or 8, too large group_size can hurt performance.
n_groups = triton.cdiv(n_heads, group_size)

grid = (n_rows, n_groups, )
_rope_embedding[grid](
dY, dY .stride(0),
cos, cos.stride(0),
sin, sin.stride(0),
seq_len, head_dim,
seq_len, head_dim, group_size, n_heads,
BACKWARD_PASS = True,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
Expand Down

0 comments on commit 809bdbe

Please sign in to comment.