Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
Expand All @@ -31,13 +32,15 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -47,7 +50,7 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;

lane += WARP_SIZE;
}
Expand All @@ -61,7 +64,8 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
Expand All @@ -75,13 +79,15 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -91,7 +97,7 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
key_layer[k_offset + lane] = (__half)k;

lane += WARP_SIZE;
}
Expand All @@ -105,7 +111,8 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
Expand All @@ -118,13 +125,15 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -134,7 +143,7 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;

lane += WARP_SIZE;
}
Expand All @@ -147,7 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
Expand All @@ -160,7 +170,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
Expand Down Expand Up @@ -209,17 +219,32 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream)
cudaStream_t stream,
int max_out_tokens)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
else if (rotate_half)
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
}

template void launch_apply_rotary_pos_emb<float>(float*,
Expand All @@ -232,7 +257,8 @@ template void launch_apply_rotary_pos_emb<float>(float*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
Expand All @@ -243,7 +269,8 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);

/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
Expand Down
4 changes: 0 additions & 4 deletions csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ __global__ void dequantize_kernel(__half* output,
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
Expand Down
8 changes: 4 additions & 4 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ __global__ void fused_bias_residual(float* input,
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
}
output_cast[offset] = data;
input_cast[offset] = data;
}
}

Expand Down Expand Up @@ -260,7 +260,7 @@ __global__ void fused_bias_residual(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
Expand Down Expand Up @@ -324,7 +324,7 @@ __global__ void gptj_residual_add(float* input,
data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale;
data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale;

output_cast[offset] = data;
input_cast[offset] = data;
}
}

Expand Down Expand Up @@ -390,7 +390,7 @@ __global__ void gptj_residual_add(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
Expand Down
Loading