From 48449661848086a3f69af38fc1b462b26195d931 Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Fri, 12 Aug 2022 06:15:40 +0500 Subject: [PATCH 01/12] Extend scratch buffer for long prompts --- .../transformer/inference/csrc/pt_binding.cpp | 41 +++++++++++++++---- .../inference/includes/custom_cuda_layers.h | 2 +- .../inference/transformer_inference.py | 9 +++- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 360beaa10284..1ddca769de00 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -59,6 +59,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, template void allocate_workspace(size_t hidden_dim, + size_t prompt_len, size_t max_seq_len, size_t batch_size, unsigned num_layers, @@ -66,6 +67,9 @@ void allocate_workspace(size_t hidden_dim, { size_t _workSpaceSize = 16 * (hidden_dim * batch_size * max_seq_len) + (num_layers * batch_size * max_seq_len * hidden_dim * 2); // KV-cache + size_t prompt_output = batch_size * head_size * prompt_len * prompt_len; + _workSpaceSize = _workSpaceSize + prompt_output; + Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T)); } @@ -81,10 +85,13 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) float alpha = 1; float gemm_beta = 0.0; - if (!workspace) { - allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0), 1); + /* + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { + allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0), 1, head_size); workspace = (T*)Context::Instance().GetWorkSpace(); } + */ auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options); unsigned m = W.size(1); @@ -305,7 +312,15 @@ void attention_unfused(T* prev_key_cont, float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; - T* workspace = (T*)output + bsz * seq_len * heads * k; + T* workspace; + if (seq_len == 1) { + workspace = (T*)output + bsz * seq_len * heads * k; + } else { + // If we are doing the prompt, switch to the tail workspace + T* scratch = (T*)Context::Instance().GetWorkSpace(); + workspace = scratch + (Context::Instance().get_workspace_size() / sizeof(T)) - + heads * seq_len * seq_len; + } cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), @@ -591,14 +606,19 @@ std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& beta, const float epsilon, bool add_bias, - unsigned num_layers) + unsigned num_layers, + bool bloom_seq_len, + unsigned head_size) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - if (!workspace) { + // Reallocate memory if we receive a new prompt + if (!workspace || input.size(1) != 1) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers); + const int max_seq_len = (bloom_seq_len) ? 128 : MAX_OUT_TOKES; + allocate_workspace( + input.size(2), input.size(1), max_seq_len, input.size(0), num_layers, head_size); workspace = (T*)Context::Instance().GetWorkSpace(); } auto options = at::TensorOptions() @@ -699,7 +719,8 @@ template at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias, - unsigned num_layers) + unsigned num_layers, + unsigned head_size) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -710,10 +731,12 @@ at::Tensor ds_linear_layer(at::Tensor& input, int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - if (!workspace) { + // Reallocate memory if we received a new prompt + if (!workspace || input.size(1) != 1) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers); + allocate_workspace( + input.size(2), input.size(1), MAX_OUT_TOKES, input.size(0), num_layers, head_size); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index c2bb30126cd6..a2bc1e7b328a 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -17,7 +17,7 @@ #include #include -#define MAX_OUT_TOKES 128 +#define MAX_OUT_TOKES 1024 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index d38cf8c3d395..e552940bdac8 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -401,6 +401,8 @@ def compute_attention(qkv_out, input_mask): def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 + + head_size = (attn_qkvw.shape[-1] // 3 // num_attention_heads_per_partition) if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 @@ -408,7 +410,8 @@ def selfAttention_fp(): qkv_out = linear_func(input, attn_qkvw, attn_qkvb, - DeepSpeedTransformerInference.layer_id) + DeepSpeedTransformerInference.layer_id, + head_size) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 @@ -421,7 +424,9 @@ def selfAttention_fp(): config.epsilon, (attn_qkvb is not None), 1 if config.bigscience_bloom else - DeepSpeedTransformerInference.layer_id) + DeepSpeedTransformerInference.layer_id, + config.bigscience_bloom, + head_size) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) From c6411d13d2b31affe8f272351fb1b297b4cbeedf Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Fri, 12 Aug 2022 21:35:58 +0500 Subject: [PATCH 02/12] Fetch correct tail buffer for batched inputs. --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 1ddca769de00..8d80da39470e 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -319,7 +319,7 @@ void attention_unfused(T* prev_key_cont, // If we are doing the prompt, switch to the tail workspace T* scratch = (T*)Context::Instance().GetWorkSpace(); workspace = scratch + (Context::Instance().get_workspace_size() / sizeof(T)) - - heads * seq_len * seq_len; + bsz * heads * seq_len * seq_len; } cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); From c074ed37d2194456794446ec6971ef0b1eda8f8f Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Fri, 12 Aug 2022 21:36:48 +0500 Subject: [PATCH 03/12] Style change --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 8d80da39470e..51c55b1e3c3a 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -607,7 +607,7 @@ std::vector ds_qkv_gemm(at::Tensor& input, const float epsilon, bool add_bias, unsigned num_layers, - bool bloom_seq_len, + bool is_bloom, unsigned head_size) { int bsz = input.size(0) * input.size(1); From ec6b1ad3a5d05a22440776ed7f4f60c7620b7e2a Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Tue, 16 Aug 2022 20:53:56 +0500 Subject: [PATCH 04/12] Fix variable rename --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 51c55b1e3c3a..54680997faf0 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -616,7 +616,7 @@ std::vector ds_qkv_gemm(at::Tensor& input, if (!workspace || input.size(1) != 1) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - const int max_seq_len = (bloom_seq_len) ? 128 : MAX_OUT_TOKES; + const int max_seq_len = (is_bloom) ? 128 : MAX_OUT_TOKES; allocate_workspace( input.size(2), input.size(1), max_seq_len, input.size(0), num_layers, head_size); workspace = (T*)Context::Instance().GetWorkSpace(); From 606d3447a616e5c1592d6c808710b130e466501f Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Wed, 17 Aug 2022 01:59:47 +0500 Subject: [PATCH 05/12] Reduce maximum sequence length --- csrc/transformer/inference/includes/custom_cuda_layers.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index fecd6a1ab9de..d1cd3361b448 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -17,7 +17,7 @@ #include #include -#define MAX_OUT_TOKES 1024 +#define MAX_OUT_TOKES 512 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 From c82433048382a23e8cea4b7638e31acc2e9117c7 Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Fri, 9 Sep 2022 20:53:52 +0000 Subject: [PATCH 06/12] Add debug print --- csrc/transformer/inference/includes/inference_context.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index a6f6613fc6a5..f314f4485d6f 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -85,7 +85,12 @@ class Context { cudaMalloc(&_workspace, size); } - if (!_workspace) { throw std::runtime_error("Workspace is null."); } + if (!_workspace) { + size_t total_size, free_size; + cudaMemGetInfo(&free_size, &total_size); + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", size, free_size, total_size); + throw std::runtime_error("Workspace is null."); + } _workSpaceSize = size; } From aafba00c81eaf29c0c2b209a94bc31f4de942936 Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Sat, 10 Sep 2022 02:08:52 +0000 Subject: [PATCH 07/12] Multi-batch inference fix --- csrc/transformer/inference/csrc/transform.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 9230516238c0..092724ad9e1d 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -28,10 +28,6 @@ __global__ void bias_add_transform_0213(float* output, int d1_stride = hidden_dim; int d2_stride = hidden_dim / heads; - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - // int d2_out_stride = d2_stride * seq_length; - int d0 = blockIdx.x; // Batch int d1 = blockIdx.y; // Sequence ID (0-127) int cnt = blockIdx.z / head_ext; // Hidden count @@ -39,6 +35,7 @@ __global__ void bias_add_transform_0213(float* output, int d3 = threadIdx.x; // Values (groups of 4) int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : MAX_OUT_TOKES); const float4* vals_vec = reinterpret_cast(vals); float4* output_vec = @@ -50,7 +47,7 @@ __global__ void bias_add_transform_0213(float* output, vals_vec += (d2 * d2_stride); output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); + output_vec += (d0 * d0_out_stride); output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; @@ -105,6 +102,8 @@ __global__ void bias_add_transform_0213(__half* output, // q int d3 = threadIdx.x; // Values (groups of 4) int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + float4 vals_arr; float4 output_arr; @@ -121,7 +120,7 @@ __global__ void bias_add_transform_0213(__half* output, // q vals_vec += (d2 * d2_stride); output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); + output_vec += (d0 * d0_out_stride); output_vec += (d2 * d2_out_stride); unsigned seq_id = d1 + seq_offset; From 4abd455521965930d0e921de8afc0073ea7df9d1 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 12 Sep 2022 02:48:08 +0500 Subject: [PATCH 08/12] add batch-size at the tranform launch for the half-precision implementation --- csrc/transformer/inference/csrc/transform.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 092724ad9e1d..e72df0f14232 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -225,7 +225,7 @@ void launch_bias_add_transform_0213<__half>(__half* output, hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(1, seq_length, (trans_count * head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); bias_add_transform_0213<<>>(output, k_cache, v_cache, From 51a63715ba8106656b836212f9587215f674d953 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 22 Sep 2022 08:12:32 +0500 Subject: [PATCH 09/12] no need to throw error when there is no mask passed --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 7022ed82296d..cf04f96b8783 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -45,7 +45,7 @@ inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int // Bert style models have always a mask stride of 1. return 1; } else if (trnsfrmr_type == TransformerType::UNKNOWN) { - throw std::runtime_error("Unknown transformer type."); + return 0; } // this is just to make the compiler happy. From d8f52032223ab24b194bbdd8be1df2586de8c3a3 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Thu, 22 Sep 2022 12:16:45 -0700 Subject: [PATCH 10/12] Increasing the token-length based on available memory for GPT models (#2280) * increasing the token-length based on available memory & reduce memory alloc * merging * formating * fix compile issue * fix the max_out_tokens to use a dynamic range based on available memory * fix the issue with empty prompt * fix residual-add * fix some issues with unit tests * fix formatting --- .../inference/csrc/apply_rotary_pos_emb.cu | 63 +++++-- csrc/transformer/inference/csrc/dequantize.cu | 4 - csrc/transformer/inference/csrc/gelu.cu | 8 +- .../transformer/inference/csrc/pt_binding.cpp | 160 ++++++++++-------- csrc/transformer/inference/csrc/transform.cu | 29 ++-- .../inference/includes/inference_context.h | 61 +++++-- .../includes/inference_cuda_layers.h | 7 +- .../inference/transformer_inference.py | 105 ++++++------ 8 files changed, 270 insertions(+), 167 deletions(-) mode change 100755 => 100644 deepspeed/ops/transformer/inference/transformer_inference.py diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index e7279d77e985..4a91975a73ca 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -18,7 +18,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -31,13 +32,15 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = mixed_query[offset + lane]; - float k = key_layer[offset + lane]; + float k = key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -47,7 +50,7 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = q; - key_layer[offset + lane] = k; + key_layer[k_offset + lane] = k; lane += WARP_SIZE; } @@ -61,7 +64,8 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); @@ -75,13 +79,15 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[offset + lane]; + float k = (float)key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -91,7 +97,7 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = (__half)q; - key_layer[offset + lane] = (__half)k; + key_layer[k_offset + lane] = (__half)k; lane += WARP_SIZE; } @@ -105,7 +111,8 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); @@ -118,13 +125,15 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, unsigned offset = head_id * head_size; unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned seq_index = head_id % seq_len; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; if (head_id < total_count) { while (lane < rotary_dim) { float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = mixed_query[offset + lane]; - float k = key_layer[offset + lane]; + float k = key_layer[k_offset + lane]; float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -134,7 +143,7 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); mixed_query[offset + lane] = q; - key_layer[offset + lane] = k; + key_layer[k_offset + lane] = k; lane += WARP_SIZE; } @@ -147,7 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, unsigned seq_offset, unsigned num_heads, unsigned head_size, - unsigned total_count) + unsigned total_count, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 cg::thread_block b = cg::this_thread_block(); @@ -160,7 +170,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; unsigned seq_index = head_id % seq_len; unsigned offset = head_id * head_size; - unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size; + unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size; constexpr unsigned mask[32] = { 0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000, @@ -209,17 +219,32 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned batch, bool rotate_half, bool rotate_every_two, - cudaStream_t stream) + cudaStream_t stream, + int max_out_tokens) { int total_count = batch * num_heads * seq_len; dim3 block_dims(1024); dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); if (rotate_every_two) - apply_rotary_pos_emb<<>>( - mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + apply_rotary_pos_emb<<>>(mixed_query, + key_layer, + rotary_dim, + seq_len, + offset, + num_heads, + head_size, + total_count, + max_out_tokens); else if (rotate_half) - apply_rotary_pos_emb1<<>>( - mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + apply_rotary_pos_emb1<<>>(mixed_query, + key_layer, + rotary_dim, + seq_len, + offset, + num_heads, + head_size, + total_count, + max_out_tokens); } template void launch_apply_rotary_pos_emb(float*, @@ -232,7 +257,8 @@ template void launch_apply_rotary_pos_emb(float*, unsigned, bool, bool, - cudaStream_t); + cudaStream_t, + int); template void launch_apply_rotary_pos_emb<__half>(__half*, __half*, unsigned, @@ -243,7 +269,8 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, unsigned, bool, bool, - cudaStream_t); + cudaStream_t, + int); /* __global__ void apply_rotary_pos_emb(float* mixed_query, diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index 3018845bc3b8..3843c2b6ea8b 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -152,10 +152,6 @@ __global__ void dequantize_kernel(__half* output, q_h[1] = __float2half(local_scale * (float)q_int8[1]); q_h[2] = __float2half(local_scale * (float)q_int8[2]); q_h[3] = __float2half(local_scale * (float)q_int8[3]); - // q_h[4] = __float2half(local_scale * (float)q_int8[4]); - // q_h[5] = __float2half(local_scale * (float)q_int8[5]); - // q_h[6] = __float2half(local_scale * (float)q_int8[6]); - // q_h[7] = __float2half(local_scale * (float)q_int8[7]); output_cast[tid] = q_f; tid += blockDim.x; } diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 3f9bf4ca740a..8bc58769ede7 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -188,7 +188,7 @@ __global__ void fused_bias_residual(float* input, data.z = data.z + out.z + bias_data.z; data.w = data.w + out.w + bias_data.w; } - output_cast[offset] = data; + input_cast[offset] = data; } } @@ -260,7 +260,7 @@ __global__ void fused_bias_residual(__half* input, vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); - output_cast[offset] = vals_vec; + input_cast[offset] = vals_vec; } #endif } @@ -324,7 +324,7 @@ __global__ void gptj_residual_add(float* input, data.z = out.z + res_vec.z + (data.z + bias_data.z) * mp_scale; data.w = out.w + res_vec.w + (data.w + bias_data.w) * mp_scale; - output_cast[offset] = data; + input_cast[offset] = data; } } @@ -390,7 +390,7 @@ __global__ void gptj_residual_add(__half* input, vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); - output_cast[offset] = vals_vec; + input_cast[offset] = vals_vec; } #endif } diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 97a012147a5f..65549cdcd71a 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -102,18 +102,14 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, template void allocate_workspace(size_t hidden_dim, - size_t prompt_len, - size_t max_seq_len, size_t batch_size, unsigned num_layers, - size_t head_size = 128) + unsigned mp_size = 1, + bool external_cache = false, + unsigned rank = 0) { - size_t _workSpaceSize = 16 * (hidden_dim * batch_size * max_seq_len) + - (num_layers * batch_size * max_seq_len * hidden_dim * 2); // KV-cache - size_t prompt_output = batch_size * head_size * prompt_len * prompt_len; - _workSpaceSize = _workSpaceSize + prompt_output; - - Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T)); + Context::Instance().GenWorkSpace( + num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank); } template @@ -131,8 +127,8 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) /* // Reallocate memory if we received a new prompt if (!workspace || input.size(1) != 1) { - allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0), 1, head_size); - workspace = (T*)Context::Instance().GetWorkSpace(); + allocate_workspace(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1, + head_size); workspace = (T*)Context::Instance().GetWorkSpace(); } */ @@ -362,8 +358,8 @@ void attention_unfused(T* prev_key_cont, } else { // If we are doing the prompt, switch to the tail workspace T* scratch = (T*)Context::Instance().GetWorkSpace(); - workspace = scratch + (Context::Instance().get_workspace_size() / sizeof(T)) - - bsz * heads * seq_len * seq_len; + workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) - + bsz * heads * seq_len * soft_len); } cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); @@ -378,7 +374,7 @@ void attention_unfused(T* prev_key_cont, workspace, CUBLAS_OP_T, CUBLAS_OP_N, - MAX_OUT_TOKES * k, + Context::Instance().GetMaxTokenLenght() * k, seq_len * k, seq_len * soft_len, bsz * heads, @@ -411,7 +407,7 @@ void attention_unfused(T* prev_key_cont, (T*)output, CUBLAS_OP_N, CUBLAS_OP_N, - MAX_OUT_TOKES * k, + Context::Instance().GetMaxTokenLenght() * k, seq_len * soft_len, seq_len * k, bsz * heads, @@ -459,12 +455,11 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); auto query_cont = workspace + 8 * buf_size; - size_t offset = - 16 * (hidden_dim * bsz * MAX_OUT_TOKES) + layer_id * 2 * bsz * MAX_OUT_TOKES * hidden_dim; - + size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) + + layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; unsigned all_tokens = soft_len; auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); - size_t value_offset = bsz * MAX_OUT_TOKES * hidden_dim; + size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; T* temp_buf = (T*)output.data_ptr() + at::numel(output); launch_bias_add_transform_0213((T*)query_cont, @@ -482,7 +477,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, rotate_half, rotate_every_two, Context::Instance().GetCurrentStream(), - 3); + 3, + Context::Instance().GetMaxTokenLenght()); if (rotary_dim > 0 && rotate_half) launch_apply_rotary_pos_emb(query_cont, kv_cache, @@ -494,7 +490,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); attention_unfused(workspace + offset, (T*)query_cont, @@ -629,16 +626,17 @@ void ds_layernorm_internal(T* workspace, } template -void quantized_gemm(at::Tensor& output, +void quantized_gemm(void* output, T* input, at::Tensor& weight, at::Tensor& qscale, int groups, int bsz) { - auto weight16 = at::empty({weight.size(0), weight.size(1)}, output.options()); + T* weight16 = (T*)Context::Instance().GetWorkSpace() + + 12 * Context::Instance().GetMaxTokenLenght() * weight.size(1); - launch_dequantize((T*)weight16.data_ptr(), + launch_dequantize(weight16, (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), weight.size(0), @@ -656,9 +654,9 @@ void quantized_gemm(at::Tensor& output, weight.size(1), &alpha, &gemm_beta, - (T*)weight16.data_ptr(), + weight16, (T*)input, - (T*)output.data_ptr(), + (T*)output, #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); #else @@ -682,10 +680,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, T* workspace = (T*)Context::Instance().GetWorkSpace(); workspace += (3 * bsz * input.size(2)); ds_layernorm_internal(workspace, input, gamma, beta, epsilon); - // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); if (q_int8) { - quantized_gemm(output, workspace, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -728,22 +725,20 @@ std::vector ds_qkv_gemm(at::Tensor& input, const float epsilon, bool add_bias, unsigned num_layers, - bool is_bloom, - unsigned head_size, + bool external_cache, + unsigned mp_size, + unsigned rank, bool q_int8) { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); int out_size = q_int8 ? weight.size(0) : weight.size(1); - // Reallocate memory if we receive a new prompt - if (!workspace || input.size(1) != 1) { + if (!workspace) cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - const int max_seq_len = (is_bloom) ? 128 : MAX_OUT_TOKES; - allocate_workspace( - input.size(2), input.size(1), max_seq_len, input.size(0), num_layers, head_size); - workspace = (T*)Context::Instance().GetWorkSpace(); - } + allocate_workspace(input.size(2), input.size(0), num_layers, mp_size, external_cache, rank); + workspace = (T*)Context::Instance().GetWorkSpace(); + auto options = at::TensorOptions() .dtype(input.options().dtype()) .layout(at::kStrided) @@ -840,8 +835,7 @@ template at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias, - unsigned num_layers, - unsigned head_size) + unsigned num_layers) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -856,8 +850,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, if (!workspace || input.size(1) != 1) { cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - allocate_workspace( - input.size(2), input.size(1), MAX_OUT_TOKES, input.size(0), num_layers, head_size); + allocate_workspace(input.size(2), input.size(0), num_layers); workspace = (T*)Context::Instance().GetWorkSpace(); } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); @@ -925,18 +918,20 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& q_scale, bool q_int8) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); int out_size = q_int8 ? weight.size(0) : weight.size(1); - int bsz = input_cont.size(0) * input_cont.size(1); - auto output = at::empty({input_cont.size(0), input_cont.size(1), out_size}, options); + int bsz = input.size(0) * input.size(1); + + T* workspace = (T*)Context::Instance().GetWorkSpace(); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); if (q_int8) { - quantized_gemm(output, (T*)input_cont.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -947,11 +942,11 @@ at::Tensor ds_vector_matmul(at::Tensor& input, CUBLAS_OP_N, weight.size(1), bsz, - input_cont.size(2), + input.size(2), &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), + (T*)input.data_ptr(), (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); @@ -988,6 +983,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, at::Tensor& residual, at::Tensor& input_bias, at::Tensor& weight, + at::Tensor& weight1, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, @@ -995,13 +991,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, bool preLayerNorm, bool mlp_after_attn, at::Tensor& q_scale, + at::Tensor& q_scale1, bool q_int8, ActivationFuncType act_func_type) { int bsz = input.size(0) * input.size(1); - auto inp_norm = at::empty_like(input); - - launch_residual_layer_norm((T*)inp_norm.data_ptr(), + T* inp_norm = + (T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output); + T* intermediate = inp_norm + torch::numel(input); + launch_residual_layer_norm((T*)inp_norm, (T*)nullptr, (T*)input.data_ptr(), (T*)residual.data_ptr(), @@ -1016,7 +1014,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, Context::Instance().GetCurrentStream()); if (q_int8) { - quantized_gemm(output, (T*)inp_norm.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1031,8 +1029,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), + inp_norm, + intermediate, #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); #else @@ -1040,20 +1038,45 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, #endif } if (act_func_type == ActivationFuncType::GELU) { - launch_bias_gelu((T*)output.data_ptr(), + launch_bias_gelu(intermediate, (T*)bias.data_ptr(), q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } else if (act_func_type == ActivationFuncType::ReLU) { - launch_bias_relu((T*)output.data_ptr(), + launch_bias_relu(intermediate, (T*)bias.data_ptr(), q_int8 ? weight.size(0) : weight.size(1), bsz, Context::Instance().GetCurrentStream()); } + if (q_int8) { + quantized_gemm( + output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz); + } else { + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight1.size(1), + bsz, + weight1.size(0), + &alpha, + &gemm_beta, + (T*)weight1.data_ptr(), + intermediate, + (T*)output.data_ptr(), +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + } - return inp_norm; + return torch::from_blob(inp_norm, input.sizes(), input.options()); } template @@ -1061,6 +1084,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, at::Tensor& residual, at::Tensor& input_bias, at::Tensor& weight, + at::Tensor& weight1, at::Tensor& bias, at::Tensor& gamma, at::Tensor& beta, @@ -1068,21 +1092,21 @@ std::vector ds_mlp_gemm(at::Tensor& input, bool preLayerNorm, bool mlp_after_attn, at::Tensor& q_scale, + at::Tensor& q_scale1, bool q_int8, int activation_type) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); int out_size = q_int8 ? weight.size(0) : weight.size(1); - auto output = at::from_blob((T*)Context::Instance().GetWorkSpace(), - {input_cont.size(0), input_cont.size(1), out_size}, + auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), + {input.size(0), input.size(1), out_size}, options); - int bsz = input_cont.size(0) * input_cont.size(1); + int bsz = input.size(0) * input.size(1); auto act_func_type = static_cast(activation_type); auto res_add = mlp_unfused_cublas(output, @@ -1090,6 +1114,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, residual, input_bias, weight, + weight1, bias, gamma, beta, @@ -1097,6 +1122,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, preLayerNorm, mlp_after_attn, q_scale, + q_scale1, q_int8, act_func_type); @@ -1207,7 +1233,7 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, template at::Tensor& residual_add_bias(at::Tensor& hidden_state, - const at::Tensor& residual, + at::Tensor& residual, const at::Tensor& attention_output, const at::Tensor& attention_bias, const at::Tensor& final_bias, @@ -1240,7 +1266,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, bsz, mp_size, Context::Instance().GetCurrentStream()); - return hidden_state; + return residual; } std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, @@ -1269,7 +1295,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), (__half*)key_cont.data_ptr(), @@ -1281,7 +1308,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, bsz, rotate_half, rotate_every_two, - Context::Instance().GetCurrentStream()); + Context::Instance().GetCurrentStream(), + Context::Instance().GetMaxTokenLenght()); return {query_cont, key_cont}; } diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index e72df0f14232..32d2df95be63 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -22,7 +22,8 @@ __global__ void bias_add_transform_0213(float* output, int rotary_dim, bool rotate_half, bool rotate_every_two, - int head_ext) + int head_ext, + int max_out_tokens) { int d0_stride = hidden_dim * seq_length; int d1_stride = hidden_dim; @@ -34,8 +35,8 @@ __global__ void bias_add_transform_0213(float* output, int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) int d3 = threadIdx.x; // Values (groups of 4) - int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); - int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); const float4* vals_vec = reinterpret_cast(vals); float4* output_vec = @@ -86,7 +87,8 @@ __global__ void bias_add_transform_0213(__half* output, // q int rotary_dim, bool rotate_half, bool rotate_every_two, - int head_ext) + int head_ext, + int max_out_tokens) { #if __CUDA_ARCH__ >= 700 @@ -101,8 +103,8 @@ __global__ void bias_add_transform_0213(__half* output, // q int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) int d3 = threadIdx.x; // Values (groups of 4) - int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); - int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens); + int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens); float4 vals_arr; float4 output_arr; @@ -165,7 +167,8 @@ void launch_bias_add_transform_0213(float* output, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count) + int trans_count, + int max_out_tokens) { hidden_dim >>= 2; int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; @@ -185,7 +188,8 @@ void launch_bias_add_transform_0213(float* output, rotary_dim >> 2, rotate_half, rotate_every_two, - head_ext); + head_ext, + max_out_tokens); } template void launch_bias_add_transform_0213(T* outputs, @@ -203,7 +207,8 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count); + int trans_count, + int max_out_tokens); template <> void launch_bias_add_transform_0213<__half>(__half* output, __half* k_cache, @@ -220,7 +225,8 @@ void launch_bias_add_transform_0213<__half>(__half* output, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count) + int trans_count, + int max_out_tokens) { hidden_dim >>= 3; int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; @@ -239,7 +245,8 @@ void launch_bias_add_transform_0213<__half>(__half* output, rotary_dim >> 3, rotate_half, rotate_every_two, - head_ext); + head_ext, + max_out_tokens); } // Bias add diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index f314f4485d6f..fe0616a75c2e 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -12,6 +12,10 @@ Copyright 2022 The Microsoft DeepSpeed Team #include "cublas_v2.h" #include "cuda.h" +#define MEGABYTE (1024 * 1024) +#define GIGABYTE (1024 * 1024 * 1024) + +#define MAX_OUT_TOKENS 8192 #define WARP_SIZE 32 #define CUDA_CHECK(callstr) \ @@ -43,7 +47,13 @@ inline int DS_GET_BLOCKS(const int N) class Context { public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0) + Context() + : _workspace(nullptr), + _seed(42), + _curr_offset(0), + _stream(0), + _free_memory_size(0), + _num_tokens(1) { if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { auto message = std::string("Fail to create cublas handle."); @@ -75,24 +85,51 @@ class Context { return _ctx; } - void GenWorkSpace(size_t size) + void GenWorkSpace(const unsigned& num_layers, + const size_t& batch_size, + const size_t& hidden_dim, + const unsigned& mp_size, + const bool& external_cache, + const size_t& elem_size, + const unsigned& rank) { + size_t total_size; + if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } + + size_t activation_size = 16 * hidden_dim * batch_size; + size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; + _max_seq_len = + (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / + elem_size)) / + (activation_size + cache_size); + size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) * + _max_seq_len * elem_size; + _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); + if (rank == 0 && !_workspace) + printf( + "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " + "tokens (input + output) to %lu \n", + _free_memory_size, + total_size, + _max_seq_len); if (!_workspace) { assert(_workspace == nullptr); - cudaMalloc(&_workspace, size); - } else if (_workSpaceSize < size) { + cudaMalloc(&_workspace, workSpaceSize); + } else if (_workSpaceSize < workSpaceSize) { cudaFree(_workspace); - cudaMalloc(&_workspace, size); + cudaMalloc(&_workspace, workSpaceSize); } if (!_workspace) { - size_t total_size, free_size; - cudaMemGetInfo(&free_size, &total_size); - printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", size, free_size, total_size); + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + workSpaceSize, + _free_memory_size, + total_size); throw std::runtime_error("Workspace is null."); } - _workSpaceSize = size; + _workSpaceSize = workSpaceSize; } + inline size_t GetMaxTokenLenght() const { return _max_seq_len; } cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } @@ -105,7 +142,7 @@ class Context { return _token_length; } - inline void reset_tokens(unsigned initial_tokens = 0) + inline void reset_tokens(unsigned initial_tokens = 1) { _num_tokens = initial_tokens; } //_token_length = 0; } @@ -165,7 +202,11 @@ class Context { void* _workspace; uint64_t _seed; uint64_t _curr_offset; + size_t _workSpaceSize; + size_t _free_memory_size; + + size_t _max_seq_len; cudaEvent_t _comp1_event; cudaEvent_t _comp2_event; diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 0391da4806f6..6302ceb2935d 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -21,7 +21,6 @@ Copyright 2022 The Microsoft DeepSpeed Team #include #include -#define MAX_OUT_TOKES 512 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 @@ -142,7 +141,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned batch, bool rotate_half, bool rotate_every_two, - cudaStream_t stream); + cudaStream_t stream, + int max_out_tokens); template void launch_moe_res_matmul(T* residual, @@ -178,4 +178,5 @@ void launch_bias_add_transform_0213(T* outputs, bool rotate_half, bool rotate_every_two, cudaStream_t stream, - int trans_count); + int trans_count, + int max_out_tokens); diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py old mode 100755 new mode 100644 index e7e4c16645f5..2fd8e7f4dcc0 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -279,6 +279,14 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): ###################### End of HF modeling_bloom addition ######################## + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + def compute_attention(qkv_out, input_mask): no_masking = input_mask is None @@ -369,6 +377,9 @@ def compute_attention(qkv_out, input_mask): return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer else: + #query = self._split_heads(query, self.num_heads, self.head_dim) + #key = self._split_heads(key, self.num_heads, self.head_dim) + #value = self._split_heads(value, self.num_heads, self.head_dim) # Note: This modification is added for the BLOOM-176B model and will be removed later! if config.bigscience_bloom: context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor) @@ -380,6 +391,8 @@ def compute_attention(qkv_out, input_mask): ) else 0 sliced_alibi = alibi[offset:batch_heads + offset, :, :] + +# attn_key_value = score_context_func( qkv_out, ((1 - input_mask).to(qkv_out.dype) * @@ -403,7 +416,6 @@ def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ inference_cuda_module.vector_matmul_fp32 - head_size = (attn_qkvw.shape[-1] // 3 // num_attention_heads_per_partition) if not config.pre_layer_norm: linear_func = inference_cuda_module.linear_layer_fp16 if config.fp16 else \ inference_cuda_module.linear_layer_fp32 @@ -411,25 +423,23 @@ def selfAttention_fp(): qkv_out = linear_func(input, attn_qkvw, attn_qkvb, - DeepSpeedTransformerInference.layer_id, - head_size) + DeepSpeedTransformerInference.layer_id) else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 - qkv_out = qkv_func( - input, - attn_qkvw, - attn_qkvw.scale, - (attn_qkvb if attn_qkvb is not None else norm_b), - norm_w, - norm_b, - config.epsilon, - (attn_qkvb is not None), - 1 if config.bigscience_bloom else - DeepSpeedTransformerInference.layer_id, - config.bigscience_bloom, - head_size, - config.q_int8) + qkv_out = qkv_func(input, + attn_qkvw, + attn_qkvw.scale, + (attn_qkvb if attn_qkvb is not None else norm_b), + norm_w, + norm_b, + config.epsilon, + (attn_qkvb is not None), + DeepSpeedTransformerInference.layer_id, + config.bigscience_bloom, + config.mp_size, + dist.get_rank() if dist.is_initialized() else 0, + config.q_int8) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, @@ -500,20 +510,18 @@ def __init__(self, self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' - self.attn_qkvw = nn.Parameter(torch.empty( - self.config.hidden_size, - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type, - device=device), + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), requires_grad=False) - self.attn_qkvb = nn.Parameter(torch.empty( - (self.config.hidden_size // self.config.mp_size) * 3, - dtype=data_type_fp, - device=device), + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, + dtype=data_type_fp, + device=device), requires_grad=False) - - self.attn_ow = nn.Parameter(torch.empty(self.config.hidden_size // - self.config.mp_size, + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, self.config.hidden_size, dtype=data_type, device=device), @@ -618,10 +626,11 @@ def forward(ctx, config.pre_layer_norm, False) else: - intermediate, residual_add = mlp_gemm_func(input, + output, residual_add = mlp_gemm_func(input, residual, bias, inter_w, + output_w, inter_b, attn_nw, attn_nb, @@ -629,26 +638,22 @@ def forward(ctx, config.pre_layer_norm, config.mlp_after_attn, inter_w.scale, + output_w.scale, config.q_int8, config.mlp_act_func_type) - output = vector_matmul_func(intermediate, - output_w, - False, - output_w.scale, - config.q_int8) - output = residual_add_func( - output, # hidden state - residual if config.pre_layer_norm else residual_add, # residual - input, # attention output - bias if bias is not None else output_b, + residual_add_func( + output, # hidden state + residual if config.pre_layer_norm else residual_add, # residual + input, # attention output output_b, + bias if bias is not None else output_b, config.mp_size, # model parallel size config.mlp_after_attn, # whether mlp is after attention (GPTJ model architecture runs the MLP layer in parallel with attention) bias is not None, # whether bias addition is fused config.pre_layer_norm) # whether the layer norm is applied before attention if mp_group is not None and dist.get_world_size(group=mp_group) > 1: - dist.all_reduce(output, group=mp_group) - return output + dist.all_reduce(residual, group=mp_group) + return residual @staticmethod def backward(ctx, grad_output): @@ -678,22 +683,20 @@ def __init__(self, dtype=data_type_fp, device=device), requires_grad=False) + intm_size_per_partition = self.config.intermediate_size // self.config.mp_size self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, - self.config.intermediate_size // - self.config.mp_size, + intm_size_per_partition, dtype=data_type, device=device), requires_grad=False) - self.inter_b = nn.Parameter(torch.empty(self.config.intermediate_size // - self.config.mp_size, + self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), requires_grad=False) - self.output_w = nn.Parameter(torch.empty( - (self.config.intermediate_size // self.config.mp_size), - self.config.hidden_size, - dtype=data_type, - device=device), + self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), requires_grad=False) self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, From 48a8b96995de73f6b0f024dab97a2fd94b9eefa3 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 23 Sep 2022 02:19:18 +0500 Subject: [PATCH 11/12] fix bert issue & remove some dead code --- .../inference/transformer_inference.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 2fd8e7f4dcc0..d0c368d8c66a 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -278,15 +278,7 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): return context_layer, presents ###################### End of HF modeling_bloom addition ######################## - - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - + def compute_attention(qkv_out, input_mask): no_masking = input_mask is None @@ -377,9 +369,6 @@ def compute_attention(qkv_out, input_mask): return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer else: - #query = self._split_heads(query, self.num_heads, self.head_dim) - #key = self._split_heads(key, self.num_heads, self.head_dim) - #value = self._split_heads(value, self.num_heads, self.head_dim) # Note: This modification is added for the BLOOM-176B model and will be removed later! if config.bigscience_bloom: context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor) @@ -641,12 +630,13 @@ def forward(ctx, output_w.scale, config.q_int8, config.mlp_act_func_type) + residual = residual if config.pre_layer_norm else residual_add residual_add_func( output, # hidden state - residual if config.pre_layer_norm else residual_add, # residual + residual, # residual input, # attention output - output_b, bias if bias is not None else output_b, + output_b, config.mp_size, # model parallel size config.mlp_after_attn, # whether mlp is after attention (GPTJ model architecture runs the MLP layer in parallel with attention) bias is not None, # whether bias addition is fused From c1d83f99927ecba7869960d75d29117665f69780 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 23 Sep 2022 02:26:49 +0500 Subject: [PATCH 12/12] fix formating --- deepspeed/ops/transformer/inference/transformer_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index d0c368d8c66a..ff50d43d8518 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -278,7 +278,7 @@ def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): return context_layer, presents ###################### End of HF modeling_bloom addition ######################## - + def compute_attention(qkv_out, input_mask): no_masking = input_mask is None