diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index acb35325dd9e..c3d65fa037c2 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -174,7 +174,7 @@ __global__ void fused_bias_residual(float* input, float* attnbias, int total_count, int intermediate_size, - int mp_size, + float mp_scale, bool preln) { float4* input_cast = reinterpret_cast(input); @@ -191,10 +191,10 @@ __global__ void fused_bias_residual(float* input, float4 bias_data = bias_cast[offset % intermediate_size]; float4 attn_bias = attnbias_cast[offset % intermediate_size]; if (preln) { - data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x); - data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y); - data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z); - data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w); + data.x = (data.x + res_vec.x + bias_data.x + attn_bias.x) * mp_scale + (out.x); + data.y = (data.y + res_vec.y + bias_data.y + attn_bias.y) * mp_scale + (out.y); + data.z = (data.z + res_vec.z + bias_data.z + attn_bias.z) * mp_scale + (out.z); + data.w = (data.w + res_vec.w + bias_data.w + attn_bias.w) * mp_scale + (out.w); } else { data.x = data.x + out.x + bias_data.x; data.y = data.y + out.y + bias_data.y; @@ -212,7 +212,7 @@ __global__ void fused_bias_residual(__half* input, __half* attn_bias, int total_count, int intermediate_size, - int mp_size, + float mp_scale, bool preln) { #ifdef HALF_PRECISION_AVAILABLE @@ -257,13 +257,13 @@ __global__ void fused_bias_residual(__half* input, if (preln) { low_data.x = - (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x)); + (low_data.x + low_res.x + (low_bias.x + attn_low_bias.x)) * mp_scale + low_out.x; low_data.y = - (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y)); - high_data.x = (high_data.x + high_res.x) * mp_size + - (high_out.x + (high_bias.x + attn_high_bias.x)); - high_data.y = (high_data.y + high_res.y) * mp_size + - (high_out.y + (high_bias.y + attn_high_bias.y)); + (low_data.y + low_res.y + (low_bias.y + attn_low_bias.y)) * mp_scale + low_out.y; + high_data.x = (high_data.x + high_res.x + (high_bias.x + attn_high_bias.x)) * mp_scale + + high_out.x; + high_data.y = (high_data.y + high_res.y + (high_bias.y + attn_high_bias.y)) * mp_scale + + high_out.y; } else { low_data.x = (low_data.x + low_out.x + low_bias.x); low_data.y = (low_data.y + low_out.y + low_bias.y); @@ -310,7 +310,7 @@ __global__ void gptj_residual_add(float* input, float* attnbias, int total_count, int intermediate_size, - float mp_size) + float mp_scale) { float4* input_cast = reinterpret_cast(input); float4* output_cast = reinterpret_cast(output); @@ -332,10 +332,10 @@ __global__ void gptj_residual_add(float* input, data.z += attn_bias.z; data.w += attn_bias.w; } - data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x); - data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y); - data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z); - data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w); + data.x = data.x * mp_scale + (out.x + res_vec.x + bias_data.x); + data.y = data.y * mp_scale + (out.y + res_vec.y + bias_data.y); + data.z = data.z * mp_scale + (out.z + res_vec.z + bias_data.z); + data.w = data.w * mp_scale + (out.w + res_vec.w + bias_data.w); output_cast[offset] = data; } @@ -348,7 +348,7 @@ __global__ void gptj_residual_add(__half* input, __half* attn_bias, int total_count, int intermediate_size, - float mp_size) + float mp_scale) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) @@ -395,10 +395,10 @@ __global__ void gptj_residual_add(__half* input, high_data.y += attn_high_bias.y; } - low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x)); - low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y)); - high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x)); - high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y)); + low_data.x = low_data.x * mp_scale + (low_out.x + low_res.x + (low_bias.x)); + low_data.y = low_data.y * mp_scale + (low_out.y + low_res.y + (low_bias.y)); + high_data.x = high_data.x * mp_scale + (high_out.x + high_res.x + (high_bias.x)); + high_data.y = high_data.y * mp_scale + (high_out.y + high_res.y + (high_bias.y)); vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu index 7f3cfc118631..22c23011ede8 100644 --- a/csrc/transformer/inference/csrc/normalize.cu +++ b/csrc/transformer/inference/csrc/normalize.cu @@ -88,6 +88,7 @@ __global__ void fused_bias_residual_layer_norm(__half* output, int row_stride) { #ifdef HALF_PRECISION_AVAILABLE + int iteration_stride = blockDim.x; int iterations = row_stride / iteration_stride; diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 2c300c0b6c92..360beaa10284 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -11,27 +11,36 @@ std::array gemm_algos = std::array({99, 99, 99}); template at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, + at::Tensor& alibi, bool triangular, bool recompute, bool local_attention, int window_size, - bool async_op) + bool async_op, + float layer_scale, + int head_offset, + int mp_size) { auto attn_scores_c = attn_scores.contiguous(); int bsz = attn_scores_c.size(0); int seq_len = attn_scores_c.size(1); int len = attn_scores_c.sizes().size(); - if (len > 3) seq_len = attn_scores_c.size(2); + if (len > 2) seq_len = attn_scores_c.size(2); int soft_len = attn_scores_c.size(2); if (len > 3) soft_len = attn_scores_c.size(3); int heads = 1; - if (len > 3) heads = attn_scores_c.size(1); + if (len > 1) heads = attn_scores_c.size(1); + + int mask_stride = 1; + if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2); launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, triangular, recompute, local_attention, @@ -40,7 +49,9 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, heads, seq_len, soft_len, - 1.0, + head_offset, + mask_stride, + mp_size, Context::Instance().GetCurrentStream(async_op)); return attn_scores_c; @@ -123,6 +134,8 @@ void attention_unfused(at::Tensor& prev_key_cont, float gemm_beta = 0.0; auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options); int k = prev_value_cont.size(2) / heads; + int mask_stride = heads; + if (attn_mask.sizes().size() > 2 && attn_mask.size(2) == 1) mask_stride *= seq_len; cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), soft_len, @@ -144,8 +157,22 @@ void attention_unfused(at::Tensor& prev_key_cont, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif - attn_score = ds_softmax( - attn_score, attn_mask, triangular, recompute, local_attention, window_size, false); + launch_attn_softmax_v2((T*)attn_score.data_ptr(), + (T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr), + (T*)nullptr, + 1.0, + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 0, + mask_stride, + 1, + Context::Instance().GetCurrentStream(false)); alpha = 1.0; cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), k, @@ -225,6 +252,8 @@ std::vector ds_softmax_context1(at::Tensor& query, template void ds_softmax_internal(T* attn_scores, at::Tensor& attn_mask, + at::Tensor& alibi, + float& layer_scale, bool triangular, bool recompute, bool local_attention, @@ -234,8 +263,12 @@ void ds_softmax_internal(T* attn_scores, int soft_len, int heads) { + int mask_stride = 1; + if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2); launch_attn_softmax_v2((T*)attn_scores, (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), + layer_scale, triangular, recompute, local_attention, @@ -244,7 +277,9 @@ void ds_softmax_internal(T* attn_scores, heads, seq_len, soft_len, - 1.0, + 0, + mask_stride, + 1, at::cuda::getCurrentCUDAStream()); } @@ -263,9 +298,12 @@ void attention_unfused(T* prev_key_cont, bool triangular, bool recompute, bool local_attention, - int window_size) + int window_size, + at::Tensor& alibi, + int layer_id) { - float alpha = norm_factor * norm_factor; + float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; + float alpha = norm_factor * norm_factor / layer_scale; float gemm_beta = 0.0; T* workspace = (T*)output + bsz * seq_len * heads * k; @@ -292,6 +330,8 @@ void attention_unfused(T* prev_key_cont, #endif ds_softmax_internal(workspace, attn_mask, + alibi, + layer_scale, triangular, recompute, local_attention, @@ -336,7 +376,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, int window_size, bool no_masking, unsigned layer_id, - unsigned num_layers) + unsigned num_layers, + at::Tensor& alibi) { unsigned bsz = query_key_value.size(0); unsigned seq_len = query_key_value.size(1); @@ -410,7 +451,9 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, (triangular && is_prompt), is_prompt, local_attention, - window_size); + window_size, + alibi, + layer_id); launch_transform4d_0213((T*)output.data_ptr(), temp_buf, bsz, @@ -506,7 +549,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, { int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); - workspace += (3 * input.size(0) * MAX_OUT_TOKES * input.size(2)); + workspace += (3 * bsz * input.size(2)); ds_layernorm_internal(workspace, input, gamma, beta, epsilon); // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index bf3c8bc90049..dcc90c3b5cbb 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -28,6 +28,8 @@ namespace cg = cooperative_groups; __global__ void attn_softmax_v2(__half* vals, __half* mask, + __half* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -36,7 +38,9 @@ __global__ void attn_softmax_v2(__half* vals, int heads, int sequence_length, int num_seq, - float scale, + int head_offset, + int mask_stride, + int mp_size, int iterations, int reduceWidth) { @@ -47,8 +51,7 @@ __global__ void attn_softmax_v2(__half* vals, float2 low_data[MAX_REG_SIZE]; float2 high_data[MAX_REG_SIZE]; - - __half2 h_scale = __float2half2_rn(scale); + const __half zero_h = __float2half(0.f); int wid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; @@ -60,11 +63,15 @@ __global__ void attn_softmax_v2(__half* vals, __shared__ float partialSum[MAX_WARP_NUM]; int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + int batch_idx = iter_offset / (num_seq * heads); + int alibi_offset = batch_idx * heads * mp_size + head_offset; + int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); if (iter_offset < total_count) { vals += (iter_offset * sequence_length); - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length; + mask_offset = mask_offset * sequence_length; int seq_id = iter_offset % num_seq; int seq_id4 = seq_id >> 2; @@ -76,47 +83,67 @@ __global__ void attn_softmax_v2(__half* vals, (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; float max_val = minus_infinity; - + // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); for (int i = 0; i < iterations; i++) { int data_id = i * (reduceWidth << 2) + (seq_lane << 2); if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; + low_data[i].x = data_id > window_stride + ? __half2float(vals[data_id]) * layer_scale + : minus_infinity; low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && (data_id + 1) > window_stride) - ? __half2float(vals[data_id + 1]) + ? __half2float(vals[data_id + 1]) * layer_scale : minus_infinity; high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && (data_id + 2) > window_stride) - ? __half2float(vals[data_id + 2]) + ? __half2float(vals[data_id + 2]) * layer_scale : minus_infinity; high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && (data_id + 3) > window_stride) - ? __half2float(vals[data_id + 3]) + ? __half2float(vals[data_id + 3]) * layer_scale : minus_infinity; - if (mask && recompute) { + if (alibi) { + low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]); + low_data[i].y = + low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]); + high_data[i].x = + high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]); + high_data[i].y = + high_data[i].y + __half2float(alibi[data_id + alibi_offset + 3]); + } + if (mask) { low_data[i].x += __half2float(mask[data_id + mask_offset]); low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); } } else { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; + low_data[i].x = data_id > window_stride + ? __half2float(vals[data_id]) * layer_scale + : minus_infinity; low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && (data_id + 1) > window_stride) && (data_id + 1) < sequence_length) - ? __half2float(vals[data_id + 1]) + ? __half2float(vals[data_id + 1]) * layer_scale : minus_infinity; high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && (data_id + 2) > window_stride) && (data_id + 2) < sequence_length) - ? __half2float(vals[data_id + 2]) + ? __half2float(vals[data_id + 2]) * layer_scale : minus_infinity; + if (alibi) { + low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]); + if ((data_id + 1) < sequence_length) + low_data[i].y = + low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]); + if ((data_id + 2) < sequence_length) + high_data[i].x = + high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]); + } high_data[i].y = minus_infinity; - if (mask && recompute) { + if (mask) { low_data[i].x += __half2float(mask[data_id + mask_offset]); if ((data_id + 1) < sequence_length) low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); @@ -187,14 +214,16 @@ __global__ void attn_softmax_v2(__half* vals, if (data_id < sequence_length) { if ((sequence_length - data_id) >= 4) { - vals[data_id] = low_data[i].x / sum; - vals[data_id + 1] = low_data[i].y / sum; - vals[data_id + 2] = high_data[i].x / sum; - vals[data_id + 3] = high_data[i].y / sum; + vals[data_id] = __float2half(low_data[i].x / sum); + vals[data_id + 1] = __float2half(low_data[i].y / sum); + vals[data_id + 2] = __float2half(high_data[i].x / sum); + vals[data_id + 3] = __float2half(high_data[i].y / sum); } else { - vals[data_id] = low_data[i].x / sum; - if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum; - if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum; + vals[data_id] = __float2half(low_data[i].x / sum); + if ((data_id + 1) < sequence_length) + vals[data_id + 1] = __float2half(low_data[i].y / sum); + if ((data_id + 2) < sequence_length) + vals[data_id + 2] = __float2half(high_data[i].x / sum); } } } @@ -204,6 +233,8 @@ __global__ void attn_softmax_v2(__half* vals, __global__ void attn_softmax_v2(float* vals, float* attn_mask, + float* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -212,7 +243,9 @@ __global__ void attn_softmax_v2(float* vals, int heads, int sequence_length, int num_seq, - float scale, + int head_offset, + int mask_stride, + int mp_size, int iterations, int reduceWidth) { @@ -234,7 +267,10 @@ __global__ void attn_softmax_v2(float* vals, if (iter_offset < total_count) { vals += (iter_offset * sequence_length); - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + int batch_idx = iter_offset / (num_seq * heads); + int alibi_offset = batch_idx * heads * mp_size + head_offset; + int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); + int seq_id = iter_offset % num_seq; int seq_id4 = seq_id >> 2; @@ -371,6 +407,8 @@ __global__ void attn_softmax_v2(float* vals, template void launch_attn_softmax_v2(T* vals, T* mask, + T* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -379,7 +417,9 @@ void launch_attn_softmax_v2(T* vals, int heads, int num_seq, int sequence_length, - float scale, + int head_offset, + int mask_stride, + int mp_size, cudaStream_t stream) { int total_count = batch_size * heads * num_seq; @@ -390,26 +430,31 @@ void launch_attn_softmax_v2(T* vals, const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; if (sequence_length <= 32768) - attn_softmax_v2<<>>( - vals, - mask, - triangular, - recompute, - local_attention, - window_size, - total_count, - (triangular ? (heads * batch_size) : heads), - sequence_length, - num_seq, - scale, - iterations, - reduce_width); + attn_softmax_v2<<>>(vals, + mask, + alibi, + layer_scale, + triangular, + recompute, + local_attention, + window_size, + total_count, + heads, + sequence_length, + num_seq, + head_offset, + mask_stride, + mp_size, + iterations, + reduce_width); else throw std::runtime_error("Unsupport Seq_Length!"); } template void launch_attn_softmax_v2(float* vals, float* mask, + float* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -418,10 +463,14 @@ template void launch_attn_softmax_v2(float* vals, int heads, int num_seq, int sequence_length, - float scale, + int head_offset, + int mask_stride, + int mp_size, cudaStream_t stream); template void launch_attn_softmax_v2(__half* vals, __half* mask, + __half* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -430,5 +479,7 @@ template void launch_attn_softmax_v2(__half* vals, int heads, int num_seq, int sequence_length, - float scale, + int head_offset, + int mask_stride, + int mp_size, cudaStream_t stream); diff --git a/csrc/transformer/inference/includes/context.h b/csrc/transformer/inference/includes/context.h index 21f0b3cfe07b..75dd625eab8f 100644 --- a/csrc/transformer/inference/includes/context.h +++ b/csrc/transformer/inference/includes/context.h @@ -89,6 +89,7 @@ class Context { cudaMalloc(&_workspace, size); } + if (!_workspace) { throw std::runtime_error("Workspace is null."); } _workSpaceSize = size; } diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index a48b2d7f06cc..c2bb30126cd6 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -17,7 +17,7 @@ #include #include -#define MAX_OUT_TOKES 1024 +#define MAX_OUT_TOKES 128 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 @@ -28,6 +28,8 @@ template void launch_attn_softmax_v2(T* vals, T* mask, + T* alibi, + float layer_scale, bool triangular, bool recompute, bool local_attention, @@ -36,7 +38,9 @@ void launch_attn_softmax_v2(T* vals, int heads, int num_seq, int sequence_length, - float scale, + int offset, + int mask_stride, + int mp_size, cudaStream_t stream); // Fused bias add with gelu activation diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 25229bde06d8..7a18f98a49e8 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -24,7 +24,7 @@ from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from .module_inject import replace_transformer_layer, revert_transformer_layer -from .utils import log_dist +from .utils import log_dist, OnDevice from .comm.comm import init_distributed from .runtime import zero diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index b5e68aaf297f..ea29282da744 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -18,8 +18,11 @@ from ..pipe import PipelineModule from ..moe.utils import has_moe_layers from ..moe.layer import MoE +from ..runtime.zero import GatheredParameters +from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing DS_INFERENCE_ENABLED = False +from torch import nn class InferenceEngine(Module): @@ -96,7 +99,7 @@ def __init__(self, assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" - if self.checkpoint: + if self.checkpoint and not replace_with_kernel_inject: self._load_checkpoint(self.checkpoint) # convert model to intended dtype @@ -117,14 +120,16 @@ def __init__(self, if self.injection_dict: for client_module, injection_policy in self.injection_dict.items(): - self._apply_injection_policy(client_module, - injection_policy, - return_tuple, - replace_with_kernel_inject, - moe, - moe_experts, - moe_type, - training_mp_size) + self._apply_injection_policy( + client_module, + injection_policy, + return_tuple, + replace_with_kernel_inject, + moe, + moe_experts, + moe_type, + training_mp_size, + self.checkpoint if replace_with_kernel_inject else None) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -132,7 +137,8 @@ def __init__(self, moe=moe, moe_experts=moe_experts, moe_type=moe_type, - training_mp_size=training_mp_size) + training_mp_size=training_mp_size, + checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None) device = torch.cuda.current_device() logger.info(f"Place model to device: {device}") @@ -230,6 +236,73 @@ def _validate_args(self, mpu): raise ValueError( f"injection_dict must be None or a dict, got: {self.injection_dict}") + def load_model_with_checkpoint(self, r_module): + self.mp_replace = ReplaceWithTensorSlicing( + mp_group=self.mp_group, + mp_size=self.mp_world_size) #, out_dim=0, in_dim=1) + error_msgs = [] + + def load(module, state_dict, prefix): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + if len(list(module.parameters())) > 0 and list( + module.parameters())[0].numel() == 0: + with GatheredParameters(list(module.parameters(recurse=False)), + modifier_rank=0): + if dist.get_rank() == 0: + module._load_from_state_dict(*args) + else: + if hasattr(module, 'weight'): + if 'query_key_value' in prefix: + module.weight = self.mp_replace.qkv_copy( + module.weight.data, + state_dict[prefix + 'weight']) + else: + module.weight = self.mp_replace.copy( + module.weight.data, + state_dict[prefix + 'weight']) + else: + module.norm.weight = self.mp_replace.copy( + module.norm.weight.data, + state_dict[prefix + 'weight']) + if prefix + 'bias' in self.key_list: + if hasattr(module, 'norm'): + module.norm.bias = self.mp_replace.copy( + module.norm.bias, + state_dict[prefix + 'bias']) + else: + data = state_dict[prefix + 'bias'] + data = data.to(torch.cuda.current_device()) + module.bias = self.mp_replace.copy(module.bias, data) + + layer_policies = { + nn.Linear: load, + nn.Embedding: load, + nn.LayerNorm: load, + LinearLayer: load, + LinearAllreduce: load + } + + def load_module_recursive(module, prefix='', level=0): + for name, child in module.named_children(): + if child.__class__ in layer_policies: + checking_key = prefix + name + '.' + if not any(checking_key in item for item in self.key_list): + continue + if len(list(child.parameters())) > 0 and list( + child.parameters())[0].numel() == 0: + if len(child.weight.ds_shape) == 1: + child = Normalize(dim=child.weight.ds_shape[-1], + dtype=child.weight.dtype, + eps=child.eps) + setattr(module, name, child) + load(child, self.sd, prefix + name + '.') + else: + load_module_recursive(child, + prefix if level == 0 else prefix + name + '.', + level + 1) + + load_module_recursive(r_module) + def _apply_injection_policy(self, client_module=None, injection_policy=None, @@ -238,8 +311,10 @@ def _apply_injection_policy(self, moe=False, moe_experts=1, moe_type='standard', - training_mp_size=1): - + training_mp_size=1, + checkpoint_dir=None): + checkpoint = SDLoaderFactory.get_sd_loader_json( + checkpoint_dir) if checkpoint_dir is not None else None replace_transformer_layer(client_module, self.module, triangular_masking=self.triangular_masking, @@ -261,7 +336,8 @@ def _apply_injection_policy(self, moe=moe, moe_experts=moe_experts, moe_type=moe_type, - training_mp_size=training_mp_size) + training_mp_size=training_mp_size, + checkpoint=checkpoint) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -303,34 +379,47 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): else: sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) - mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() + if type(sd_loader) is list: + self.sd = torch.load(sd_loader[0], map_location='cpu') + self.key_list = list(self.sd.keys()) - load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size, - mp_rank, - is_pipe_parallel=is_pipe_parallel, - quantize=(self.dtype is torch.int8), - quantize_groups=self.quantize_groups, - mlp_extra_grouping=self.mlp_extra_grouping) + self.load_model_with_checkpoint(self.module) - self.quantization_scales, self.quantize_merge_count = quantize_config + for i in range(1, len(sd_loader)): + if not dist.is_initialized() or dist.get_rank() == 0: + print(f"loading checkpoint ({i})") + self.sd = torch.load(sd_loader[i], map_location='cuda') + self.key_list = list(self.sd.keys()) + self.load_model_with_checkpoint(self.module) + else: + mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() - moe, _ = has_moe_layers(self.module) - if moe: - from deepspeed.runtime.engine import DeepSpeedEngine - old_moe_load = False - if not isinstance(checkpoint['num_experts'], list): - old_moe_load = True - DeepSpeedEngine.load_moe_state_dict( - load_dir, - tag, + load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size, + mp_rank, + is_pipe_parallel=is_pipe_parallel, + quantize=(self.dtype is torch.int8), + quantize_groups=self.quantize_groups, + mlp_extra_grouping=self.mlp_extra_grouping) + + self.quantization_scales, self.quantize_merge_count = quantize_config + + moe, _ = has_moe_layers(self.module) + if moe: + from deepspeed.runtime.engine import DeepSpeedEngine + old_moe_load = False + if not isinstance(checkpoint['num_experts'], list): + old_moe_load = True + DeepSpeedEngine.load_moe_state_dict( + load_dir, + tag, + state_dict=checkpoint[self._choose_module_key(checkpoint)], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu) + + self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu) - - self.module.load_state_dict( - state_dict=checkpoint[self._choose_module_key(checkpoint)], - strict=load_module_strict) + strict=load_module_strict) def _choose_module_key(self, sd): assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed" @@ -349,6 +438,8 @@ def _convert_to_dtype(self): self.quantize_groups) elif self.dtype == torch.half: self.module.half() + elif self.dtype == torch.bfloat16: + self.module.bfloat16() elif self.dtype == torch.float: self.module.float() @@ -396,6 +487,7 @@ def forward(self, *inputs, **kwargs): *inputs: Variable length input list **kwargs: variable length keyword arguments """ + if self.mp_world_size > 1: if self.mpu is None: for input in inputs: @@ -410,7 +502,6 @@ def forward(self, *inputs, **kwargs): if not kwargs[k].is_contiguous(): kwargs[k] = kwargs[k].contiguous() dist.broadcast(kwargs[k], 0) - outputs = self.model_orig_fwd(*inputs, **kwargs) else: if self.enable_cuda_graph: diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py index 7fe1a3b36b2e..6082d0aff862 100755 --- a/deepspeed/module_inject/__init__.py +++ b/deepspeed/module_inject/__init__.py @@ -1,3 +1,4 @@ -from .replace_module import replace_transformer_layer, revert_transformer_layer +from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing from .module_quantize import quantize_transformer_layer from .replace_policy import DSPolicy, HFBertLayerPolicy +from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py new file mode 100644 index 000000000000..fa71c8d7d4d0 --- /dev/null +++ b/deepspeed/module_inject/layers.py @@ -0,0 +1,71 @@ +import torch +import deepspeed +import deepspeed.ops.transformer as transformer_inference +from deepspeed import comm as dist +from torch import nn +from torch.nn import functional as F + +from torch.nn.parameter import Parameter + + +class LinearAllreduce(nn.Module): + def __init__(self, weight, bias=None, mp_group=None): + super(LinearAllreduce, self).__init__() + self.weight = weight + self.bias = bias + self.mp_group = mp_group + + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + +class LinearLayer(nn.Module): + def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): + super(LinearLayer, self).__init__() + if weight is not None: + self.weight = weight + self.bias = bias + else: + self.weight = Parameter( + torch.empty(weight_shape, + dtype=dtype, + device=torch.cuda.current_device())) + self.bias = Parameter( + torch.empty(weight_shape[0], + dtype=dtype, + device=torch.cuda.current_device())) + + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output += self.bias + return output + + +class Normalize(nn.Module): + def __init__(self, dim, dtype=torch.float, eps=1e-5): + super(Normalize, self).__init__() + self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(torch.cuda.current_device()) + self.weight = self.norm.weight + self.bias = self.norm.bias + + def forward(self, input): + return self.norm(input) + + +class EmbeddingLayer(nn.Module): + def __init__(self, weight_shape, dtype=torch.float): + super(EmbeddingLayer, self).__init__() + self.weight = Parameter( + torch.empty(weight_shape[0], + weight_shape[1], + dtype=dtype, + device=torch.cuda.current_device())) + + def forward(self, input): + return F.embedding(input, self.weight) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py new file mode 100644 index 000000000000..f6722deb582b --- /dev/null +++ b/deepspeed/module_inject/load_checkpoint.py @@ -0,0 +1,128 @@ +import deepspeed +import torch +from torch import nn +from torch.nn import functional as F +import deepspeed.ops.transformer as transformer_inference +from ..runtime.zero import GatheredParameters +from .layers import LinearAllreduce, LinearLayer, Normalize, EmbeddingLayer + + +def load_model_with_checkpoint(r_module, sd, mp_replace): + error_msgs = [] + + def transpose(data): + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None + return data.reshape(data.shape[-1], data.shape[-2]) + + def load(module, prefix): + args = (sd, prefix, {}, True, [], [], error_msgs) + + if len(list(module.parameters())) > 0 and list( + module.parameters())[0].numel() == 0: + with GatheredParameters(list(module.parameters(recurse=False)), + modifier_rank=0): + module._load_from_sd(*args) + else: + if hasattr(module, 'weight'): + module.weight = mp_replace.copy(module.weight.data, + sd[prefix + 'weight']) + if prefix + 'bias' in sd.keys(): + module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) + + def load_transformer_layer(module, prefix): + module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) + module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) + module.attention.attn_qkvw = mp_replace.copy( + module.attention.attn_qkvw.data, + transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) + module.attention.attn_qkvb = mp_replace.copy( + module.attention.attn_qkvb.data, + sd[prefix + 'self_attention.query_key_value.' + 'bias']) + module.attention.attn_ow = mp_replace.copy( + module.attention.attn_ow.data, + transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) + module.attention.attn_ob = mp_replace.copy( + module.attention.attn_ob.data, + sd[prefix + 'self_attention.dense.' + 'bias']) + module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + + 'weight']) + module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + 'bias']) + module.mlp.inter_w = mp_replace.copy( + module.mlp.inter_w.data, + transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) + module.mlp.inter_b = mp_replace.copy(module.mlp.inter_b.data, + sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) + module.mlp.output_w = mp_replace.copy( + module.mlp.output_w.data, + transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) + module.mlp.output_b = mp_replace.copy(module.mlp.output_b.data, + sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) + + layer_policies = { + nn.Linear: load, + nn.Embedding: load, + nn.LayerNorm: load, + EmbeddingLayer: load, + LinearLayer: load, + Normalize: load, + transformer_inference.DeepSpeedTransformerInference: load_transformer_layer + } + + all_ds_ids = {} + + def load_module_recursive(module, prefix='', level=0): + for name, child in module.named_children(): + if child.__class__ in layer_policies: + checking_key = prefix + name + '.' + if not any(checking_key in item for item in sd.keys()): + if hasattr(child, 'weight') and \ + (hasattr(child.weight, 'ds_id') and \ + child.weight.ds_id in all_ds_ids): + prefix1 = all_ds_ids[child.weight.ds_id] + if child.__class__ is nn.Linear: + child = LinearLayer(weight=all_ds_ids[child.weight.ds_id]) + setattr(module, name, child) + continue + child_params = list(child.parameters()) + if len(child_params) > 0 and (child_params[0].numel() == 0 + or child_params[0].is_meta): + if child.weight.is_meta: + ds_shape = child.weight.shape + else: + ds_shape = child.weight.ds_shape + + if child.__class__ is nn.LayerNorm: + child = Normalize(dim=ds_shape[-1], + dtype=child.weight.dtype, + eps=child.eps) + setattr(module, name, child) + else: + ds_id = None + if hasattr(child.weight, 'ds_id'): + ds_id = child.weight.ds_id + child = EmbeddingLayer(weight_shape=ds_shape, + dtype=child.weight.dtype) + if ds_id is not None: + all_ds_ids[ds_id] = child.weight + setattr(module, name, child) + + layer_policies[child.__class__](child, prefix + name + '.') + else: + load_module_recursive(child, + prefix if level == 0 else prefix + name + '.', + level + 1) + + load_module_recursive(r_module) + + #XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon + embedding_weight = None + for n, p in r_module.named_parameters(): + if "word_embeddings." in n: + embedding_weight = p + assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" + r_module.lm_head.weight = embedding_weight + + del sd + sd = None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 16afaefab495..306feac322ea 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -2,49 +2,29 @@ import torch import deepspeed import deepspeed.ops.transformer as transformer_inference -from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy +from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy, BLOOMLayerPolicy from .replace_policy import replace_policies from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE from ..runtime.weight_quantizer import WeightQuantization from torch import nn from deepspeed import comm as dist +from torch import nn +from torch.nn import functional as F - -class LinearAllreduce(nn.Module): - def __init__(self, weight, bias=None, mp_group=None): - super(LinearAllreduce, self).__init__() - self.weight = weight - self.bias = bias - self.mp_group = mp_group - - def forward(self, input): - output = torch.matmul(input, self.weight) - if self.mp_group is not None: - dist.all_reduce(output, group=self.mp_group) - if self.bias is not None: - output += self.bias - return output - - -class LinearLayer(nn.Module): - def __init__(self, weight, bias=None): - super(LinearLayer, self).__init__() - self.weight = weight - self.bias = bias - - def forward(self, input): - output = torch.matmul(input, self.weight) - if self.bias is not None: - output += self.bias - return output +from ..runtime.zero import GatheredParameters +from .layers import LinearAllreduce, LinearLayer, Normalize, EmbeddingLayer +from .load_checkpoint import load_model_with_checkpoint class ReplaceWithTensorSlicing: - def __init__(self, mp_group=None): + def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): if mp_group is not None: self.gpu_index = dist.get_rank(group=mp_group) else: self.gpu_index = 0 + self.out_dim = out_dim + self.in_dim = in_dim + self.mp_size = mp_size def merge_assert(self, dim1, dim2): assert dim1 > dim2, \ @@ -57,68 +37,87 @@ def qkv_copy(self, dst, src): return src src_shape = src.shape dst_shape = dst.shape - - src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1) + if self.out_dim == 0: + src_split = torch.split(src.data, + src_shape[self.out_dim] // self.mp_size, + dim=0) + else: + src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1) if (len(src_shape) == 2 and len(dst_shape) == 2): - if src_shape[1] == dst_shape[1]: - return torch.nn.Parameter(src) - - self.merge_assert(src_shape[1], dst_shape[1]) - qkv_size = dst_shape[1] // 3 - qkv_split = [torch.split(src_s, qkv_size, dim=1) for src_s in src_split] - - weight_split = [ - torch.cat([qkv_s[i] for qkv_s in qkv_split], - axis=1) for i in range(len(qkv_split[0])) - ] - dst.data.copy_(weight_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + if src_shape[self.out_dim] == dst_shape[self.out_dim]: + return torch.nn.parameter.Parameter(src) + if self.out_dim == 1: + self.merge_assert(src_shape[self.out_dim], dst_shape[self.out_dim]) + qkv_size = dst_shape[self.out_dim] // 3 + qkv_split = [ + torch.split(src_s, + qkv_size, + dim=self.out_dim) for src_s in src_split + ] + + weight_split = [ + torch.cat([qkv_s[i] for qkv_s in qkv_split], + axis=self.out_dim) for i in range(len(qkv_split[0])) + ] + dst.data.copy_(weight_split[self.gpu_index].to( + torch.cuda.current_device()).contiguous()) + else: + dst.data.copy_(src_split[self.gpu_index].to( + torch.cuda.current_device()).contiguous()) else: if src_shape[0] == dst_shape[0]: - return torch.nn.Parameter(src) - - qkv_size = dst_shape[0] // 3 - qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] - bias_split = [ - torch.cat([qkv_s[i] for qkv_s in qkv_split], - axis=0) for i in range(len(qkv_split[0])) - ] - dst.data.copy_(bias_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + return torch.nn.parameter.Parameter(src) + if self.out_dim == 1: + qkv_size = dst_shape[0] // 3 + qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] + bias_split = [ + torch.cat([qkv_s[i] for qkv_s in qkv_split], + axis=0) for i in range(len(qkv_split[0])) + ] + dst.data.copy_(bias_split[self.gpu_index].to( + torch.cuda.current_device()).contiguous()) + else: + dst.data.copy_(src_split[self.gpu_index].to( + torch.cuda.current_device()).contiguous()) - return torch.nn.Parameter(dst) + return torch.nn.parameter.Parameter(dst) def copy(self, dst, src): if src is None: return src - src_shape = src.shape dst_shape = dst.shape - if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: - return torch.nn.Parameter(src) - - if src_shape[0] != dst_shape[0]: - self.merge_assert(src_shape[0], dst_shape[0]) - weight_split = torch.split(src, dst_shape[0]) + dst.data.copy_(src) else: - self.merge_assert(src_shape[1], dst_shape[1]) - weight_split = torch.split(src.data, dst_shape[1], dim=1) - - dst.data.copy_(weight_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + if src_shape[self.in_dim] != dst_shape[self.in_dim]: + self.merge_assert(src_shape[self.in_dim], dst_shape[self.in_dim]) + weight_split = torch.split( + src, + dst_shape[self.in_dim], + dim=self.in_dim)[self.gpu_index].to( + torch.cuda.current_device()).contiguous() + else: + self.merge_assert(src_shape[self.out_dim], dst_shape[self.out_dim]) + weight_split = torch.split( + src.data, + dst_shape[self.out_dim], + dim=self.out_dim)[self.gpu_index].to( + torch.cuda.current_device()).contiguous() + dst.data.copy_(weight_split.contiguous()) else: if src_shape[0] == dst_shape[0]: - return torch.nn.Parameter(src) - - bias_split = torch.split(src.data, dst_shape[-1]) - dst.data.copy_(bias_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + dst.data.copy_(src) + else: + bias_split = torch.split(src.data, + dst_shape[-1])[self.gpu_index].to( + torch.cuda.current_device()).contiguous() + dst.data.copy_(bias_split) - return torch.nn.Parameter(dst) + return torch.nn.parameter.Parameter(dst, requires_grad=False) def replace_transformer_layer(orig_layer_impl, @@ -147,7 +146,8 @@ def replace_transformer_layer(orig_layer_impl, linear_layer_setting=None, moe=False, moe_experts=1, - moe_type='standard'): + moe_type='standard', + checkpoint=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -182,6 +182,9 @@ def replace_transformer_layer(orig_layer_impl, Returns: Updated nn.module with replaced transformer layers """ + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group, + mp_size=mp_size) #, out_dim=0, in_dim=1) + def replace_with_policy(child, policy_cls, triangular_masking, @@ -243,7 +246,6 @@ def replace_with_policy(child, _res_4hh_w = _res_4hh_w.half() _res_coef = _res_coef.half() - mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) #expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group) if inference: @@ -267,6 +269,7 @@ def replace_with_policy(child, else: rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 + bigscience_bloom = policy_cls is BLOOMLayerPolicy transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, @@ -291,7 +294,8 @@ def replace_with_policy(child, 'window_size') else 1), rotary_dim=rotary_dim, mlp_after_attn=(rotary_dim is None or rotary_dim < 0), - training_mp_size=training_mp_size) + training_mp_size=training_mp_size, + bigscience_bloom=bigscience_bloom) if quantize and quantize_settings is not None: (quantization_scales, @@ -359,66 +363,115 @@ def transpose(data): data.to(torch.cuda.current_device()) return data + attn_block = new_module.attention + mpl_block = new_module.mlp + if attn_linear_layer: - qkvw.data = transpose(qkvw.data) - dense_w.data = transpose(dense_w.data) + if qkvw.numel() == 0 or qkvw.is_meta: + if qkvw.is_meta or qkvw.ds_tensor.numel( + ) < attn_block.attn_qkvw.numel(): + pass + else: + with GatheredParameters([qkvw, + dense_w, + qkvb, + dense_b], + modifier_rank=0): + qkvw = transpose(qkvw.data) + dense_w = transpose(dense_w.data) + qkvb = qkvb.data + dense_b = dense_b.data + else: + qkvw.data = transpose(qkvw.data) + dense_w.data = transpose(dense_w.data) + + def _transpose(x): + num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size + attention_head_size = x.shape[-1] // num_attention_heads_per_partition + new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, + attention_head_size) + x_1 = x.view(*new_x_shape) + (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) + if len(q.shape) > 2: + return torch.cat((q.reshape(q.shape[0], + -1), + k.reshape(q.shape[0], + -1), + v.reshape(q.shape[0], + -1)), + dim=-1).reshape(x.shape) + else: + return torch.cat((q.reshape(-1), + k.reshape(-1), + v.reshape(-1)), + dim=-1).reshape(x.shape) if megatron_v2: new_module.config.rotate_half = True new_module.config.rotate_every_two = False - def _transpose(x): - num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size - attention_head_size = x.shape[-1] // num_attention_heads_per_partition - new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, - attention_head_size) - x_1 = x.view(*new_x_shape) - (q, - k, - v) = torch.split(x_1, - (x_1.shape[-1] // 3), - dim=(x_1.dim() - 1)) - if len(q.shape) > 2: - return torch.cat((q.reshape(q.shape[0], - -1), - k.reshape(q.shape[0], - -1), - v.reshape(q.shape[0], - -1)), - dim=-1).reshape(x.shape) - else: - return torch.cat((q.reshape(-1), - k.reshape(-1), - v.reshape(-1)), - dim=-1).reshape(x.shape) - - qkvw = torch.nn.Parameter(_transpose(qkvw).contiguous()) - qkvb = torch.nn.Parameter(_transpose(qkvb).contiguous()) + # Note: this part needs to be added for BLOOM architecture + qkvw = torch.nn.parameter.Parameter(_transpose(qkvw).contiguous()) + qkvb = torch.nn.parameter.Parameter(_transpose(qkvb).contiguous()) - dense_b = dense_b if dense_b is None else dense_b * ( - transformer_config.training_mp_size / transformer_config.mp_size) - _4hh_b = _4hh_b * (transformer_config.training_mp_size / - transformer_config.mp_size) + # NOTE: This part caused instability in the multi-GPU inference! + # TODO: This needs to be incorporated in the kernels. + #dense_b = dense_b if dense_b is None else dense_b * ( + # transformer_config.training_mp_size / transformer_config.mp_size) + #_4hh_b = _4hh_b * (transformer_config.training_mp_size / + # transformer_config.mp_size) if mlp_linear_layer: - _h4h_w = [transpose(moe_w1.data) - for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) - _4hh_w = [transpose(moe_w1.data) - for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) + if not moe and (_4hh_w.numel() == 0 or _4hh_w.is_meta): + if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( + ) < mpl_block.inter_w.numel(): + pass + else: + with GatheredParameters([_h4h_w, + _4hh_w, + _4hh_b, + _h4h_b], + modifier_rank=0): + _h4h_w = transpose(_h4h_w.data) + _4hh_w = transpose(_4hh_w.data) + _h4h_b = _h4h_b.data + _4hh_b = _4hh_b.data + else: + _h4h_w = [transpose(moe_w1.data) + for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) + _4hh_w = [transpose(moe_w1.data) + for moe_w1 in _4hh_w] if moe else transpose(_4hh_w.data) if moe and moe_type == 'residual': _res_h4h_w.data = transpose(_res_h4h_w.data) _res_4hh_w.data = transpose(_res_4hh_w.data) _res_coef.data = transpose(_res_coef.data) - attn_block = new_module.attention - attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw) - attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) + if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: + if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): + pass + else: + with GatheredParameters([attn_qkvw, + attn_qkvb, + attn_ow, + attn_ob], + modifier_rank=0): + attn_block.attn_qkvw = mp_replace.copy( + attn_block.attn_qkvw, + qkvw) + attn_block.attn_qkvb = mp_replace.copy( + attn_block.attn_qkvb, + qkvb) + + attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) + attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) + else: + attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) + attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) - attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) - attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) + attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) + attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) - mpl_block = new_module.mlp if moe: gpu_index = dist.get_rank() gpu_index = 0 @@ -448,20 +501,68 @@ def _transpose(x): torch.cuda.current_device()) new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) else: - mpl_block.inter_w.data = mp_replace.copy(mpl_block.inter_w, _h4h_w) - mpl_block.inter_b.data = mp_replace.copy(mpl_block.inter_b, _h4h_b) - mpl_block.output_w.data = mp_replace.copy(mpl_block.output_w, _4hh_w) - mpl_block.output_b.data = mp_replace.copy(mpl_block.output_b, _4hh_b) + + if _4hh_w.numel() == 0 or _4hh_w.is_meta: + if _4hh_w.is_meta or _4hh_w.ds_tensor.numel( + ) < mpl_block.inter_w.numel(): + pass + else: + with GatheredParameters([_h4h_w, + _4hh_w, + _4hh_w, + _4hh_b], + modifier_rank=0): + mpl_block.inter_w = mp_replace.copy( + mpl_block.inter_w, + _h4h_w) + mpl_block.inter_b = mp_replace.copy( + mpl_block.inter_b, + _h4h_b) + mpl_block.output_w = mp_replace.copy( + mpl_block.output_w, + _4hh_w) + mpl_block.output_b = mp_replace.copy( + mpl_block.output_b, + _4hh_b) + else: + mpl_block.inter_w = mp_replace.copy(mpl_block.inter_w, _h4h_w) + mpl_block.inter_b = mp_replace.copy(mpl_block.inter_b, _h4h_b) + mpl_block.output_w = mp_replace.copy(mpl_block.output_w, _4hh_w) + mpl_block.output_b = mp_replace.copy(mpl_block.output_b, _4hh_b) + if attn_nw is None: new_module.mlp.attn_nw = attn_nw - else: - new_module.mlp.attn_nw.data = attn_nw.to(torch.cuda.current_device()) - if attn_nb is None: new_module.mlp.attn_nb = attn_nb else: - new_module.mlp.attn_nb.data = attn_nb.to(torch.cuda.current_device()) - new_module.norm_w.data = input_nw.to(torch.cuda.current_device()) - new_module.norm_b.data = input_nb.to(torch.cuda.current_device()) + if attn_nw.is_meta or attn_nw.numel() == 0: + if attn_nw.is_meta or attn_nw.ds_tensor.numel( + ) < new_module.mlp.attn_nw.numel(): + pass + else: + with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): + new_module.mlp.attn_nw.data.copy_( + attn_nw.to(torch.cuda.current_device())) + new_module.mlp.attn_nb.data.copy_( + attn_nb.to(torch.cuda.current_device())) + else: + new_module.mlp.attn_nw.data.copy_( + attn_nw.to(torch.cuda.current_device())) + new_module.mlp.attn_nb.data.copy_( + attn_nb.to(torch.cuda.current_device())) + + if input_nw.is_meta or input_nw.numel() == 0: + if input_nw.is_meta or input_nw.ds_tensor.numel( + ) < new_module.norm_w.numel(): + pass + else: + with GatheredParameters([input_nw, input_nb], modifier_rank=0): + new_module.norm_w.data.copy_( + input_nw.to(torch.cuda.current_device())) + new_module.norm_b.data.copy_( + input_nb.to(torch.cuda.current_device())) + else: + new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) + new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, @@ -502,58 +603,88 @@ def _transpose(x): def replace_wo_policy(module, all_reduce_linears): def _replace(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) + z_inference = (len(list(child.parameters())) > 0) and (list( + child.parameters())[0].numel() == 0) + if z_inference: + weight_shape = child.weight.ds_shape + else: + weight_shape = child.weight.shape if name in all_reduce_linears: - new_weight = torch.empty( - (child.weight.shape[0] - if conv_linear_layer else child.weight.shape[1] // mp_size, - child.weight.shape[1] - if conv_linear_layer else child.weight.shape[0]), - device=child.weight.device, - dtype=torch.half if fp16 else torch.float) - if not conv_linear_layer: - child.weight.data.view(-1).copy_( - child.weight.data.transpose(-1, - -2).contiguous().view(-1)) - child.weight.data = child.weight.data.reshape( - child.weight.data.shape[-1], - child.weight.data.shape[-2]) - data = mp_replace.copy(new_weight, - child.weight.data).to(torch.cuda.current_device()) + new_weight = torch.empty(( + weight_shape[1] if conv_linear_layer else weight_shape[0], + (weight_shape[0] if conv_linear_layer else weight_shape[1]) // + mp_size, + ), + device=child.weight.device, + dtype=child.weight.dtype) + if z_inference: + with deepspeed.zero.GatheredParameters(child.weight, + modifier_rank=0): + data = child.weight.data.to(new_weight.device) + if conv_linear_layer: + data = data.transpose(-1, -2).contiguous() + data = mp_replace.copy(new_weight, data) + child.weight.ds_tensor = torch.empty(1) + else: + if conv_linear_layer: + child.weight.data = child.weight.data.transpose(-1, + -2).contiguous() + data = mp_replace.copy(new_weight, child.weight.data) + new_bias = torch.empty((weight_shape[0]), + device=child.weight.device, + dtype=child.weight.dtype) + + if z_inference: + with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): + new_bias.data.copy_(child.bias.data) + else: + new_bias.data.copy_(child.bias.data) return LinearAllreduce(data, child.bias if child.bias is None else \ - child.bias.to(torch.cuda.current_device()), mp_group) + torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group) else: - new_weight = torch.empty( - (child.weight.shape[0] // - mp_size if conv_linear_layer else child.weight.shape[1], - child.weight.shape[1] - if conv_linear_layer else child.weight.shape[0] // mp_size), - device=child.weight.device, - dtype=torch.half if fp16 else torch.float) - if not conv_linear_layer: - child.weight.data.view(-1).copy_( - child.weight.data.transpose(-1, - -2).contiguous().view(-1)) - child.weight.data = child.weight.data.reshape( - child.weight.data.shape[-1], - child.weight.data.shape[-2]) - data = mp_replace.copy(new_weight, child.weight.data) - new_bias = torch.empty((child.weight.shape[1] // mp_size), + new_weight = torch.empty(( + (weight_shape[1] if conv_linear_layer else weight_shape[0]) // + mp_size, + weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], + ), + device=child.weight.device, + dtype=child.weight.dtype) + if z_inference: + with deepspeed.zero.GatheredParameters(child.weight, + modifier_rank=0): + data = child.weight.data.to(new_weight.device) + if conv_linear_layer: + data = data.transpose(-1, -2).contiguous() + data = mp_replace.copy(new_weight, data) + child.weight.ds_tensor = torch.empty(1) + else: + if conv_linear_layer: + child.weight.data = child.weight.data.transpose(-1, + -2).contiguous() + data = mp_replace.copy(new_weight, child.weight.data) + + new_bias = torch.empty((weight_shape[0] // mp_size), device=child.weight.device, - dtype=torch.half if fp16 else torch.float) - bias_data = None if child.bias is None else mp_replace.copy( - new_bias, - child.bias.data).to(torch.cuda.current_device()) + dtype=child.weight.dtype) + if z_inference: + with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): + bias_data = None if child.bias is None else mp_replace.copy( + new_bias, + child.bias.data).to(torch.cuda.current_device()) + else: + bias_data = None if child.bias is None else mp_replace.copy( + new_bias, + child.bias.data).to(torch.cuda.current_device()) return LinearLayer(data.to(torch.cuda.current_device()), bias_data) def _slice_embedding(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) - new_weight = torch.empty((child.weight.shape[0], - child.weight.shape[1] // mp_size), + new_weight = torch.empty((weight_shape[0], + weight_shape[1] // mp_size), device=child.weight.device, dtype=child.weight.dtype) - data = mp_replace.copy(new_weight, child.weight.data) - new_embedding = nn.Embedding(child.weight.shape[0], - child.weight.shape[1] // mp_size) + data = mp_replace.copy(new_weight, child.weight.ds_tensor.data) + new_embedding = nn.Embedding(weight_shape[0], weight_shape[1] // mp_size) new_embedding.weight.data.copy_(data) return new_embedding @@ -570,6 +701,8 @@ def update_mp_params(child): child.all_head_size = child.all_head_size // mp_size if hasattr(child, 'embed_dim'): child.embed_dim = child.embed_dim // mp_size + if hasattr(child, 'hidden_size'): + child.hidden_size = child.hidden_size // mp_size conv_linear_layer = False if linear_layer_setting is not None: @@ -626,10 +759,18 @@ def replace_fn(child, _policy, layer_id=0): return new_module - return replace_module(model=model, - orig_class=orig_layer_impl, - replace_fn=replace_fn, - _replace_policy=policy) + replaced_module = replace_module(model=model, + orig_class=orig_layer_impl, + replace_fn=replace_fn, + _replace_policy=policy) + + if checkpoint is not None: + for i in range(len(checkpoint)): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + print(f"loading checkpoint ({i})") + sd = torch.load(checkpoint[i], map_location='cpu') + load_model_with_checkpoint(replaced_module, sd, mp_replace) + return replaced_module def revert_transformer_layer(orig_layer_impl, model, config, preln=False): diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index f4054c633810..59393b1d8477 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -321,6 +321,45 @@ def layerNorm(self): self.client_module.ln_1.bias +class BLOOMLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + super().__init__(inference, linear_layer=True) + self.client_module = client_module + try: + import transformers + BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock + except: + BLOOMLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.self_attention.hidden_size, \ + self.client_module.self_attention.num_heads + + def attention(self): + return self.linear_layer, \ + self.client_module.self_attention.query_key_value.weight, \ + self.client_module.self_attention.query_key_value.bias, \ + self.client_module.self_attention.dense.weight, \ + self.client_module.self_attention.dense.bias, \ + self.scale_attention, \ + self.is_megatron_v2 + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.dense_h_to_4h.weight, \ + self.client_module.mlp.dense_h_to_4h.bias, \ + self.client_module.mlp.dense_4h_to_h.weight, \ + self.client_module.mlp.dense_4h_to_h.bias + + def layerNorm(self): + return self.client_module.post_attention_layernorm.weight, \ + self.client_module.post_attention_layernorm.bias, \ + self.client_module.input_layernorm.weight, \ + self.client_module.input_layernorm.bias + + class GPTNEOXLayerPolicy(DSPolicy): _orig_layer_class = None version = 0 @@ -383,4 +422,5 @@ def layerNorm(self): HFGPTJLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, + BLOOMLayerPolicy, ] diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 4fa162bc3491..342f4e1b595f 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -13,6 +13,7 @@ from deepspeed import comm as dist # Cuda modules will be imported if needed inference_cuda_module = None +minus_inf = -10000.0 class TransformerConfig(): @@ -48,6 +49,7 @@ class DeepSpeedInferenceConfig(TransformerConfig): scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation. return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor + bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture. """ def __init__(self, hidden_size=-1, @@ -70,7 +72,8 @@ def __init__(self, rotate_every_two=True, return_tuple=True, mlp_after_attn=True, - training_mp_size=1): + training_mp_size=1, + bigscience_bloom=False): super(DeepSpeedInferenceConfig, self).__init__( hidden_size, @@ -95,6 +98,7 @@ def __init__(self, self.mlp_after_attn = mlp_after_attn self.specialized_mode = False self.training_mp_size = training_mp_size + self.bigscience_bloom = bigscience_bloom @classmethod def from_dict(cls, json_object): @@ -136,7 +140,8 @@ def forward(ctx, q_groups, merge_count, qkv_merging, - score_context_func): + score_context_func, + alibi): def _transpose_for_scores(x, key=False, reshape=False): attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, @@ -156,6 +161,123 @@ def _transpose_for_context(x): (hidden_size_per_partition,) return x.view(*new_x_layer_shape).contiguous() + ########### This part is taken/modified form the HF modeling_bloom.py ################ + # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py + + def split_tensor_along_last_dim(tensor, + num_partitions, + contiguous_split_chunks=True): + """Split a tensor along its last dimension. + + Args: + tensor: ([`torch.tensor`], *required*): + input tensor to split + num_partitions ([`int`], *required*): + number of partitions to split the tensor + contiguous_split_chunks ([`bool`], *optional*, default=`False`):: + If True, make each chunk contiguous in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + numerator, denominator = tensor.size()[last_dim], num_partitions + if not (numerator % denominator == 0): + raise ValueError(f"{numerator} is not divisible by {denominator}") + last_dim_size = numerator // denominator + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): + head_dim = hidden_size_per_partition // num_attention_heads_per_partition + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + num_attention_heads_per_partition, + 3 * head_dim) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + (query_layer, + key_layer, + value_layer) = split_tensor_along_last_dim(mixed_x_layer, + 3) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), + dim=1) + + presents = (key_layer, value_layer) + + # [batch_size, head_dim, q_length, k_length] + output_size = (query_layer.size(0), + query_layer.size(2), + query_layer.size(1), + key_layer.size(1)) + # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] + query_layer = query_layer.transpose(1, + 0).reshape( + output_size[2], + output_size[0] * output_size[1], + -1) + # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] + key_layer = key_layer.transpose(1, + 0).reshape(output_size[3], + output_size[0] * output_size[1], + -1) + + # Raw attention scores. [batch_size * num_heads, q_length, k_length] + matmul_result = torch.matmul(query_layer.transpose(1, + 0), + key_layer.transpose(1, + 0).transpose(1, + 2)) + # change view to [batch_size, num_heads, q_length, k_length] + attention_scores = matmul_result.view(*output_size) + + offset = dist.get_rank( + ) * num_attention_heads_per_partition if dist.is_initialized() else 0 + attention_probs = inference_cuda_module.softmax_fp16( + attention_scores, + ((1 - input_mask).half() * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + alibi, + (config.triangular_masking and (attention_scores.shape[-2] > 1)), + False, + False, + 1, + False, + 1 / (norm_factor * norm_factor), + offset, + config.mp_size) + # change view [batch_size x num_heads, q_length, k_length] + attention_probs_reshaped = attention_probs.view(*matmul_result.shape) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm( + attention_probs_reshaped, + value_layer.transpose(1, + 2).reshape(-1, + value_layer.size(1), + value_layer.size(3))) + + # change view [batch_size, num_heads, q_length, head_dim] + context_layer = context_layer.view( + context_layer.size(0) // num_attention_heads_per_partition, + num_attention_heads_per_partition, + context_layer.size(1), + context_layer.shape[-1]) + + context_layer = _transpose_for_context(context_layer) + + return context_layer, presents + + ###################### End of HF modeling_bloom addition ######################## + def compute_attention(qkv_out, input_mask): no_masking = input_mask is None @@ -208,7 +330,8 @@ def compute_attention(qkv_out, input_mask): mixed_query, key_layer, torch.empty(1), - input_mask, + ((1 - input_mask).half() * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, value_layer, torch.empty(1), num_attention_heads_per_partition, @@ -223,7 +346,8 @@ def compute_attention(qkv_out, input_mask): mixed_query, (key_layer if unfused_mode else past_key.type_as(key_layer)), key_layer, - input_mask, + ((1 - input_mask).half() * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, (value_layer if unfused_mode else past_value.type_as(value_layer)), value_layer, @@ -244,23 +368,35 @@ def compute_attention(qkv_out, input_mask): return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer else: - attn_key_value = score_context_func( - qkv_out, - input_mask, - config.rotary_dim, - config.rotate_half, - config.rotate_every_two, - num_attention_heads_per_partition, - (1 / norm_factor if config.scale_attention else 1.0), - config.triangular_masking, - config.local_attention, - config.window_size, - no_masking, - config.layer_id, - DeepSpeedTransformerInference.layer_id) + # Note: This modification is added for the BLOOM-176B model and will be removed later! + if config.bigscience_bloom: + context_layer, presents = backup_attention(qkv_out, layer_past, alibi, input_mask, norm_factor) + return context_layer, presents[0], presents[1] #key_layer, value_layer + else: + if alibi is not None: + batch_heads = qkv_out.shape[0] * num_attention_heads_per_partition + offset = dist.get_rank() * batch_heads if dist.is_initialized( + ) else 0 + sliced_alibi = alibi[offset:batch_heads + offset, :, :] - context_layer, key_layer, value_layer = attn_key_value - return context_layer, key_layer, value_layer + attn_key_value = score_context_func( + qkv_out, + ((1 - input_mask).to(qkv_out.dype) * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + config.rotary_dim, + config.rotate_half, + config.rotate_every_two, + num_attention_heads_per_partition, + (1 / norm_factor if config.scale_attention else 1.0), + config.triangular_masking, + config.local_attention, + config.window_size, + no_masking, + config.layer_id, + DeepSpeedTransformerInference.layer_id, + sliced_alibi if alibi is not None else torch.empty(1)) + context_layer, key_layer, value_layer = attn_key_value + return context_layer, key_layer, value_layer def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ @@ -276,16 +412,16 @@ def selfAttention_fp(): else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 - - qkv_out = qkv_func(input, - attn_qkvw, - (attn_qkvb if attn_qkvb is not None else norm_b), - norm_w, - norm_b, - config.epsilon, - (attn_qkvb is not None), - DeepSpeedTransformerInference.layer_id) - + qkv_out = qkv_func( + input, + attn_qkvw, + (attn_qkvb if attn_qkvb is not None else norm_b), + norm_w, + norm_b, + config.epsilon, + (attn_qkvb is not None), + 1 if config.bigscience_bloom else + DeepSpeedTransformerInference.layer_id) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) @@ -347,19 +483,30 @@ def __init__(self, qkv_merging=False): super(DeepSpeedSelfAttention, self).__init__() self.config = config + data_type = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 + device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' self.attn_qkvw = nn.Parameter( - torch.Tensor(self.config.hidden_size, - (self.config.hidden_size // self.config.mp_size) * 3)) + torch.empty(self.config.hidden_size, + (self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type, + device=device)) self.attn_qkvb = nn.Parameter( - torch.Tensor((self.config.hidden_size // self.config.mp_size) * 3)) + torch.empty((self.config.hidden_size // self.config.mp_size) * 3, + dtype=data_type, + device=device)) self.attn_ow = nn.Parameter( - torch.Tensor(self.config.hidden_size // self.config.mp_size, - self.config.hidden_size)) + torch.empty(self.config.hidden_size // self.config.mp_size, + self.config.hidden_size, + dtype=data_type, + device=device)) - self.attn_ob = nn.Parameter(torch.Tensor(self.config.hidden_size)) + self.attn_ob = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size @@ -389,7 +536,8 @@ def forward(self, encoder_attention_mask=None, output_attentions=False, norm_w=None, - norm_b=None): + norm_b=None, + alibi=None): output = DeepSpeedSelfAttentionFunction.apply( input, input_mask, @@ -414,7 +562,8 @@ def forward(self, self.q_groups, self.merge_count, self.qkv_merging, - self.score_context_func) + self.score_context_func, + alibi) return output @@ -482,6 +631,7 @@ def forward(ctx, config.pre_layer_norm, config.mlp_after_attn) output = vector_matmul_func(intermediate, output_w, False) + inference_cuda_module.residual_add( output, residual if config.pre_layer_norm else residual_add, @@ -513,17 +663,34 @@ def __init__(self, super(DeepSpeedMLP, self).__init__() self.config = config - self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size)) - self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size)) + data_type = torch.half if config.fp16 else torch.float + device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + self.attn_nw = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) + self.attn_nb = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) self.inter_w = nn.Parameter( - torch.Tensor(self.config.hidden_size, - self.config.intermediate_size // self.config.mp_size)) + torch.empty(self.config.hidden_size, + self.config.intermediate_size // self.config.mp_size, + dtype=data_type, + device=device)) self.inter_b = nn.Parameter( - torch.Tensor(self.config.intermediate_size // self.config.mp_size)) + torch.empty(self.config.intermediate_size // self.config.mp_size, + dtype=data_type, + device=device)) self.output_w = nn.Parameter( - torch.Tensor((self.config.intermediate_size // self.config.mp_size), - self.config.hidden_size)) - self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) + torch.empty((self.config.intermediate_size // self.config.mp_size), + self.config.hidden_size, + dtype=data_type, + device=device)) + self.output_b = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) # used for quantization self.q_scales = q_scales @@ -595,6 +762,7 @@ def __init__(self, self.config.layer_id = DeepSpeedTransformerInference.layer_id DeepSpeedTransformerInference.layer_id += 1 + data_type = torch.half if config.fp16 else torch.float global inference_cuda_module if inference_cuda_module is None: builder = op_builder.InferenceBuilder() @@ -615,8 +783,15 @@ def __init__(self, merge_count, mlp_extra_grouping) - self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size)) - self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) + device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + self.norm_w = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) + self.norm_b = nn.Parameter( + torch.empty(self.config.hidden_size, + dtype=data_type, + device=device)) self.layer_past = None def forward(self, @@ -632,6 +807,7 @@ def forward(self, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, + alibi=None, output_attentions=False): get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask @@ -658,10 +834,10 @@ def forward(self, encoder_attention_mask, output_attentions, self.norm_w, - self.norm_b) + self.norm_b, + alibi) presents = (key, value) - self.layer_past = presents - + self.layer_past = presents if layer_past is None else None output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) if not self.config.pre_layer_norm: @@ -673,8 +849,6 @@ def forward(self, self.config.epsilon) output = output.to(input_type) - #print(f'[{deepspeed.comm.get_rank()}] {self.config.layer_id}: {output.norm()}') - #exit() if get_present: output = (output, presents) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6fd84a6aa073..7b6669d961fd 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -443,6 +443,7 @@ def __getattr__(self, name): """ Pass through attributes defined in the model if they are not overridden by ds-engine. """ + _module = {} if "module" in self.__dict__: _module = self.__dict__['module'] @@ -944,6 +945,7 @@ def is_replicated(p): @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: + return if not all(param.dtype == dtype for param in model.parameters()) and dist.get_rank() == 0: raise ValueError( diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 09887aaa275c..35ccb8d5fd02 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -22,6 +22,8 @@ def get_sd_loader_json(json_file): sd_type = data['type'] ckpt_list = data['checkpoints'] version = data['version'] + if 'BLOOM' in sd_type or 'Bloom' in sd_type: + return ckpt_list return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) @staticmethod diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index efe18c264316..fe7461002932 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,5 +1,6 @@ from .logging import logger, log_dist #from .distributed import init_distributed +from .init_on_device import OnDevice from .groups import * from .nvtx import instrument_w_nvtx from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/init_on_device.py b/deepspeed/utils/init_on_device.py new file mode 100644 index 000000000000..1b51efe60384 --- /dev/null +++ b/deepspeed/utils/init_on_device.py @@ -0,0 +1,81 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' +import torch +from typing import Callable +from torch import Tensor +from packaging import version as pkg_version + + +class OnDevice(object): + """ + Create modules/tensors w. specific devices and dtypes. Examples: + + Create MyModule which consists of many different sub-modules and parameters. In this case we can create + MyModule as a collection of 'meta' tensors by passing `device='meta'` or we can create the module _directly_ + on a CUDA device by passing `device=f'cuda:{local_rank}'` (where `local_rank` is the local GPU id. + + with OnDevice(dtype=torch.float16, device='meta'): + model = MyModel() + + with OnDevice(dtype=torch.float16, device=f'cuda:{local_rank}'): + model = MyModel() + + """ + + _orig_torch_empty = torch.empty + _orig_torch_zeros = torch.zeros + _orig_torch_ones = torch.ones + _orig_torch_full = torch.full + + def __init__(self, dtype, device="meta", enabled=True): + self.dtype = dtype + self.enabled = enabled + self.device = device + + if device == "meta": + if pkg_version.parse('1.10') > pkg_version.parse(torch.__version__): + raise NotImplementedError( + "Meta tensor support is not available, please upgrade to torch 1.10+" + ) + + def fp_tensor_constructor(self, + fn: Callable, + target_fp_dtype: torch.dtype) -> Callable: + def wrapped_fn(*args, **kwargs) -> Tensor: + if kwargs.get("device", None) is None: + kwargs['device'] = self.device + tensor: Tensor = fn(*args, **kwargs) + if tensor.is_floating_point(): + tensor = tensor.to(target_fp_dtype) + return tensor + + return wrapped_fn + + def get_new_tensor_fn_for_dtype(self, dtype: torch.dtype) -> Callable: + def new_tensor(cls, *args) -> Tensor: + tensor = OnDevice._orig_torch_empty(0, device=self.device).new_empty(*args) + if tensor.is_floating_point(): + tensor = tensor.to(dtype) + return tensor + + return new_tensor + + def __enter__(self): + if not self.enabled: + return + torch.Tensor.__old_new__ = torch.Tensor.__new__ + torch.Tensor.__new__ = self.get_new_tensor_fn_for_dtype(self.dtype) + torch.empty = self.fp_tensor_constructor(self._orig_torch_empty, self.dtype) + torch.zeros = self.fp_tensor_constructor(self._orig_torch_zeros, self.dtype) + torch.ones = self.fp_tensor_constructor(self._orig_torch_ones, self.dtype) + torch.full = self.fp_tensor_constructor(self._orig_torch_full, self.dtype) + + def __exit__(self, exc_type, exc_value, traceback): + if not self.enabled: + return + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.empty = self._orig_torch_empty + torch.zeros = self._orig_torch_zeros + torch.ones = self._orig_torch_ones + torch.full = self._orig_torch_full diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index c3dbc49055b2..006fe6cc884f 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -52,6 +52,7 @@ def lm_eval_imports(): "distilgpt2", "Norod78/hebrew-bad_wiki-gpt_neo-tiny", "EleutherAI/gpt-j-6B", + "bigscience/bloom-350m", ] _all_models = HfApi().list_models() diff --git a/tests/unit/test_init_on_device.py b/tests/unit/test_init_on_device.py new file mode 100644 index 000000000000..2306d315f9e4 --- /dev/null +++ b/tests/unit/test_init_on_device.py @@ -0,0 +1,19 @@ +import torch +import pytest +from .simple_model import SimpleModel +from deepspeed import OnDevice +from packaging import version as pkg_version + + +@pytest.mark.parametrize('device', ['meta', 'cuda:0']) +def test_on_device(device): + if device == "meta" and pkg_version.parse( + torch.__version__) < pkg_version.parse("1.10"): + pytest.skip("meta tensors only became stable after torch 1.10") + + with OnDevice(dtype=torch.half, device=device): + model = SimpleModel(4) + + for p in model.parameters(): + assert p.device == torch.device(device) + assert p.dtype == torch.half