|
18 | 18 | #include "mma_tensor_op.cuh" |
19 | 19 | #include "utils.cuh" |
20 | 20 |
|
| 21 | +template <typename T, int VecSize = 1, typename InT = T> |
| 22 | +__global__ void append_speculate_cache_T_rope_qk_norm_kernel( |
| 23 | + const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, |
| 24 | + // head_size] |
| 25 | + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, |
| 26 | + // head_size // 2] |
| 27 | + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, |
| 28 | + // head_size // 2] |
| 29 | + T* __restrict__ q_out, |
| 30 | + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] |
| 31 | + const int* __restrict__ batch_id_per_token, // [num_tokens] |
| 32 | + const int* __restrict__ cu_seqlens_q, |
| 33 | + const int* __restrict__ seq_lens_decoder, // [bsz] |
| 34 | + const float* __restrict__ cos_emb, |
| 35 | + const float* __restrict__ sin_emb, |
| 36 | + const float* |
| 37 | + qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size] |
| 38 | + const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] |
| 39 | + const int max_seq_len, |
| 40 | + const int max_blocks_per_seq, |
| 41 | + const int num_heads, |
| 42 | + const int output_inner_dim, |
| 43 | + const int head_size, |
| 44 | + const int block_size, |
| 45 | + const int elem_cnt, |
| 46 | + const int gqa_group_size, |
| 47 | + const float* q_norm_weight, |
| 48 | + const float* k_norm_weight, |
| 49 | + const float rms_norm_eps) { |
| 50 | + using LoadT = AlignedVector<T, VecSize>; |
| 51 | + using LoadFloat = AlignedVector<float, VecSize>; |
| 52 | + using LoadInT = AlignedVector<InT, VecSize>; |
| 53 | + constexpr int HalfVecSize = VecSize / 2; |
| 54 | + using LoadEmbT = AlignedVector<float, HalfVecSize>; |
| 55 | + LoadInT src_vec; |
| 56 | + LoadFloat scale_vec; |
| 57 | + LoadT bias_vec; |
| 58 | + LoadEmbT cos_emb_vec; |
| 59 | + LoadEmbT sin_emb_vec; |
| 60 | + LoadFloat tmp_vec; |
| 61 | + LoadFloat q_norm_vec; |
| 62 | + LoadFloat k_norm_vec; |
| 63 | + |
| 64 | + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; |
| 65 | + int64_t all_warp_num = gridDim.x * blockDim.y; |
| 66 | + int64_t all_head_dim = elem_cnt / head_size; |
| 67 | + |
| 68 | + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; |
| 69 | + const int half_head_size = head_size / 2; |
| 70 | + for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) { |
| 71 | + int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; |
| 72 | + const int token_id = linear_index / hidden_size; |
| 73 | + const int ori_bi = batch_id_per_token[token_id]; |
| 74 | + if (seq_lens_decoder[ori_bi] == 0) continue; |
| 75 | + const int bias = linear_index % hidden_size; |
| 76 | + const int hi = bias / head_size; // q + k + v |
| 77 | + const int h_bias = bias % head_size; |
| 78 | + const int start_token_idx = cu_seqlens_q[ori_bi]; |
| 79 | + const int write_seq_id = |
| 80 | + seq_lens_decoder[ori_bi] + token_id - start_token_idx; |
| 81 | + if (write_seq_id == 0) continue; |
| 82 | + |
| 83 | + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; |
| 84 | + const int block_idx = block_table_now[write_seq_id / block_size]; |
| 85 | + if (block_idx < 0) { |
| 86 | + printf( |
| 87 | + "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " |
| 88 | + "%d %d %d %d\n", |
| 89 | + block_idx, |
| 90 | + write_seq_id, |
| 91 | + ori_bi, |
| 92 | + seq_lens_decoder[ori_bi], |
| 93 | + token_id, |
| 94 | + cu_seqlens_q[ori_bi]); |
| 95 | + } |
| 96 | + const int block_offset = write_seq_id % block_size; |
| 97 | + |
| 98 | + const int write_q_idx = |
| 99 | + token_id * output_inner_dim * head_size + hi * head_size + h_bias; |
| 100 | + |
| 101 | + const int bias_idx = hi * head_size + h_bias; |
| 102 | + Load<InT, VecSize>(&qkv[linear_index], &src_vec); |
| 103 | + if (qkv_biases) { |
| 104 | + Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec); |
| 105 | + } |
| 106 | + if (qkv_out_scales) { |
| 107 | + Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec); |
| 108 | + } |
| 109 | + if (hi < num_heads + gqa_group_size) { |
| 110 | + // q k rope |
| 111 | + const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; |
| 112 | + Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); |
| 113 | + Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); |
| 114 | + } |
| 115 | + float thread_m2 = 0.0f; |
| 116 | + float warp_m2 = 0.0f; |
| 117 | +#pragma unroll |
| 118 | + for (int i = 0; i < HalfVecSize; i++) { |
| 119 | + // add_bias + rope |
| 120 | + float input_left = static_cast<float>(src_vec[2 * i]); |
| 121 | + float input_right = static_cast<float>(src_vec[2 * i + 1]); |
| 122 | + if (qkv_out_scales) { |
| 123 | + input_left *= scale_vec[2 * i]; |
| 124 | + input_right *= scale_vec[2 * i + 1]; |
| 125 | + } |
| 126 | + if (qkv_biases) { |
| 127 | + input_left = input_left + static_cast<float>(bias_vec[2 * i]); |
| 128 | + input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]); |
| 129 | + } |
| 130 | + if (hi < num_heads + gqa_group_size) { |
| 131 | + const float cos_tmp = cos_emb_vec[i]; |
| 132 | + const float sin_tmp = sin_emb_vec[i]; |
| 133 | + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; |
| 134 | + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; |
| 135 | + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; |
| 136 | + tmp_vec[2 * i] = tmp1; |
| 137 | + tmp_vec[2 * i + 1] = tmp2; |
| 138 | + } else { |
| 139 | + bias_vec[2 * i] = static_cast<T>(input_left); |
| 140 | + bias_vec[2 * i + 1] = static_cast<T>(input_right); |
| 141 | + } |
| 142 | + } |
| 143 | + if (hi < (num_heads + gqa_group_size)) { |
| 144 | + WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); |
| 145 | + float row_variance = |
| 146 | + max(warp_m2 / head_size, 0.0f); |
| 147 | + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); |
| 148 | + if (hi < num_heads) { |
| 149 | + Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); |
| 150 | + #pragma unroll |
| 151 | + for (int i = 0; i < VecSize; i++) { |
| 152 | + bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]); |
| 153 | + } |
| 154 | + } else { |
| 155 | + Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); |
| 156 | + #pragma unroll |
| 157 | + for (int i = 0; i < VecSize; i++) { |
| 158 | + bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]); |
| 159 | + } |
| 160 | + } |
| 161 | + } |
| 162 | + if (hi < num_heads) { |
| 163 | + // write q |
| 164 | + Store<T, VecSize>(bias_vec, &q_out[write_q_idx]); |
| 165 | + } else { |
| 166 | + // write k/v |
| 167 | + const int kv_head_idx = (hi - num_heads) % gqa_group_size; |
| 168 | + const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size + |
| 169 | + kv_head_idx * block_size * head_size + |
| 170 | + block_offset * head_size + h_bias); |
| 171 | + // write |
| 172 | + if (hi < num_heads + gqa_group_size) { |
| 173 | + Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]); |
| 174 | + } else { |
| 175 | + Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]); |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | +} |
| 180 | + |
21 | 181 | template <int VecSize = 4, int HeadDim = 128> |
22 | 182 | __global__ void append_clear_cache_int8_block( |
23 | 183 | uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, |
|
0 commit comments