Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/includes/ds_transformer_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class BertTransformerLayer {
int seq_length,
float attn_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
Expand Down
9 changes: 7 additions & 2 deletions csrc/includes/normalize_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ class Normalize_Layer {
float epsilon;
bool training;
bool useMean;
Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
float epsilon = 1e-12,
bool training = true,
bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(1e-12),
epsilon(epsilon),
training(training),
useMean(useMean)
{
Expand Down
16 changes: 8 additions & 8 deletions csrc/transformer/dropout_kernels.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ __global__ void dropout_kernel(const int N,
m[3] = (uint8_t)(rand.w > ratio);

float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[tid];
float4 b_data = bias_cast[j % (dim / unroll_factor)];

x_data.x += b_data.x;
x_data.y += b_data.y;
Expand All @@ -515,7 +515,7 @@ __global__ void dropout_kernel(const int N,
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[threadIdx.x % dim];
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
Expand Down Expand Up @@ -553,7 +553,7 @@ __global__ void dropout_kernel(const int N,
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);

data_f = Xdata_cast[j];
bias_f = bias_cast[tid];
bias_f = bias_cast[j % (dim / unroll_factor)];

float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
Expand Down Expand Up @@ -595,7 +595,7 @@ __global__ void dropout_kernel(const int N,
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[threadIdx.x % dim];
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
Expand Down Expand Up @@ -678,7 +678,7 @@ __global__ void dropout_kernel(const int N,
m[3] = (uint8_t)(rand.w > ratio);

float4 out_data;
float4 b_data = bias_cast[tid];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];

Expand Down Expand Up @@ -707,7 +707,7 @@ __global__ void dropout_kernel(const int N,
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[threadIdx.x % dim];
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
Expand Down Expand Up @@ -758,7 +758,7 @@ __global__ void dropout_kernel(const int N,
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);

bias_f = bias_cast[tid];
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];

Expand Down Expand Up @@ -813,7 +813,7 @@ __global__ void dropout_kernel(const int N,
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[threadIdx.x % dim];
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
Expand Down
5 changes: 5 additions & 0 deletions csrc/transformer/ds_transformer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
int seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
Expand Down Expand Up @@ -83,11 +84,13 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
Expand Down Expand Up @@ -583,6 +586,7 @@ int create_transformer_layer(int layer_id,
int intermediate_size,
float attn_dropout_ratio,
float hidden_dropout_ratio,
float layer_norm_eps,
int seed,
bool pre_or_postLayerNorm,
bool test_gemm,
Expand All @@ -603,6 +607,7 @@ int create_transformer_layer(int layer_id,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
Expand Down
14 changes: 13 additions & 1 deletion csrc/transformer/transform_kernels.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,23 @@ __global__ void bias_add_transform_0213<__half>(__half* output,
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];

#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];

#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;

#endif
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def replace_fn(child):
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
layer_norm_eps=bert_config.layer_norm_eps,
seed=seed,
fp16=fp16,
pre_layer_norm=preln,
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/ops/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self,
hidden_dropout_ratio=-1,
num_hidden_layers=-1,
initializer_range=-1,
layer_norm_eps=1e-12,
local_rank=-1,
seed=-1,
fp16=False,
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(self,
self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
self.adjust_init_range = adjust_init_range
self.test_gemm = False
self.layer_norm_eps = layer_norm_eps
self.training = training
self.is_grad_enabled = True
self.attn_dropout_checkpoint = attn_dropout_checkpoint
Expand Down Expand Up @@ -553,6 +555,7 @@ def __init__(self, config, initial_weights=None, initial_biases=None):
self.config.intermediate_size,
self.config.attn_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.layer_norm_eps,
self.config.seed,
self.config.pre_layer_norm,
self.config.test_gemm,
Expand Down