From 66b13ad14560fb003c2a0743942a1be308e6e18b Mon Sep 17 00:00:00 2001 From: Andrey Chernykh Date: Mon, 26 Sep 2022 12:10:55 +0700 Subject: [PATCH 1/4] Fix issue with corrupted output on long generation for GPT --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- csrc/transformer/inference/includes/inference_context.h | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 65549cdcd71a..f3bd4bd40f3a 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -461,7 +461,7 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; - T* temp_buf = (T*)output.data_ptr() + at::numel(output); + T* temp_buf = (T*)kv_cache+value_offset+value_offset; launch_bias_add_transform_0213((T*)query_cont, kv_cache, kv_cache + value_offset, diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index fe0616a75c2e..2423643b677e 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -97,12 +97,13 @@ class Context { if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } size_t activation_size = 16 * hidden_dim * batch_size; + size_t temp_size = 2 * hidden_dim * batch_size; size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; _max_seq_len = (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / elem_size)) / (activation_size + cache_size); - size_t workSpaceSize = (external_cache ? activation_size : (activation_size + cache_size)) * + size_t workSpaceSize = ((external_cache ? activation_size : (activation_size + cache_size)) + temp_size) * _max_seq_len * elem_size; _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); if (rank == 0 && !_workspace) From 8d8be855c5a1eede761bc95a47eedcf1321da4af Mon Sep 17 00:00:00 2001 From: Andrey Chernykh Date: Wed, 28 Sep 2022 18:18:10 +0700 Subject: [PATCH 2/4] Move workspace allocation to TransformerInference --- .../transformer/inference/csrc/pt_binding.cpp | 33 ++++++------------- .../inference/includes/inference_context.h | 22 +++++++++---- .../inference/transformer_inference.py | 15 +++++++++ 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index f3bd4bd40f3a..9ba4d2a0f8dd 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -103,13 +103,15 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, template void allocate_workspace(size_t hidden_dim, size_t batch_size, + size_t prompt_length, unsigned num_layers, + unsigned num_heads, unsigned mp_size = 1, bool external_cache = false, unsigned rank = 0) { Context::Instance().GenWorkSpace( - num_layers, batch_size, hidden_dim, mp_size, external_cache, sizeof(T), rank); + num_layers, num_heads, batch_size, prompt_length, hidden_dim, mp_size, external_cache, sizeof(T), rank); } template @@ -352,15 +354,10 @@ void attention_unfused(T* prev_key_cont, float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; - T* workspace; - if (seq_len == 1) { - workspace = (T*)output + bsz * seq_len * heads * k; - } else { - // If we are doing the prompt, switch to the tail workspace - T* scratch = (T*)Context::Instance().GetWorkSpace(); - workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) - - bsz * heads * seq_len * soft_len); - } + // Always use the tail workspace + T* scratch = (T*)Context::Instance().GetWorkSpace(); + T *workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) - + bsz * heads * seq_len * soft_len); cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), @@ -461,7 +458,7 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; - T* temp_buf = (T*)kv_cache+value_offset+value_offset; + T* temp_buf = (T*)output.data_ptr() + at::numel(output); launch_bias_add_transform_0213((T*)query_cont, kv_cache, kv_cache + value_offset, @@ -733,11 +730,6 @@ std::vector ds_qkv_gemm(at::Tensor& input, int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); int out_size = q_int8 ? weight.size(0) : weight.size(1); - if (!workspace) - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), input.size(0), num_layers, mp_size, external_cache, rank); - workspace = (T*)Context::Instance().GetWorkSpace(); auto options = at::TensorOptions() .dtype(input.options().dtype()) @@ -846,13 +838,6 @@ at::Tensor ds_linear_layer(at::Tensor& input, int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - // Reallocate memory if we received a new prompt - if (!workspace || input.size(1) != 1) { - cublasSetStream(Context::Instance().GetCublasHandle(), - Context::Instance().GetCurrentStream()); - allocate_workspace(input.size(2), input.size(0), num_layers); - workspace = (T*)Context::Instance().GetWorkSpace(); - } auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); float alpha = (T)1.0; @@ -1425,4 +1410,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &einsum_sec_sm_ecm<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); + m.def("allocate_workspace_fp32", &allocate_workspace, "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); + m.def("allocate_workspace_fp16", &allocate_workspace<__half>, "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); } diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 2423643b677e..91a3812eda5b 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -86,7 +86,9 @@ class Context { } void GenWorkSpace(const unsigned& num_layers, + const unsigned& num_heads, const size_t& batch_size, + const size_t& prompt_len, const size_t& hidden_dim, const unsigned& mp_size, const bool& external_cache, @@ -97,15 +99,21 @@ class Context { if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } size_t activation_size = 16 * hidden_dim * batch_size; - size_t temp_size = 2 * hidden_dim * batch_size; + size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size; size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; - _max_seq_len = - (((_free_memory_size - (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE) / - elem_size)) / - (activation_size + cache_size); - size_t workSpaceSize = ((external_cache ? activation_size : (activation_size + cache_size)) + temp_size) * - _max_seq_len * elem_size; + size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; + if (_free_memory_size < minimal_requirements) { + printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", + minimal_requirements, + _free_memory_size, + total_size); + throw std::runtime_error("Workspace can't be allocated, no enough memory."); + } + + _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / (activation_size + cache_size); _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); + size_t workSpaceSize = ((external_cache ? activation_size : (activation_size + cache_size))) * + _max_seq_len * elem_size + temp_size; if (rank == 0 && !_workspace) printf( "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index ff50d43d8518..e1da6a430c3a 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -799,6 +799,8 @@ def __init__(self, device=device), requires_grad=False) self.layer_past = None + self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \ + inference_cuda_module.allocate_workspace_fp16 def forward( self, @@ -820,6 +822,19 @@ def forward( # This needs to be redesigned later! layer_head_mask=None, past_key_value=None): + # Allocate memory only on first layer forward + if self.config.layer_id == 0: + self.allocate_workspace( + self.config.hidden_size, + input.size()[0], + input.size()[1], + DeepSpeedTransformerInference.layer_id, + self.config.heads, + self.config.mp_size, + self.config.bigscience_bloom, + dist.get_rank() if dist.is_initialized() else 0 + ) + get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask From da40919891533bb9edfd03380248790afdcdde0f Mon Sep 17 00:00:00 2001 From: Andrey Chernykh Date: Thu, 29 Sep 2022 07:41:56 +0700 Subject: [PATCH 3/4] Update attention_unfused required memory size --- csrc/transformer/inference/csrc/pt_binding.cpp | 5 +---- csrc/transformer/inference/includes/inference_context.h | 9 +++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 9ba4d2a0f8dd..cfaf91de8a0e 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -354,10 +354,7 @@ void attention_unfused(T* prev_key_cont, float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; - // Always use the tail workspace - T* scratch = (T*)Context::Instance().GetWorkSpace(); - T *workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) - - bsz * heads * seq_len * soft_len); + T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace(); cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 52388a5b00b2..612eaf31b0c9 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -55,7 +55,8 @@ class Context { _curr_offset(0), _stream(0), _free_memory_size(0), - _num_tokens(1) + _num_tokens(1), + _attention_unfused_workspace_offset(0) { if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { auto message = std::string("Fail to create cublas handle."); @@ -101,7 +102,7 @@ class Context { if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); } size_t activation_size = 16 * hidden_dim * batch_size; - size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size; + size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size; size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; if (_free_memory_size < minimal_requirements) { @@ -139,6 +140,7 @@ class Context { throw std::runtime_error("Workspace is null."); } _workSpaceSize = workSpaceSize; + _attention_unfused_workspace_offset = workSpaceSize - temp_size; } inline size_t GetMaxTokenLenght() const { return _max_seq_len; } @@ -146,6 +148,7 @@ class Context { size_t get_workspace_size() const { return _workSpaceSize; } void* GetWorkSpace() { return _workspace; } + void* GetAttentionUnfusedWorkspace() { return _workspace + _attention_unfused_workspace_offset; } inline unsigned new_token(unsigned layer_id) { @@ -211,6 +214,8 @@ class Context { cudaEvent_t _comm_event; void* _workspace; + // offset from _workspace for attention unfused memory + size_t _attention_unfused_workspace_offset; uint64_t _seed; uint64_t _curr_offset; From 693ad1781e43531a45a714be25e59da5e726210f Mon Sep 17 00:00:00 2001 From: Andrey Chernykh Date: Thu, 29 Sep 2022 09:00:46 +0700 Subject: [PATCH 4/4] fix format issues --- .../transformer/inference/csrc/pt_binding.cpp | 19 +++++++++++++++---- .../inference/includes/inference_context.h | 17 ++++++++++++----- .../inference/transformer_inference.py | 18 ++++++++---------- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index cfaf91de8a0e..45ed6e6bdd22 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -110,8 +110,15 @@ void allocate_workspace(size_t hidden_dim, bool external_cache = false, unsigned rank = 0) { - Context::Instance().GenWorkSpace( - num_layers, num_heads, batch_size, prompt_length, hidden_dim, mp_size, external_cache, sizeof(T), rank); + Context::Instance().GenWorkSpace(num_layers, + num_heads, + batch_size, + prompt_length, + hidden_dim, + mp_size, + external_cache, + sizeof(T), + rank); } template @@ -1407,6 +1414,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &einsum_sec_sm_ecm<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); - m.def("allocate_workspace_fp32", &allocate_workspace, "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); - m.def("allocate_workspace_fp16", &allocate_workspace<__half>, "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); + m.def("allocate_workspace_fp32", + &allocate_workspace, + "DeepSpeed memory allocation for GPT inference with fp32 (CUDA)"); + m.def("allocate_workspace_fp16", + &allocate_workspace<__half>, + "DeepSpeed memory allocation for GPT inference with fp16 (CUDA)"); } diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index 612eaf31b0c9..2fc1e7082662 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -104,7 +104,8 @@ class Context { size_t activation_size = 16 * hidden_dim * batch_size; size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size; size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2; - size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; + size_t minimal_requirements = + temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE; if (_free_memory_size < minimal_requirements) { printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n", minimal_requirements, @@ -113,10 +114,13 @@ class Context { throw std::runtime_error("Workspace can't be allocated, no enough memory."); } - _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / (activation_size + cache_size); + _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) / + (activation_size + cache_size); _max_seq_len = std::min((size_t)MAX_OUT_TOKENS, _max_seq_len); - size_t workSpaceSize = ((external_cache ? activation_size : (activation_size + cache_size))) * - _max_seq_len * elem_size + temp_size; + size_t workSpaceSize = + ((external_cache ? activation_size : (activation_size + cache_size))) * _max_seq_len * + elem_size + + temp_size; if (rank == 0 && !_workspace) printf( "Free memory : %lu (Bytes) Total memory: %lu (Bytes) Setting maximum total " @@ -148,7 +152,10 @@ class Context { size_t get_workspace_size() const { return _workSpaceSize; } void* GetWorkSpace() { return _workspace; } - void* GetAttentionUnfusedWorkspace() { return _workspace + _attention_unfused_workspace_offset; } + void* GetAttentionUnfusedWorkspace() + { + return _workspace + _attention_unfused_workspace_offset; + } inline unsigned new_token(unsigned layer_id) { diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index e1da6a430c3a..cc7ed35a33a5 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -824,16 +824,14 @@ def forward( past_key_value=None): # Allocate memory only on first layer forward if self.config.layer_id == 0: - self.allocate_workspace( - self.config.hidden_size, - input.size()[0], - input.size()[1], - DeepSpeedTransformerInference.layer_id, - self.config.heads, - self.config.mp_size, - self.config.bigscience_bloom, - dist.get_rank() if dist.is_initialized() else 0 - ) + self.allocate_workspace(self.config.hidden_size, + input.size()[0], + input.size()[1], + DeepSpeedTransformerInference.layer_id, + self.config.heads, + self.config.mp_size, + self.config.bigscience_bloom, + dist.get_rank() if dist.is_initialized() else 0) get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask