Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4844966
Extend scratch buffer for long prompts
cmikeh2 Aug 12, 2022
be4714d
Merge branch 'master' into cholmes/fix-long-seq-len-inference
cmikeh2 Aug 12, 2022
c6411d1
Fetch correct tail buffer for batched inputs.
cmikeh2 Aug 12, 2022
c074ed3
Style change
cmikeh2 Aug 12, 2022
3135774
Merge branch 'cholmes/fix-long-seq-len-inference' of https://github.c…
cmikeh2 Aug 12, 2022
777a36e
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Aug 13, 2022
ec6b1ad
Fix variable rename
cmikeh2 Aug 16, 2022
d897f98
Merge branch 'cholmes/fix-long-seq-len-inference' of https://github.c…
cmikeh2 Aug 16, 2022
6da9234
Merge branch 'master' into cholmes/fix-long-seq-len-inference
cmikeh2 Aug 16, 2022
606d344
Reduce maximum sequence length
cmikeh2 Aug 16, 2022
5269ba1
Merge branch 'master' into cholmes/fix-long-seq-len-inference
cmikeh2 Aug 23, 2022
89f2ded
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Sep 1, 2022
8e29808
Merge branch 'master' into cholmes/fix-long-seq-len-inference
cmikeh2 Sep 9, 2022
c824330
Add debug print
cmikeh2 Sep 9, 2022
d9eb076
Merge branch 'master' into cholmes/fix-long-seq-len-inference
cmikeh2 Sep 9, 2022
aafba00
Multi-batch inference fix
cmikeh2 Sep 10, 2022
4abd455
add batch-size at the tranform launch for the half-precision implemen…
Sep 11, 2022
603cc5b
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Sep 13, 2022
508712a
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Sep 19, 2022
51a6371
no need to throw error when there is no mask passed
Sep 22, 2022
9effa9e
Merge branch 'cholmes/fix-long-seq-len-inference' of github.com:micro…
Sep 22, 2022
c9652ec
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Sep 22, 2022
d8f5203
Increasing the token-length based on available memory for GPT models …
RezaYazdaniAminabadi Sep 22, 2022
b64c117
Merge branch 'master' into cholmes/fix-long-seq-len-inference
RezaYazdaniAminabadi Sep 22, 2022
48a8b96
fix bert issue & remove some dead code
Sep 22, 2022
c1d83f9
fix formating
Sep 22, 2022
a5f4d31
Merge branch 'master' into cholmes/fix-long-seq-len-inference
jeffra Sep 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
Expand All @@ -31,13 +32,15 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -47,7 +50,7 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;

lane += WARP_SIZE;
}
Expand All @@ -61,7 +64,8 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
Expand All @@ -75,13 +79,15 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -91,7 +97,7 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
key_layer[k_offset + lane] = (__half)k;

lane += WARP_SIZE;
}
Expand All @@ -105,7 +111,8 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
Expand All @@ -118,13 +125,15 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
Expand All @@ -134,7 +143,7 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
key_layer[k_offset + lane] = k;

lane += WARP_SIZE;
}
Expand All @@ -147,7 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
unsigned total_count,
int max_out_tokens)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
Expand All @@ -160,7 +170,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;

constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
Expand Down Expand Up @@ -209,17 +219,32 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream)
cudaStream_t stream,
int max_out_tokens)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
else if (rotate_half)
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
}

template void launch_apply_rotary_pos_emb<float>(float*,
Expand All @@ -232,7 +257,8 @@ template void launch_apply_rotary_pos_emb<float>(float*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
Expand All @@ -243,7 +269,8 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
unsigned,
bool,
bool,
cudaStream_t);
cudaStream_t,
int);

/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
Expand Down
4 changes: 0 additions & 4 deletions csrc/transformer/inference/csrc/dequantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ __global__ void dequantize_kernel(__half* output,
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
Expand Down
8 changes: 4 additions & 4 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ __global__ void fused_bias_residual(float* input,
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
}
output_cast[offset] = data;
input_cast[offset] = data;
}
}

Expand Down Expand Up @@ -260,7 +260,7 @@ __global__ void fused_bias_residual(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
Expand Down Expand Up @@ -324,7 +324,7 @@ __global__ void gptj_residual_add(float* input,
data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale;
data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale;

output_cast[offset] = data;
input_cast[offset] = data;
}
}

Expand Down Expand Up @@ -390,7 +390,7 @@ __global__ void gptj_residual_add(__half* input,
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

output_cast[offset] = vals_vec;
input_cast[offset] = vals_vec;
}
#endif
}
Expand Down
Loading