Skip to content

Commit 67c096c

Browse files
authored
Merge branch 'develop' into develop
2 parents 8e425ed + 4c998c3 commit 67c096c

13 files changed

+505
-175
lines changed

custom_ops/gpu_ops/append_attention.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,10 @@ void AppendAttentionKernel(
277277
exec_stream,
278278
&qkv_out,
279279
const_cast<paddle::Tensor*>(&key_cache),
280-
const_cast<paddle::Tensor*>(&value_cache));
280+
const_cast<paddle::Tensor*>(&value_cache),
281+
q_norm_weight,
282+
k_norm_weight,
283+
rms_norm_eps);
281284
} else {
282285
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
283286
meta_data,
@@ -300,7 +303,10 @@ void AppendAttentionKernel(
300303
exec_stream,
301304
&qkv_out,
302305
const_cast<paddle::Tensor*>(&key_cache),
303-
const_cast<paddle::Tensor*>(&value_cache));
306+
const_cast<paddle::Tensor*>(&value_cache),
307+
q_norm_weight,
308+
k_norm_weight,
309+
rms_norm_eps);
304310
}
305311
} else {
306312
if (qkv_out_scales) {

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
120120
float row_variance =
121121
max(warp_m2 / head_size, 0.0f);
122122
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
123-
124123
if (hi < num_heads) { // q
125124
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
126125
#pragma unroll
@@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
129128
}
130129
} else { // k
131130
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
131+
#pragma unroll
132132
for (int i = 0; i < VecSize; i++) {
133133
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
134134
}

custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,166 @@
1818
#include "mma_tensor_op.cuh"
1919
#include "utils.cuh"
2020

21+
template <typename T, int VecSize = 1, typename InT = T>
22+
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
23+
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
24+
// head_size]
25+
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
26+
// head_size // 2]
27+
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
28+
// head_size // 2]
29+
T* __restrict__ q_out,
30+
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
31+
const int* __restrict__ batch_id_per_token, // [num_tokens]
32+
const int* __restrict__ cu_seqlens_q,
33+
const int* __restrict__ seq_lens_decoder, // [bsz]
34+
const float* __restrict__ cos_emb,
35+
const float* __restrict__ sin_emb,
36+
const float*
37+
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
38+
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
39+
const int max_seq_len,
40+
const int max_blocks_per_seq,
41+
const int num_heads,
42+
const int output_inner_dim,
43+
const int head_size,
44+
const int block_size,
45+
const int elem_cnt,
46+
const int gqa_group_size,
47+
const float* q_norm_weight,
48+
const float* k_norm_weight,
49+
const float rms_norm_eps) {
50+
using LoadT = AlignedVector<T, VecSize>;
51+
using LoadFloat = AlignedVector<float, VecSize>;
52+
using LoadInT = AlignedVector<InT, VecSize>;
53+
constexpr int HalfVecSize = VecSize / 2;
54+
using LoadEmbT = AlignedVector<float, HalfVecSize>;
55+
LoadInT src_vec;
56+
LoadFloat scale_vec;
57+
LoadT bias_vec;
58+
LoadEmbT cos_emb_vec;
59+
LoadEmbT sin_emb_vec;
60+
LoadFloat tmp_vec;
61+
LoadFloat q_norm_vec;
62+
LoadFloat k_norm_vec;
63+
64+
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
65+
int64_t all_warp_num = gridDim.x * blockDim.y;
66+
int64_t all_head_dim = elem_cnt / head_size;
67+
68+
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
69+
const int half_head_size = head_size / 2;
70+
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
71+
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
72+
const int token_id = linear_index / hidden_size;
73+
const int ori_bi = batch_id_per_token[token_id];
74+
if (seq_lens_decoder[ori_bi] == 0) continue;
75+
const int bias = linear_index % hidden_size;
76+
const int hi = bias / head_size; // q + k + v
77+
const int h_bias = bias % head_size;
78+
const int start_token_idx = cu_seqlens_q[ori_bi];
79+
const int write_seq_id =
80+
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
81+
if (write_seq_id == 0) continue;
82+
83+
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
84+
const int block_idx = block_table_now[write_seq_id / block_size];
85+
if (block_idx < 0) {
86+
printf(
87+
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
88+
"%d %d %d %d\n",
89+
block_idx,
90+
write_seq_id,
91+
ori_bi,
92+
seq_lens_decoder[ori_bi],
93+
token_id,
94+
cu_seqlens_q[ori_bi]);
95+
}
96+
const int block_offset = write_seq_id % block_size;
97+
98+
const int write_q_idx =
99+
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
100+
101+
const int bias_idx = hi * head_size + h_bias;
102+
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
103+
if (qkv_biases) {
104+
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
105+
}
106+
if (qkv_out_scales) {
107+
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
108+
}
109+
if (hi < num_heads + gqa_group_size) {
110+
// q k rope
111+
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
112+
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
113+
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
114+
}
115+
float thread_m2 = 0.0f;
116+
float warp_m2 = 0.0f;
117+
#pragma unroll
118+
for (int i = 0; i < HalfVecSize; i++) {
119+
// add_bias + rope
120+
float input_left = static_cast<float>(src_vec[2 * i]);
121+
float input_right = static_cast<float>(src_vec[2 * i + 1]);
122+
if (qkv_out_scales) {
123+
input_left *= scale_vec[2 * i];
124+
input_right *= scale_vec[2 * i + 1];
125+
}
126+
if (qkv_biases) {
127+
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
128+
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
129+
}
130+
if (hi < num_heads + gqa_group_size) {
131+
const float cos_tmp = cos_emb_vec[i];
132+
const float sin_tmp = sin_emb_vec[i];
133+
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
134+
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
135+
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
136+
tmp_vec[2 * i] = tmp1;
137+
tmp_vec[2 * i + 1] = tmp2;
138+
} else {
139+
bias_vec[2 * i] = static_cast<T>(input_left);
140+
bias_vec[2 * i + 1] = static_cast<T>(input_right);
141+
}
142+
}
143+
if (hi < (num_heads + gqa_group_size)) {
144+
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
145+
float row_variance =
146+
max(warp_m2 / head_size, 0.0f);
147+
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
148+
if (hi < num_heads) {
149+
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
150+
#pragma unroll
151+
for (int i = 0; i < VecSize; i++) {
152+
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
153+
}
154+
} else {
155+
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
156+
#pragma unroll
157+
for (int i = 0; i < VecSize; i++) {
158+
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
159+
}
160+
}
161+
}
162+
if (hi < num_heads) {
163+
// write q
164+
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
165+
} else {
166+
// write k/v
167+
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
168+
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
169+
kv_head_idx * block_size * head_size +
170+
block_offset * head_size + h_bias);
171+
// write
172+
if (hi < num_heads + gqa_group_size) {
173+
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
174+
} else {
175+
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
176+
}
177+
}
178+
}
179+
}
180+
21181
template <int VecSize = 4, int HeadDim = 128>
22182
__global__ void append_clear_cache_int8_block(
23183
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,

0 commit comments

Comments
 (0)