Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 92 additions & 40 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
is_neox_style: tl.constexpr,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
Expand Down Expand Up @@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
if is_neox_style:
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)

q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)

# right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
# right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask

q_tile_2 = tl.load(
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
).to(sin_row.dtype)
k_tile_2 = tl.load(
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
).to(sin_row.dtype)

# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)

new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
odd_idx = even_idx + 1

even_q_offsets = base_q + even_idx
odd_q_offsets = base_q + odd_idx
even_k_offsets = base_k + even_idx
odd_k_offsets = base_k + odd_idx

idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh

even_q_mask = qn_mask & idx_mask
odd_q_mask = qn_mask & idx_mask
even_k_mask = kn_mask & idx_mask
odd_k_mask = kn_mask & idx_mask

q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
sin_row.dtype
)

q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
sin_row.dtype
)

# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# NeoX-style rotary embedding:
# Each (even, odd) channel pair forms one rotation arm.
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)

new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)


def triton_mrope(
Expand All @@ -1180,6 +1229,7 @@ def triton_mrope(
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""The mrope triton kernel.

Expand Down Expand Up @@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section[1],
mrope_section[2],
mrope_interleaved,
is_neox_style,
)
return q, k

Expand Down Expand Up @@ -1400,6 +1451,7 @@ def _forward_triton(
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
self.is_neox_style,
)

return q.reshape(query_shape), k.reshape(key_shape)
Expand Down
Loading
Loading