From 6308ff70363d25b6a11318d62ca5fea45b28cb3e Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 24 Feb 2021 20:28:05 +0000 Subject: [PATCH 1/3] fix the bias-add precision and indexing and also adding the layer-norm-eps as a configurable parameter for transformer --- csrc/includes/ds_transformer_cuda.h | 1 + csrc/includes/normalize_layer.h | 9 +++++++-- csrc/transformer/dropout_kernels.cu | 16 ++++++++-------- csrc/transformer/ds_transformer_cuda.cpp | 5 +++++ csrc/transformer/transform_kernels.cu | 18 ++++++++++++++---- deepspeed/module_inject/replace_module.py | 1 + deepspeed/ops/transformer/transformer.py | 3 +++ 7 files changed, 39 insertions(+), 14 deletions(-) mode change 100644 => 100755 csrc/transformer/dropout_kernels.cu diff --git a/csrc/includes/ds_transformer_cuda.h b/csrc/includes/ds_transformer_cuda.h index dbae797a8ecd..cdd65b4a7da7 100755 --- a/csrc/includes/ds_transformer_cuda.h +++ b/csrc/includes/ds_transformer_cuda.h @@ -42,6 +42,7 @@ class BertTransformerLayer { int seq_length, float attn_dropout_ratio, float hidden_output_dropout_ratio, + float layer_norm_eps, bool pre_or_postLayerNorm, const std::vector>& gemm_algos, bool attn_dropout_checkpoint, diff --git a/csrc/includes/normalize_layer.h b/csrc/includes/normalize_layer.h index bfe84636ddb9..e18e01a33893 100644 --- a/csrc/includes/normalize_layer.h +++ b/csrc/includes/normalize_layer.h @@ -18,11 +18,16 @@ class Normalize_Layer { float epsilon; bool training; bool useMean; - Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true) + Config(uint32_t batch, + uint32_t seq, + uint32_t h, + float epsilon = 1e-12, + bool training = true, + bool useMean = true) : batchSize(batch), seqLength(seq), hiddenDim(h), - epsilon(1e-12), + epsilon(epsilon), training(training), useMean(useMean) { diff --git a/csrc/transformer/dropout_kernels.cu b/csrc/transformer/dropout_kernels.cu old mode 100644 new mode 100755 index 6b0655b788eb..98f2ac22fddd --- a/csrc/transformer/dropout_kernels.cu +++ b/csrc/transformer/dropout_kernels.cu @@ -493,7 +493,7 @@ __global__ void dropout_kernel(const int N, m[3] = (uint8_t)(rand.w > ratio); float4 x_data = Xdata_cast[j]; - float4 b_data = bias_cast[tid]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; x_data.x += b_data.x; x_data.y += b_data.y; @@ -515,7 +515,7 @@ __global__ void dropout_kernel(const int N, float* rand_data = &(rand.x); int k = 0; for (int i = high_index; i < N; i++) { - float x_data = Xdata[i] + bias[threadIdx.x % dim]; + float x_data = Xdata[i] + bias[i % dim]; uint8_t m = (uint8_t)(rand_data[k++] > ratio); Xdata[i] = x_data * scale * m; mask[i] = m; @@ -553,7 +553,7 @@ __global__ void dropout_kernel(const int N, __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); data_f = Xdata_cast[j]; - bias_f = bias_cast[tid]; + bias_f = bias_cast[j % (dim / unroll_factor)]; float2 data_h_0 = __half22float2(data_h[0]); float2 data_h_1 = __half22float2(data_h[1]); @@ -595,7 +595,7 @@ __global__ void dropout_kernel(const int N, float* rand_data = &(rand.x); int k = 0; for (int i = high_index; i < N; i++) { - float x_data = (float)Xdata[i] + (float)bias[threadIdx.x % dim]; + float x_data = (float)Xdata[i] + (float)bias[i % dim]; uint8_t m = (uint8_t)(rand_data[k++] > ratio); Xdata[i] = __float2half(x_data * scale * m); mask[i] = m; @@ -678,7 +678,7 @@ __global__ void dropout_kernel(const int N, m[3] = (uint8_t)(rand.w > ratio); float4 out_data; - float4 b_data = bias_cast[tid]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; float4 res_data = residual_cast[j]; float4 inp_data = input_cast[j]; @@ -707,7 +707,7 @@ __global__ void dropout_kernel(const int N, float* rand_data = &(rand.x); int k = 0; for (int i = high_index; i < N; i++) { - float x_data = input[i] + bias[threadIdx.x % dim]; + float x_data = input[i] + bias[i % dim]; uint8_t m = (uint8_t)(rand_data[k++] > ratio); x_data = x_data * scale * m; x_data += residual[i]; @@ -758,7 +758,7 @@ __global__ void dropout_kernel(const int N, float2 input_f; __half2* input_h = reinterpret_cast<__half2*>(&input_f); - bias_f = bias_cast[tid]; + bias_f = bias_cast[j % (dim / unroll_factor)]; residual_f = residual_cast[j]; input_f = input_cast[j]; @@ -813,7 +813,7 @@ __global__ void dropout_kernel(const int N, float* rand_data = &(rand.x); int k = 0; for (int i = high_index; i < N; i++) { - float x_data = (float)input[i] + (float)bias[threadIdx.x % dim]; + float x_data = (float)input[i] + (float)bias[i % dim]; uint8_t m = (uint8_t)(rand_data[k++] > ratio); x_data = x_data * scale * m; x_data += (float)residual[i]; diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index 2bb96fa99d67..b4e28915011b 100755 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -52,6 +52,7 @@ BertTransformerLayer::BertTransformerLayer(int layer_id, int seq_length, float attn_prob_dropout_ratio, float hidden_output_dropout_ratio, + float layer_norm_eps, bool pre_or_postLayerNorm, const std::vector>& gemm_algos, bool attn_dropout_checkpoint, @@ -83,11 +84,13 @@ BertTransformerLayer::BertTransformerLayer(int layer_id, _attn_layer_norm(typename Normalize_Layer::Config(batch_size, seq_length, hidden_size, + layer_norm_eps, true, !normalize_invertible)), _layer_norm(typename Normalize_Layer::Config(batch_size, seq_length, hidden_size, + layer_norm_eps, true, !normalize_invertible)), _ff1(typename FeedForward::Config(batch_size * seq_length, @@ -583,6 +586,7 @@ int create_transformer_layer(int layer_id, int intermediate_size, float attn_dropout_ratio, float hidden_dropout_ratio, + float layer_norm_eps, int seed, bool pre_or_postLayerNorm, bool test_gemm, @@ -603,6 +607,7 @@ int create_transformer_layer(int layer_id, init_seq_length, attn_dropout_ratio, hidden_dropout_ratio, + layer_norm_eps, pre_or_postLayerNorm, Context::Instance().GetGemmAlgos(), attn_dropout_checkpoint, diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu index 691b877771b9..5c08f401593a 100644 --- a/csrc/transformer/transform_kernels.cu +++ b/csrc/transformer/transform_kernels.cu @@ -259,11 +259,21 @@ __global__ void bias_add_transform_0213<__half>(__half* output, bias_arr = bias_vec[d3]; vals_arr = vals_vec[d3]; + float2 bias_arr_f[4]; + float2 vals_arr_f[4]; +#pragma unroll + for (int l = 0; l < 4; l++) { + bias_arr_f[l] = __half22float2(bias_half[l]); + vals_arr_f[l] = __half22float2(vals_half[l]); + vals_arr_f[l].x += bias_arr_f[l].x; + vals_arr_f[l].y += bias_arr_f[l].y; + output_half[l] = __float22half2_rn(vals_arr_f[l]); + } - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; + // output_half[0] = vals_half[0] + bias_half[0]; + // output_half[1] = vals_half[1] + bias_half[1]; + // output_half[2] = vals_half[2] + bias_half[2]; + // output_half[3] = vals_half[3] + bias_half[3]; output_vec[d3] = output_arr; diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 6b25c96a6d14..de014640ad6a 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -38,6 +38,7 @@ def replace_fn(child): hidden_dropout_ratio=bert_config.hidden_dropout_prob, num_hidden_layers=bert_config.num_hidden_layers, initializer_range=bert_config.initializer_range, + layer_norm_eps=bert_config.layer_norm_eps, seed=seed, fp16=fp16, pre_layer_norm=preln, diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index f0979f2e3f2a..0238eed144e1 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -101,6 +101,7 @@ def __init__(self, hidden_dropout_ratio=-1, num_hidden_layers=-1, initializer_range=-1, + layer_norm_eps=1e-12, local_rank=-1, seed=-1, fp16=False, @@ -130,6 +131,7 @@ def __init__(self, self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required self.adjust_init_range = adjust_init_range self.test_gemm = False + self.layer_norm_eps = layer_norm_eps self.training = training self.is_grad_enabled = True self.attn_dropout_checkpoint = attn_dropout_checkpoint @@ -553,6 +555,7 @@ def __init__(self, config, initial_weights=None, initial_biases=None): self.config.intermediate_size, self.config.attn_dropout_ratio, self.config.hidden_dropout_ratio, + self.config.layer_norm_eps, self.config.seed, self.config.pre_layer_norm, self.config.test_gemm, From 3b8af456fbd6522fb2877c8fb70099d48a2c123a Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 24 Feb 2021 21:05:15 +0000 Subject: [PATCH 2/3] add ACC_HALF config --- csrc/transformer/transform_kernels.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu index 5c08f401593a..e423e16af3be 100644 --- a/csrc/transformer/transform_kernels.cu +++ b/csrc/transformer/transform_kernels.cu @@ -259,6 +259,13 @@ __global__ void bias_add_transform_0213<__half>(__half* output, bias_arr = bias_vec[d3]; vals_arr = vals_vec[d3]; + +#if __ACC_HALF__ + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; +#else float2 bias_arr_f[4]; float2 vals_arr_f[4]; #pragma unroll @@ -269,12 +276,7 @@ __global__ void bias_add_transform_0213<__half>(__half* output, vals_arr_f[l].y += bias_arr_f[l].y; output_half[l] = __float22half2_rn(vals_arr_f[l]); } - - // output_half[0] = vals_half[0] + bias_half[0]; - // output_half[1] = vals_half[1] + bias_half[1]; - // output_half[2] = vals_half[2] + bias_half[2]; - // output_half[3] = vals_half[3] + bias_half[3]; - +#endif output_vec[d3] = output_arr; #endif From 9c9f80f1366a6d82a8dca6e88e98cf3e7f7adbfc Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 24 Feb 2021 21:09:45 +0000 Subject: [PATCH 3/3] use defined to check if ACC_Half is defined --- csrc/transformer/transform_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 csrc/transformer/transform_kernels.cu diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu old mode 100644 new mode 100755 index e423e16af3be..7d8a27eeeb43 --- a/csrc/transformer/transform_kernels.cu +++ b/csrc/transformer/transform_kernels.cu @@ -260,7 +260,7 @@ __global__ void bias_add_transform_0213<__half>(__half* output, bias_arr = bias_vec[d3]; vals_arr = vals_vec[d3]; -#if __ACC_HALF__ +#if defined(__ACC_HALF__) output_half[0] = vals_half[0] + bias_half[0]; output_half[1] = vals_half[1] + bias_half[1]; output_half[2] = vals_half[2] + bias_half[2];