Skip to content

Commit 9621f28

Browse files
committed
fuse rope
1 parent a8f3839 commit 9621f28

File tree

5 files changed

+164
-41
lines changed

5 files changed

+164
-41
lines changed

Diff for: csrc/gpu/fused_rotary_position_encoding.cu

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
#include "paddle/extension.h"
17+
18+
template <typename T, bool IS_NEOX>
19+
inline __device__ void apply_token_rotary_embedding_kernel(
20+
T* __restrict__ arr,
21+
const T* __restrict__ cos_ptr,
22+
const T* __restrict__ sin_ptr,
23+
int rot_offset,
24+
int embed_dim) {
25+
int x_index, y_index;
26+
T cos, sin;
27+
if (IS_NEOX) {
28+
x_index = rot_offset;
29+
y_index = embed_dim + rot_offset;
30+
cos = cos_ptr[x_index];
31+
sin = sin_ptr[x_index];
32+
} else {
33+
x_index = 2 * rot_offset;
34+
y_index = 2 * rot_offset + 1;
35+
cos = cos_ptr[x_index / 2];
36+
sin = sin_ptr[x_index / 2];
37+
}
38+
39+
const T x = arr[x_index];
40+
const T y = arr[y_index];
41+
arr[x_index] = x * cos - y * sin;
42+
arr[y_index] = y * cos + x * sin;
43+
}
44+
45+
46+
template <typename T, bool IS_NEOX>
47+
__global__ void apply_rotary_embedding_kernel(
48+
T* __restrict__ query, // [num_tokens, num_heads, head_size]
49+
T* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
50+
const int* __restrict__ position_ids, // [num_tokens]
51+
const T* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
52+
const int rot_dim,
53+
const int64_t query_stride,
54+
const int64_t key_stride,
55+
const int num_heads,
56+
const int num_kv_heads,
57+
const int head_size) {
58+
// Each thread block is responsible for one token.
59+
const int token_idx = blockIdx.x;
60+
int pos = position_ids[token_idx];
61+
const T* cache_ptr = cos_sin_cache + pos * rot_dim;
62+
63+
const int embed_dim = rot_dim / 2;
64+
const T* cos_ptr = cache_ptr;
65+
const T* sin_ptr = cache_ptr + embed_dim;
66+
67+
const int nq = num_heads * embed_dim;
68+
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
69+
const int head_idx = i / embed_dim;
70+
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
71+
const int rot_offset = i % embed_dim;
72+
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
73+
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
74+
}
75+
76+
const int nk = num_kv_heads * embed_dim;
77+
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
78+
const int head_idx = i / embed_dim;
79+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
80+
const int rot_offset = i % embed_dim;
81+
apply_token_rotary_embedding_kernel<T, IS_NEOX>(
82+
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
83+
}
84+
}
85+
86+
87+
void FusedRotaryPositionEncoding(
88+
paddle::Tensor& query, // [num_tokens, num_heads, head_size] or
89+
// [num_tokens, num_heads * head_size]
90+
paddle::Tensor& key,
91+
// [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads *
92+
// head_size]
93+
const paddle::Tensor& position_ids, // [num_tokens]
94+
const paddle::Tensor& cos_sin_cache, // [max_position, rot_dim]
95+
int head_size,
96+
bool is_neox) {
97+
int64_t num_tokens = query.dims()[0];
98+
int num_heads = query.numel() / num_tokens / head_size;
99+
int num_kv_heads = key.numel() / num_tokens / head_size;
100+
int rot_dim = cos_sin_cache.dims()[1];
101+
int64_t query_stride = num_heads * head_size;
102+
int64_t key_stride = num_kv_heads * head_size;
103+
104+
dim3 grid(num_tokens);
105+
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
106+
PD_DISPATCH_FLOATING_TYPES(
107+
query.dtype(), "apply_rotary_embedding_kernel", [&] {
108+
if (is_neox) {
109+
apply_rotary_embedding_kernel<data_t, true>
110+
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
111+
key.data<data_t>(),
112+
position_ids.data<int>(),
113+
cos_sin_cache.data<data_t>(),
114+
rot_dim,
115+
query_stride,
116+
key_stride,
117+
num_heads,
118+
num_kv_heads,
119+
head_size);
120+
} else {
121+
apply_rotary_embedding_kernel<data_t, false>
122+
<<<grid, block, 0, query.stream()>>>(query.data<data_t>(),
123+
key.data<data_t>(),
124+
position_ids.data<int>(),
125+
cos_sin_cache.data<data_t>(),
126+
rot_dim,
127+
query_stride,
128+
key_stride,
129+
num_heads,
130+
num_kv_heads,
131+
head_size);
132+
}
133+
});
134+
}
135+
136+
PD_BUILD_OP(fused_rotary_position_encoding)
137+
.Inputs({"query", "key", "position_ids", "cos_sin_cache"})
138+
.Outputs({"query_out", "key_out"})
139+
.Attrs({"head_size: int", "is_neox: bool"})
140+
.SetInplaceMap({{"query", "query_out"}, {"key", "key_out"}})
141+
.SetKernelFn(PD_KERNEL(FusedRotaryPositionEncoding));

Diff for: csrc/setup_cuda.py

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def get_gencode_flags():
104104
"./gpu/quant_int8.cu",
105105
"./gpu/dequant_int8.cu",
106106
"./gpu/get_position_ids.cu",
107+
"./gpu/fused_rotary_position_encoding.cu",
107108
"./gpu/flash_attn_bwd.cc",
108109
"./gpu/tune_cublaslt_gemm.cu",
109110
"./gpu/sample_kernels/top_p_sampling_reject.cu",

Diff for: paddlenlp/experimental/transformers/deepseek_v2/modeling.py

+9-28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import os
1617
from functools import partial
1718
from typing import Tuple
1819

@@ -60,7 +61,6 @@ class DeepseekScalingRotaryEmbedding(nn.Layer):
6061

6162
def __init__(
6263
self,
63-
# head_size: int,
6464
rotary_dim: int,
6565
max_position_embeddings: int,
6666
base: int,
@@ -114,10 +114,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
114114
def _compute_cos_sin_cache(self) -> paddle.Tensor:
115115
inv_freq = self._compute_inv_freq(self.scaling_factor)
116116
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
117-
freqs = paddle.einsum("i,j -> ij", t, inv_freq)
118-
emb = paddle.concat((freqs, freqs), axis=-1)
119-
cos = emb.cos() * self.mscale
120-
sin = emb.sin() * self.mscale
117+
freqs = paddle.einsum("i,j->ij", t, inv_freq)
118+
cos = freqs.cos() * self.mscale
119+
sin = freqs.sin() * self.mscale
121120
cache = paddle.concat((cos, sin), axis=-1)
122121
return cache.cast(self._dtype)
123122

@@ -127,29 +126,17 @@ def forward(
127126
query: paddle.Tensor,
128127
key: paddle.Tensor,
129128
) -> Tuple[paddle.Tensor, paddle.Tensor]:
130-
cos_sin = self.cos_sin_cache[position_ids].unsqueeze(1)
131-
cos, sin = cos_sin.chunk(2, axis=-1)
129+
from paddlenlp_ops import fused_rotary_position_encoding
132130

133-
s, h, d = query.shape
134-
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
135-
136-
s, h, d = key.shape
137-
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
138-
139-
def rotate_half(x):
140-
"""Rotates half the hidden axiss of the input."""
141-
x1 = x[..., : x.shape[-1] // 2]
142-
x2 = x[..., x.shape[-1] // 2 :]
143-
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
144-
145-
query = (query * cos) + (rotate_half(query) * sin)
146-
key = (key * cos) + (rotate_half(key) * sin)
131+
# In-place operations that update the query and key tensors.
132+
os.environ["stride_in_no_check_dy2st_diff"] = "1"
133+
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
147134

148135
return query, key
149136

150137

151138
class DeepseekV2RMSNorm(nn.Layer):
152-
def __init__(self, config):
139+
def __init__(self, config: DeepseekV2Config):
153140
super().__init__()
154141
self.eps = config.rms_norm_eps
155142
self.weight = paddle.create_parameter(
@@ -170,9 +157,7 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
170157

171158
self.config = config
172159

173-
self.append_attn = config.append_attn
174160
self.max_seq_len = config.max_seq_len
175-
self.block_size = config.block_size
176161

177162
self.vocab_size = config.vocab_size
178163
self.hidden_size = config.hidden_size
@@ -181,12 +166,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
181166
self.num_key_value_heads = config.num_key_value_heads
182167
self.num_layers = config.num_hidden_layers
183168
self.rms_norm_eps = config.rms_norm_eps
184-
self.max_position_embeddings = config.max_position_embeddings
185169
self.quant_type = config.quant_type
186170
self.rope_theta = config.rope_theta
187171

188-
self.use_neox = False
189-
190172
self.use_weight_only = False
191173
if config.quant_type == "weight_only_int8":
192174
self.use_weight_only = True
@@ -499,7 +481,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
499481
rope_theta=self.rope_theta,
500482
rotary_emb=self.rotary_emb,
501483
norm_type="rmsnorm",
502-
use_neox_rotary_style=self.use_neox,
503484
rank_id=config.tensor_parallel_rank,
504485
moe_config=moe_config,
505486
mla_config=mla_config,

Diff for: paddlenlp/experimental/transformers/fused_transformer_layers.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -974,13 +974,13 @@ def compute_qkv_linear(self, ln_out, i):
974974
query = paddle.matmul(ln_out, self.q_proj_weights[i])
975975

976976
query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
977-
query_nope, query_pe = paddle.split(
978-
query, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
977+
query_nope, query_pe = query.split(
978+
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
979979
)
980980

981981
compressed_kv = paddle.matmul(ln_out, self.kv_a_proj_with_mqa_weights[i])
982-
compressed_kv, key_pe = paddle.split(
983-
compressed_kv, [self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
982+
compressed_kv, key_pe = compressed_kv.split(
983+
[self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
984984
)
985985
key_pe = key_pe.reshape([-1, 1, self.config.mla_config.qk_rope_head_dim])
986986
compressed_kv = self.norm_func(
@@ -994,8 +994,8 @@ def compute_qkv_linear(self, ln_out, i):
994994
key_value = key_value.reshape(
995995
[-1, self.num_heads, self.config.mla_config.qk_nope_head_dim + self.config.mla_config.v_head_dim]
996996
)
997-
key_nope, value = paddle.split(
998-
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
997+
key_nope, value = key_value.split(
998+
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
999999
)
10001000

10011001
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
@@ -1305,6 +1305,7 @@ def pre_process(self, **kwargs):
13051305

13061306
from paddlenlp_ops import get_position_ids
13071307

1308+
# In-place operations that compute the position_ids.
13081309
get_position_ids(seq_lens_encoder, seq_lens_decoder, self.position_ids)
13091310

13101311
def post_process(self, **kwargs):
@@ -1827,8 +1828,8 @@ def compute_qkv_linear(self, ln_out, i):
18271828
)
18281829

18291830
query = query.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
1830-
query_nope, query_pe = paddle.split(
1831-
query, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
1831+
query_nope, query_pe = query.split(
1832+
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.qk_rope_head_dim], axis=-1
18321833
)
18331834

18341835
compressed_kv = weight_only_linear(
@@ -1837,8 +1838,8 @@ def compute_qkv_linear(self, ln_out, i):
18371838
weight_scale=self.kv_a_proj_with_mqa_weights_scale[i],
18381839
weight_dtype=self.weight_dtype,
18391840
)
1840-
compressed_kv, key_pe = paddle.split(
1841-
compressed_kv, [self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
1841+
compressed_kv, key_pe = compressed_kv.split(
1842+
[self.config.mla_config.kv_lora_rank, self.config.mla_config.qk_rope_head_dim], axis=-1
18421843
)
18431844
key_pe = key_pe.reshape([-1, 1, self.config.mla_config.qk_rope_head_dim])
18441845
compressed_kv = self.norm_func(
@@ -1857,8 +1858,8 @@ def compute_qkv_linear(self, ln_out, i):
18571858
key_value = key_value.reshape(
18581859
[-1, self.num_heads, self.config.mla_config.qk_nope_head_dim + self.config.mla_config.v_head_dim]
18591860
)
1860-
key_nope, value = paddle.split(
1861-
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
1861+
key_nope, value = key_value.split(
1862+
[self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
18621863
)
18631864

18641865
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)

Diff for: paddlenlp/experimental/transformers/qwen2_moe/modeling.py

-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def __init__(self, config: Qwen2MoeConfig):
7878
self.num_key_value_heads = config.num_key_value_heads
7979
self.num_layers = config.num_hidden_layers
8080
self.rms_norm_eps = config.rms_norm_eps
81-
self.max_position_embeddings = config.max_position_embeddings
8281
self.quant_type = config.quant_type
8382
self.rope_theta = config.rope_theta
8483

0 commit comments

Comments
 (0)