From 4652e249df103be0366315f9afc32e66d1cecbcb Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 14 Oct 2022 21:50:52 +0500 Subject: [PATCH 1/7] fix checkpoint loading when it is a dictionary --- deepspeed/inference/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 6f9d90e3f743..d51cfda79001 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -420,7 +420,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): if is_pipe_parallel: raise RuntimeError( 'pipeline parallelism is currently not supported in inference.') - if os.path.isdir(load_dir): + if not isinstance(load_dir, dict) and os.path.isdir(load_dir): if tag is None: latest_path = os.path.join(load_dir, "latest") if os.path.isfile(latest_path): From 89c64a92fc2628eb9cbb1815d3ab843423da33d0 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 23 Nov 2022 22:44:20 +0000 Subject: [PATCH 2/7] fix some issues with saving ckpt & int8 inference --- csrc/transformer/inference/csrc/pt_binding.cpp | 3 +-- csrc/transformer/inference/includes/inference_context.h | 2 +- deepspeed/module_inject/load_checkpoint.py | 5 +++-- deepspeed/module_inject/replace_module.py | 8 ++++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 7c9ab0bfd3ea..63efd2913387 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1293,7 +1293,6 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, } else { ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); } - if (q_int8) { quantized_gemm(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); } else { @@ -1407,7 +1406,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, q_int8, act_func_type); - return {output, res_add}; + return {output, output}; //res_add}; } template diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index b3851ca43b72..a03da882f778 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -106,7 +106,7 @@ class Context { const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); const int effective_head_size = (head_size > 128) ? head_size : padded_head_size; - size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size; + size_t activation_size = 32 * (num_heads * effective_head_size) * batch_size; // Other sequence length dimension is added when the final workSpaceSize is calculated size_t temp_size = batch_size * num_heads * max_out_tokens * 2; size_t cache_size = diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index f577a1a0e1bc..a9a1fc12750f 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -208,7 +208,7 @@ def load_module_recursive(module, prefix='', level=0): eps=child.eps) setattr(module, name, child) elif child.__class__ is nn.Linear: - child = LinearLayer(weight=child.weight, bias=child.bias) + child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias) setattr(module, name, child) else: ds_id = None @@ -224,7 +224,8 @@ def load_module_recursive(module, prefix='', level=0): else: load_module_recursive( child, - prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', + #prefix if level == 0 and ckpt_type == 'pp' else \ + prefix + name + '.', level + 1) load_module_recursive(r_module) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 8bf2268064ff..f491ea4bb724 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -553,6 +553,10 @@ def _transpose(x): if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): + if qkvb is None: + attn_block.attn_qkvb = None + if dense_b is None: + attn_block.attn_ob = None pass else: with GatheredParameters([ @@ -996,7 +1000,7 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') - config = json.dumps({ + new_config = json.dumps({ 'type': ckpt_name, 'base_dir': @@ -1020,7 +1024,7 @@ def replace_fn(child, _policy, layer_id=0): }) with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json", "w") as cfg: - cfg.write(config) + cfg.write(new_config) rep_sd = replaced_module.state_dict() for n, p in replaced_module.named_parameters(): From 90c5297c74cbc3cb01386b8a837194b599f61117 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 24 Nov 2022 03:22:47 +0000 Subject: [PATCH 3/7] fix quantized-inference & add generic support of checkpoint loading --- .../transformer/inference/csrc/pt_binding.cpp | 143 ++++++++++++------ .../inference/includes/inference_context.h | 2 +- deepspeed/module_inject/layers.py | 27 +++- deepspeed/module_inject/load_checkpoint.py | 136 ++++++++++++----- deepspeed/module_inject/replace_module.py | 36 +++-- deepspeed/module_inject/replace_policy.py | 75 +++++++++ deepspeed/ops/transformer/inference/ds_mlp.py | 3 + 7 files changed, 327 insertions(+), 95 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 63efd2913387..55e0ada4ab7b 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -763,11 +763,18 @@ void quantized_gemm(void* output, at::Tensor& weight, at::Tensor& qscale, int groups, - int bsz) + int bsz, + int hidden_size) { - T* weight16 = (T*)Context::Instance().GetWorkSpace() + - 12 * Context::Instance().GetMaxTokenLenght() * weight.size(1); - + T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz; + + // auto options = at::TensorOptions() + // .dtype(at::kHalf) + // .layout(at::kStrided) + // .device(at::kCUDA) + // .requires_grad(false); + // auto tmp = torch::empty(weight.sizes(), options); + // T* weight16 = (T*)tmp.data_ptr(); launch_dequantize(weight16, (int8_t*)weight.data_ptr(), (float*)qscale.data_ptr(), @@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ds_layer_norm_internal(workspace, input, gamma, beta, epsilon); if (q_int8) { - quantized_gemm(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - int out_size = q_int8 ? weight.size(0) : weight.size(1); int bsz = input.size(0) * input.size(1); T* workspace = (T*)Context::Instance().GetWorkSpace(); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); if (q_int8) { - quantized_gemm( - output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz); + quantized_gemm(output.data_ptr(), + (T*)input.data_ptr(), + weight, + q_scale, + q_scale.size(0), + bsz, + input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1294,7 +1306,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); } if (q_int8) { - quantized_gemm(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); + quantized_gemm( + intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1331,8 +1344,13 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, Context::Instance().GetCurrentStream()); } if (q_int8) { - quantized_gemm( - output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz); + quantized_gemm(output.data_ptr(), + intermediate, + weight1, + q_scale1, + q_scale1.size(0), + bsz, + input.size(2)); } else { float alpha = (T)1.0; float gemm_beta = (T)0.0; @@ -1406,7 +1424,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, q_int8, act_func_type); - return {output, output}; //res_add}; + return {output, output}; // res_add}; } template @@ -1448,64 +1466,97 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, template at::Tensor fused_gemm_gelu(at::Tensor& input, at::Tensor& weight, + at::Tensor& weight_scale, at::Tensor& bias, at::Tensor& weight_out, + at::Tensor& weight_out_scale, const float epsilon, bool preLayerNorm, + bool q_int8, bool async_op) { - auto input_cont = input.contiguous(); auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); + q_int8 = true; + + int intm_dim = q_int8 ? weight.size(0) : weight.size(1); + + // auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), + // {input.size(0), input.size(1), out_size}, + // options); + // T* intermediate = (T*)input.data_ptr() + torch::numel(input); + auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options); + + int bsz = input.size(0) * input.size(1); - auto intermediate = - at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); float alpha = (T)1.0; float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)intermediate.data_ptr(), + if (q_int8) { + quantized_gemm(intermediate.data_ptr(), + (T*)input.data_ptr(), + weight, + weight_scale, + weight_scale.size(0), + bsz, + input.size(2)); + } else { + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + intm_dim, + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input.data_ptr(), + (T*)intermediate.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } launch_bias_gelu((T*)intermediate.data_ptr(), (T*)bias.data_ptr(), - weight.size(1), + intm_dim, bsz, Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight_out.size(1), - bsz, - intermediate.size(2), - &alpha, - &gemm_beta, - (T*)weight_out.data_ptr(), - (T*)intermediate.data_ptr(), - (T*)output.data_ptr(), + q_int8 = true; + int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); + auto output = at::empty({input.size(0), input.size(1), out_size}, options); + if (q_int8) { + quantized_gemm(output.data_ptr(), + (T*)intermediate.data_ptr(), + weight_out, + weight_out_scale, + weight_out_scale.size(0), + bsz, + input.size(2)); + } else { + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + out_size, + bsz, + intm_dim, + &alpha, + &gemm_beta, + (T*)weight_out.data_ptr(), + (T*)intermediate.data_ptr(), + (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ - rocblas_gemm_algo_standard); + rocblas_gemm_algo_standard); #else - CUBLAS_GEMM_DEFAULT_TENSOR_OP); + CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif + } // cudaEventRecord(Context::Instance().GetCompEvent(2), // Context::Instance().GetCurrentStream(true)); return output; diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index a03da882f778..b3851ca43b72 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -106,7 +106,7 @@ class Context { const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128); const int effective_head_size = (head_size > 128) ? head_size : padded_head_size; - size_t activation_size = 32 * (num_heads * effective_head_size) * batch_size; + size_t activation_size = 16 * (num_heads * effective_head_size) * batch_size; // Other sequence length dimension is added when the final workSpaceSize is calculated size_t temp_size = batch_size * num_heads * max_out_tokens * 2; size_t cache_size = diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 1ae5bd5ccf9f..695cc3ce63aa 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -23,7 +23,7 @@ def forward(self, input): class LinearLayer(nn.Module): - def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): + def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): super(LinearLayer, self).__init__() if weight is not None: self.weight = weight @@ -67,3 +67,28 @@ def __init__(self, weight_shape, dtype=torch.float): def forward(self, input): return F.embedding(input, self.weight) + + +class OPTEmbedding(EmbeddingLayer): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + def __init__(self, weight_shape): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(weight_shape) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, + dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index a9a1fc12750f..2f72e3f4308e 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -1,7 +1,7 @@ from torch import nn import deepspeed.ops.transformer as transformer_inference from ..runtime.zero import GatheredParameters -from .layers import LinearLayer, Normalize, EmbeddingLayer +from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding import torch import gc @@ -11,7 +11,8 @@ def load_model_with_checkpoint(r_module, mp_replace, ckpt_type, weight_quantizer=None, - rank=0): + rank=0, + param_names=None): error_msgs = [] def transpose(data): @@ -138,37 +139,95 @@ def load_parameters(module, prefix): for n, child in module.named_children(): load_parameters(child, prefix + n + '.') else: - module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']))) - module.attention.attn_qkvb = mp_replace.copy( - module.attention.attn_qkvb.data, - sd[0][prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight']))) - module.attention.attn_ob = mp_replace.copy( - module.attention.attn_ob.data, - sd[0][prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'bias']) - module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']))) - module.mlp.inter_b = mp_replace.copy( - module.mlp.inter_b.data, - sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy(module.mlp.output_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']))) - module.mlp.output_b = mp_replace.copy( - module.mlp.output_b.data, - sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias']) + def maybe_copy(module, dst_name, src_name, qkv=False): + if src_name in sd[0]: + dst = getattr(module, dst_name) + if len(dst.shape) == 1: + if qkv: + dst = mp_replace.qkv_copy(dst, + (sd[0][src_name]).contiguous()) + else: + dst = mp_replace.copy(dst, sd[0][src_name]) + else: + if qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ + ((transpose(sd[0][src_name])).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ + transpose(sd[0][src_name]))) + setattr(module, dst_name, dst) + + def maybe_copy1(module, dst_name, src_names, qkv=False): + if src_names[0] in sd[0]: + q = sd[0][src_names[0]] + k = sd[0][src_names[1]] + v = sd[0][src_names[2]] + qkv_data = torch.cat((q, k, v), dim=0) + dst = getattr(module, dst_name) + if len(dst.shape) == 1: + if qkv: + dst = mp_replace.qkv_copy(dst, (qkv_data).contiguous()) + else: + dst = mp_replace.copy(dst, qkv_data) + else: + if qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data if weight_quantizer.q_int8 else \ + ((transpose(qkv_data)).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data if weight_quantizer.q_int8 else \ + transpose(qkv_data))) + setattr(module, dst_name, dst) + + if len(param_names) == 12: + qkv_w, qkv_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb = param_names + elif len(param_names) < 12: + q_w, k_w, v_w, attn_ow, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb = param_names + else: + q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb = param_names + maybe_copy(module, 'norm_w', prefix + inp_normw) + maybe_copy(module, 'norm_b', prefix + inp_normb) + if len(param_names) == 12: + maybe_copy(module.attention, 'attn_qkvw', prefix + qkv_w, qkv=True) + maybe_copy(module.attention, 'attn_qkvb', prefix + qkv_b, qkv=True) + elif len(param_names) < 12: + maybe_copy1(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w]) + else: + maybe_copy1(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w]) + maybe_copy1(module.attention, + 'attn_qkvb', + [prefix + q_b, + prefix + k_b, + prefix + v_b]) + maybe_copy(module.attention, 'attn_ow', prefix + attn_ow) + if len(param_names) > 12: + maybe_copy(module.attention, 'attn_ob', prefix + attn_ob) + maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw) + maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb) + maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw) + maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb) + maybe_copy(module.mlp, 'output_w', prefix + mlp_ow) + maybe_copy(module.mlp, 'output_b', prefix + mlp_ob) + + try: + import transformers + OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding + except: + OPTLearnedPositionalEmbedding = None layer_policies = { nn.Linear: load, nn.Embedding: load, @@ -176,7 +235,9 @@ def load_parameters(module, prefix): EmbeddingLayer: load, LinearLayer: load, Normalize: load, - transformer_inference.DeepSpeedTransformerInference: load_transformer_layer + transformer_inference.DeepSpeedTransformerInference: load_transformer_layer, + OPTLearnedPositionalEmbedding: load, + OPTEmbedding: load } all_ds_ids = {} @@ -208,7 +269,11 @@ def load_module_recursive(module, prefix='', level=0): eps=child.eps) setattr(module, name, child) elif child.__class__ is nn.Linear: - child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias) + child = LinearLayer(weight_shape=child.weight.shape, + bias=child.bias) + setattr(module, name, child) + elif child.__class__ is OPTLearnedPositionalEmbedding: + child = OPTEmbedding(weight_shape=ds_shape) setattr(module, name, child) else: ds_id = None @@ -236,7 +301,8 @@ def load_module_recursive(module, prefix='', level=0): if "word_embeddings." in n: embedding_weight = p assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" - r_module.lm_head.weight = embedding_weight + if embedding_weight is not None: + r_module.lm_head.weight = embedding_weight for sd_ in sd: del sd_ sd = None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f491ea4bb724..e2e666230d70 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -286,6 +286,9 @@ def _replace_module(module, policy): setattr(module, name, new_module) +selected_policy_g = None + + def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, @@ -325,6 +328,10 @@ def replace_with_policy(child, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) + + global selected_policy_g + if selected_policy_g is None: + selected_policy_g = policy if not policy.cuda_graph_supported: # policy says cuda graph is not supported raise an error if set assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable" @@ -915,6 +922,7 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, + param_names=selected_policy_g.get_param_names(), ) pbar.update(1) else: @@ -939,12 +947,14 @@ def replace_fn(child, _policy, layer_id=0): torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files ] - load_model_with_checkpoint(replaced_module, - sds, - mp_replace, - ckpt_type, - quantizer, - int(rank % tp_split_size)) + load_model_with_checkpoint( + replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % tp_split_size), + param_names=selected_policy_g.get_param_names()) sds = [None for _ in sds] gc.collect() @@ -959,12 +969,14 @@ def replace_fn(child, _policy, layer_id=0): checkpoint["non_tp"][i] ) if base_dir1 else checkpoint["non_tp"][i] sds = [torch.load(ckpt_file, map_location='cpu')] - load_model_with_checkpoint(replaced_module, - sds, - mp_replace, - ckpt_type, - quantizer, - int(rank % tp_split_size)) + load_model_with_checkpoint( + replaced_module, + sds, + mp_replace, + ckpt_type, + quantizer, + int(rank % tp_split_size), + param_names=selected_policy_g.get_param_names()) sds = [None for _ in sds] gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 9dcb4ace234e..8ab4ab9011da 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -139,6 +139,9 @@ def layerNorm(self): """ raise NotImplementedError + def get_param_names(self): + raise NotImplementedError + class HFBertLayerPolicy(TransformerPolicy): def __init__(self, client_module, inference=False): @@ -294,6 +297,20 @@ def layerNorm(self): self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', + class HFGPTJLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -339,6 +356,18 @@ def layerNorm(self): self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + def get_param_names(self): + return 'attn.q_proj.weight', \ + 'attn.k_proj.weight', \ + 'attn.v_proj.weight', \ + 'attn.out_proj.weight', \ + 'mlp.fc_in.weight', \ + 'mlp.fc_in.bias', \ + 'mlp.fc_out.weight', \ + 'mlp.fc_out.bias', \ + 'ln_1.weight', \ + 'ln_1.bias', + class MegatronLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -501,6 +530,20 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'self_attention.query_key_value.weight', \ + 'self_attention.query_key_value.bias', \ + 'self_attention.dense.weight', \ + 'self_attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', + class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -555,6 +598,20 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', + class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -612,6 +669,24 @@ def layerNorm(self): self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias + def get_param_names(self): + return 'self_attn.q_proj.weight', \ + 'self_attn.q_proj.bias', \ + 'self_attn.k_proj.weight', \ + 'self_attn.k_proj.bias', \ + 'self_attn.v_proj.weight', \ + 'self_attn.v_proj.bias', \ + 'self_attn.out_proj.weight', \ + 'self_attn.out_proj.bias', \ + 'fc1.weight', \ + 'fc1.bias', \ + 'fc2.weight', \ + 'fc2.bias', \ + 'final_layer_norm.weight', \ + 'final_layer_norm.bias', \ + 'self_attn_layer_norm.weight', \ + 'self_attn_layer_norm.bias', + # transformer-based policies replace_policies = [ diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index 4f1c705c55ea..86275f0381dc 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -41,10 +41,13 @@ def forward(ctx, if attn_nw is None: output = fused_gemm_gelu(residual_norm, inter_w, + inter_w.scale, inter_b, output_w, + output_w.scale, config.epsilon, config.pre_layer_norm, + config.q_int8, False) else: output, residual_add = mlp_gemm_func(input, From 20f30ce21eb58575c48578ea3c4f0bb77f4aef58 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 24 Nov 2022 06:37:30 +0000 Subject: [PATCH 4/7] remove int8 hard-coded flag --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 -- deepspeed/module_inject/load_checkpoint.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 55e0ada4ab7b..bb983b39cca5 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1480,7 +1480,6 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - q_int8 = true; int intm_dim = q_int8 ? weight.size(0) : weight.size(1); @@ -1528,7 +1527,6 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, bsz, Context::Instance().GetCurrentStream()); - q_int8 = true; int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); auto output = at::empty({input.size(0), input.size(1), out_size}, options); if (q_int8) { diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 2f72e3f4308e..e6c580fce711 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -201,7 +201,8 @@ def maybe_copy1(module, dst_name, src_names, qkv=False): 'attn_qkvw', [prefix + q_w, prefix + k_w, - prefix + v_w]) + prefix + v_w], + qkv=True) else: maybe_copy1(module.attention, 'attn_qkvw', From eef04b3995bb2e6d594b14f6a9edab6de4cdb0ac Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 25 Nov 2022 04:59:03 +0000 Subject: [PATCH 5/7] fix mlp return tensors --- csrc/transformer/inference/csrc/pt_binding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index bb983b39cca5..556b4b133132 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1424,7 +1424,7 @@ std::vector ds_mlp_gemm(at::Tensor& input, q_int8, act_func_type); - return {output, output}; // res_add}; + return {output, res_add}; } template From 40f100087eb2ba6d61d5f3ca066a0a0c0aee383c Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Mon, 28 Nov 2022 21:31:07 +0000 Subject: [PATCH 6/7] fix several issue to load checkpoints of GPT-J, GPT-NEOX, and OPT with different TP-size --- .../transformer/inference/csrc/pt_binding.cpp | 1 + .../transformers/ds_transformer.py | 3 +- deepspeed/module_inject/layers.py | 6 +- deepspeed/module_inject/load_checkpoint.py | 126 ++++++++++++------ deepspeed/module_inject/replace_module.py | 32 +++-- deepspeed/module_inject/replace_policy.py | 26 ++-- .../ops/transformer/inference/ds_attention.py | 3 +- deepspeed/ops/transformer/inference/ds_mlp.py | 2 +- 8 files changed, 134 insertions(+), 65 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 556b4b133132..619ed6923398 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1343,6 +1343,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, bsz, Context::Instance().GetCurrentStream()); } + if (q_int8) { quantized_gemm(output.data_ptr(), intermediate, diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 49b7c81698a1..d281f8eba39c 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -68,7 +68,7 @@ def __init__(self, merge_count, mlp_extra_grouping) - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), @@ -131,7 +131,6 @@ def forward( if (self.config.fp16 or self.config.q_int8) \ and input.dtype == torch.float: input = input.half() - with torch.no_grad(): attention_output, key, value, context_outputtn_ctx, inp_norm = \ self.attention(input, diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 695cc3ce63aa..cb738ec0c792 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -33,10 +33,12 @@ def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): torch.empty(weight_shape, dtype=dtype, device=torch.cuda.current_device())) + self.bias = Parameter( torch.empty(weight_shape[0], dtype=dtype, - device=torch.cuda.current_device())) + device=torch.cuda.current_device())) \ + if bias is not None else None def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) @@ -57,7 +59,7 @@ def forward(self, input): class EmbeddingLayer(nn.Module): - def __init__(self, weight_shape, dtype=torch.float): + def __init__(self, weight_shape, dtype=torch.half): super(EmbeddingLayer, self).__init__() self.weight = Parameter( torch.empty(weight_shape[0], diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index e6c580fce711..cfc2c50bc9be 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -12,14 +12,17 @@ def load_model_with_checkpoint(r_module, ckpt_type, weight_quantizer=None, rank=0, - param_names=None): + param_names=None, + transformer_config=None, + megatron_v2=False): error_msgs = [] def transpose(data): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None + with torch.no_grad(): + data = data.contiguous() + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None return data.reshape(data.shape[-1], data.shape[-2]) def load(module, prefix): @@ -88,7 +91,7 @@ def load_parameters(module, prefix): else: assert tmp_data.dtype != torch.int8, \ '''Merging of the checkpoints are not supported when using INT8 checkpoint! \ - Please use a as many GPUs as TP-size for the checkpoint''' + Please use a as many GPUs as TP-size for the checkpoint''' all_data = [ sd[j][prefix + n] if type(sd[j][prefix + n]) is list else @@ -140,22 +143,55 @@ def load_parameters(module, prefix): load_parameters(child, prefix + n + '.') else: - def maybe_copy(module, dst_name, src_name, qkv=False): + def _transpose(x): + heads = transformer_config.heads // mp_replace.mp_size + attention_head_size = x.shape[-1] // heads + new_x_shape = x.size()[:-1] + (heads, attention_head_size) + x_1 = x.view(*new_x_shape) + (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1)) + if len(q.shape) > 2: + return torch.cat((q.reshape(q.shape[0], + -1), + k.reshape(q.shape[0], + -1), + v.reshape(q.shape[0], + -1)), + dim=-1).reshape(x.shape) + else: + return torch.cat((q.reshape(-1), + k.reshape(-1), + v.reshape(-1)), + dim=-1).reshape(x.shape) + + def maybe_copy(module, + dst_name, + src_name, + qkv=False, + megatron_v2=False, + split_qkv=False): if src_name in sd[0]: dst = getattr(module, dst_name) + tmp = sd[0][src_name].cuda() if len(dst.shape) == 1: - if qkv: - dst = mp_replace.qkv_copy(dst, - (sd[0][src_name]).contiguous()) + if split_qkv: + dst = mp_replace.qkv_copy(dst, tmp) else: - dst = mp_replace.copy(dst, sd[0][src_name]) + dst = mp_replace.copy(dst, tmp) + if qkv and megatron_v2: + dst = torch.nn.parameter.Parameter( + _transpose(dst).contiguous()) else: - if qkv: - dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ - ((transpose(sd[0][src_name])).contiguous()))) + if split_qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, tmp if weight_quantizer.q_int8 else \ + (transpose(tmp).contiguous()))) else: - dst = weight_quantizer.quantize(mp_replace.copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ - transpose(sd[0][src_name]))) + dst = weight_quantizer.quantize(mp_replace.copy(dst, tmp if weight_quantizer.q_int8 else \ + transpose(tmp))) + if qkv and megatron_v2: + scale1 = dst.scale + dst = torch.nn.parameter.Parameter( + _transpose(dst).contiguous()) + dst.scale = scale1 setattr(module, dst_name, dst) def maybe_copy1(module, dst_name, src_names, qkv=False): @@ -167,55 +203,68 @@ def maybe_copy1(module, dst_name, src_names, qkv=False): dst = getattr(module, dst_name) if len(dst.shape) == 1: if qkv: - dst = mp_replace.qkv_copy(dst, (qkv_data).contiguous()) + dst = mp_replace.qkv_copy(dst, + (qkv_data.cuda()).contiguous()) else: - dst = mp_replace.copy(dst, qkv_data) + dst = mp_replace.copy(dst, qkv_data.cuda()) else: if qkv: - dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data if weight_quantizer.q_int8 else \ - ((transpose(qkv_data)).contiguous()))) + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \ + ((transpose(qkv_data.cuda())).contiguous()))) else: - dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data if weight_quantizer.q_int8 else \ - transpose(qkv_data))) + dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \ + transpose(qkv_data.cuda()))) setattr(module, dst_name, dst) - if len(param_names) == 12: + if len(param_names) == 14: qkv_w, qkv_b, attn_ow, attn_ob, \ mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ - inp_normw, inp_normb, attn_nw, attn_nb = param_names - elif len(param_names) < 12: + inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names + elif len(param_names) < 14: q_w, k_w, v_w, attn_ow, \ mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ - inp_normw, inp_normb = param_names + inp_normw, inp_normb, _, split_qkv = param_names else: q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \ mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ - inp_normw, inp_normb, attn_nw, attn_nb = param_names + inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names maybe_copy(module, 'norm_w', prefix + inp_normw) maybe_copy(module, 'norm_b', prefix + inp_normb) - if len(param_names) == 12: - maybe_copy(module.attention, 'attn_qkvw', prefix + qkv_w, qkv=True) - maybe_copy(module.attention, 'attn_qkvb', prefix + qkv_b, qkv=True) - elif len(param_names) < 12: + if len(param_names) == 14: + maybe_copy(module.attention, + 'attn_qkvw', + prefix + qkv_w, + qkv=True, + megatron_v2=megatron_v2, + split_qkv=split_qkv) + maybe_copy(module.attention, + 'attn_qkvb', + prefix + qkv_b, + qkv=True, + megatron_v2=megatron_v2, + split_qkv=split_qkv) + elif len(param_names) < 14: maybe_copy1(module.attention, 'attn_qkvw', [prefix + q_w, prefix + k_w, prefix + v_w], - qkv=True) + qkv=split_qkv) else: maybe_copy1(module.attention, 'attn_qkvw', [prefix + q_w, prefix + k_w, - prefix + v_w]) + prefix + v_w], + qkv=split_qkv) maybe_copy1(module.attention, 'attn_qkvb', [prefix + q_b, prefix + k_b, - prefix + v_b]) + prefix + v_b], + qkv=split_qkv) maybe_copy(module.attention, 'attn_ow', prefix + attn_ow) - if len(param_names) > 12: + if len(param_names) >= 14: maybe_copy(module.attention, 'attn_ob', prefix + attn_ob) maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw) maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb) @@ -263,7 +312,6 @@ def load_module_recursive(module, prefix='', level=0): ds_shape = child.weight.shape else: ds_shape = child.weight.ds_shape - if child.__class__ is nn.LayerNorm: child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, @@ -290,7 +338,7 @@ def load_module_recursive(module, prefix='', level=0): else: load_module_recursive( child, - #prefix if level == 0 and ckpt_type == 'pp' else \ + prefix if (level == 0 and ckpt_type == 'pp') and param_names[-2] else \ prefix + name + '.', level + 1) @@ -299,9 +347,9 @@ def load_module_recursive(module, prefix='', level=0): #XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon embedding_weight = None for n, p in r_module.named_parameters(): - if "word_embeddings." in n: + if "word_embeddings." in n or "embed_tokens." in n: embedding_weight = p - assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" + if embedding_weight is not None: r_module.lm_head.weight = embedding_weight for sd_ in sd: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e2e666230d70..7233d9adfa6f 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -36,13 +36,13 @@ def qkv_copy(self, dst, src): return src src_shape = src.shape dst_shape = dst.shape + if self.out_dim == 0: src_split = torch.split(src.data, src_shape[self.out_dim] // self.mp_size, dim=0) else: src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1) - if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[self.out_dim] == dst_shape[self.out_dim]: return torch.nn.parameter.Parameter(src) @@ -54,7 +54,6 @@ def qkv_copy(self, dst, src): qkv_size, dim=self.out_dim) for src_s in src_split ] - weight_split = [ torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=self.out_dim) for i in range(len(qkv_split[0])) @@ -137,8 +136,7 @@ def get_transformer_name(replaced_module): class GroupQuantizer: - def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): - self.num_groups = num_groups + def __init__(self, q_int8=True, group_size=1, num_bits=8): self.group_size = group_size self.num_bits = num_bits self.q_int8 = q_int8 @@ -149,8 +147,9 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): inputs.scale = torch.empty(1) return inputs q_range = 2**self.num_bits + num_groups = inputs.shape[0] // self.group_size inputs = inputs.to(torch.cuda.current_device()) - input_flat = inputs.reshape(self.num_groups, -1).contiguous() + input_flat = inputs.reshape(num_groups, -1).contiguous() input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range) @@ -160,7 +159,7 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): #print(inputs.shape) inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) input_flat = [ - inputs_split[i].reshape(self.num_groups, + inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2) ] input_min = [ @@ -182,7 +181,7 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], - dim=0).reshape(self.num_groups, + dim=0).reshape(num_groups, -1).contiguous() return out @@ -287,6 +286,8 @@ def _replace_module(module, policy): selected_policy_g = None +megatron_v2_g = False +transformer_config_g = None def replace_transformer_layer(orig_layer_impl, @@ -328,7 +329,6 @@ def replace_with_policy(child, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) - global selected_policy_g if selected_policy_g is None: selected_policy_g = policy @@ -347,6 +347,8 @@ def replace_with_policy(child, moe = True attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() + global megatron_v2_g + megatron_v2_g = megatron_v2 if not moe or config.moe.type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: @@ -446,6 +448,8 @@ def replace_with_policy(child, bigscience_bloom=bigscience_bloom, max_out_tokens=config.max_out_tokens, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) + global transformer_config_g + transformer_config_g = transformer_config if moe: new_module = transformer_inference.DeepSpeedMoEInference( @@ -923,7 +927,8 @@ def replace_fn(child, _policy, layer_id=0): ckpt_type, quantizer, param_names=selected_policy_g.get_param_names(), - ) + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) pbar.update(1) else: import gc @@ -954,7 +959,9 @@ def replace_fn(child, _policy, layer_id=0): ckpt_type, quantizer, int(rank % tp_split_size), - param_names=selected_policy_g.get_param_names()) + param_names=selected_policy_g.get_param_names(), + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) sds = [None for _ in sds] gc.collect() @@ -976,7 +983,9 @@ def replace_fn(child, _policy, layer_id=0): ckpt_type, quantizer, int(rank % tp_split_size), - param_names=selected_policy_g.get_param_names()) + param_names=selected_policy_g.get_param_names(), + transformer_config=transformer_config_g, + megatron_v2=megatron_v2_g) sds = [None for _ in sds] gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") @@ -1002,6 +1011,7 @@ def replace_fn(child, _policy, layer_id=0): non_tp_ckpt_name = f'non-tp.pt' ckpt_files = [non_tp_ckpt_name] os.makedirs(config.save_mp_checkpoint_path, exist_ok=True) + if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save( diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 8ab4ab9011da..c5b64455baf6 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -309,7 +309,9 @@ def get_param_names(self): 'input_layernorm.weight', \ 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ - 'post_attention_layernorm.bias', + 'post_attention_layernorm.bias', \ + False, \ + True class HFGPTJLayerPolicy(TransformerPolicy): @@ -366,7 +368,9 @@ def get_param_names(self): 'mlp.fc_out.weight', \ 'mlp.fc_out.bias', \ 'ln_1.weight', \ - 'ln_1.bias', + 'ln_1.bias', \ + False, \ + True class MegatronLayerPolicy(TransformerPolicy): @@ -542,7 +546,9 @@ def get_param_names(self): 'input_layernorm.weight', \ 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ - 'post_attention_layernorm.bias', + 'post_attention_layernorm.bias', \ + True, \ + False class GPTNEOXLayerPolicy(TransformerPolicy): @@ -610,7 +616,9 @@ def get_param_names(self): 'input_layernorm.weight', \ 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ - 'post_attention_layernorm.bias', + 'post_attention_layernorm.bias', \ + False, \ + False class HFOPTLayerPolicy(TransformerPolicy): @@ -625,9 +633,9 @@ def __init__(self, client_module, inference=True): try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer - if isinstance(DSPolicy.hf_model_config, + if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig): - self.pre_attn_norm = self.hf_model_config.do_layer_norm_before + self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before except: HFOPTLayerPolicy._orig_layer_class = None @@ -682,10 +690,12 @@ def get_param_names(self): 'fc1.bias', \ 'fc2.weight', \ 'fc2.bias', \ + 'self_attn_layer_norm.weight', \ + 'self_attn_layer_norm.bias', \ 'final_layer_norm.weight', \ 'final_layer_norm.bias', \ - 'self_attn_layer_norm.weight', \ - 'self_attn_layer_norm.bias', + True, \ + True # transformer-based policies diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 610bd882ecf4..d9df8e98a3de 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -331,7 +331,6 @@ def selfAttention_fp(): False, attn_ow.scale, config.q_int8) - return output, key_layer, value_layer, context_layer, qkv_out[-1] def selfAttention_int8(): @@ -394,7 +393,7 @@ def __init__(self, data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index 86275f0381dc..277ba1818286 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -99,7 +99,7 @@ def __init__(self, self.config = config data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), From 25e9bbffc5dcc44eb021203104e0d86b46dde5a5 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 6 Dec 2022 23:50:01 +0500 Subject: [PATCH 7/7] add more comments & description for checkpoint-loading module --- deepspeed/module_inject/load_checkpoint.py | 46 +++++++------ deepspeed/module_inject/replace_policy.py | 75 +++++++++++++++------- 2 files changed, 78 insertions(+), 43 deletions(-) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index cfc2c50bc9be..13b743f6f781 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -163,6 +163,8 @@ def _transpose(x): v.reshape(-1)), dim=-1).reshape(x.shape) + # This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor. + # Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist! def maybe_copy(module, dst_name, src_name, @@ -194,7 +196,8 @@ def maybe_copy(module, dst.scale = scale1 setattr(module, dst_name, dst) - def maybe_copy1(module, dst_name, src_names, qkv=False): + # Extending the maybe_copy function for when the q, k, and v are in separate parameters! + def maybe_copy_qkv(module, dst_name, src_names, split_qkv=False): if src_names[0] in sd[0]: q = sd[0][src_names[0]] k = sd[0][src_names[1]] @@ -202,13 +205,13 @@ def maybe_copy1(module, dst_name, src_names, qkv=False): qkv_data = torch.cat((q, k, v), dim=0) dst = getattr(module, dst_name) if len(dst.shape) == 1: - if qkv: + if split_qkv: dst = mp_replace.qkv_copy(dst, (qkv_data.cuda()).contiguous()) else: dst = mp_replace.copy(dst, qkv_data.cuda()) else: - if qkv: + if split_qkv: dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \ ((transpose(qkv_data.cuda())).contiguous()))) else: @@ -228,6 +231,7 @@ def maybe_copy1(module, dst_name, src_names, qkv=False): q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \ mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names + maybe_copy(module, 'norm_w', prefix + inp_normw) maybe_copy(module, 'norm_b', prefix + inp_normb) if len(param_names) == 14: @@ -244,25 +248,25 @@ def maybe_copy1(module, dst_name, src_names, qkv=False): megatron_v2=megatron_v2, split_qkv=split_qkv) elif len(param_names) < 14: - maybe_copy1(module.attention, - 'attn_qkvw', - [prefix + q_w, - prefix + k_w, - prefix + v_w], - qkv=split_qkv) + maybe_copy_qkv(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w], + split_qkv=split_qkv) else: - maybe_copy1(module.attention, - 'attn_qkvw', - [prefix + q_w, - prefix + k_w, - prefix + v_w], - qkv=split_qkv) - maybe_copy1(module.attention, - 'attn_qkvb', - [prefix + q_b, - prefix + k_b, - prefix + v_b], - qkv=split_qkv) + maybe_copy_qkv(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w], + split_qkv=split_qkv) + maybe_copy_qkv(module.attention, + 'attn_qkvb', + [prefix + q_b, + prefix + k_b, + prefix + v_b], + split_qkv=split_qkv) maybe_copy(module.attention, 'attn_ow', prefix + attn_ow) if len(param_names) >= 14: maybe_copy(module.attention, 'attn_ob', prefix + attn_ob) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index c5b64455baf6..4dd9e5b0855e 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -92,15 +92,19 @@ class TransformerPolicy(DSPolicy): hf_model_config = None def __init__( - self, - inference=True, - linear_layer=True, - scale_attention=True, - megatron_v2=False, - # the type of activation function used in MLP - mlp_act_func_type=ActivationFuncType.GELU, - # applies layer norm before attention if `pre_attn_norm` is set to True - pre_attn_norm=True): + self, + inference=True, + linear_layer=True, + scale_attention=True, + megatron_v2=False, + # the type of activation function used in MLP + mlp_act_func_type=ActivationFuncType.GELU, + # applies layer norm before attention if `pre_attn_norm` is set to True + pre_attn_norm=True, + # this flag shows whether or not using prefix in loading the checkpoint + use_load_prefix=False, + # whether or not the qkv is stored in the split-format + split_qkv=True): super().__init__() self.inference = inference self.linear_layer = linear_layer @@ -108,6 +112,7 @@ def __init__( self.is_megatron_v2 = megatron_v2 self.mlp_act_func_type = mlp_act_func_type self.pre_attn_norm = pre_attn_norm + self.load_prefix = False def attention(self): """ @@ -140,6 +145,28 @@ def layerNorm(self): raise NotImplementedError def get_param_names(self): + """ + Returns all the transformer parameter names to + be loaded from checkpoint files. The order of + the names is as follows: + 1. Attention weights and biases; + 2. MLP weights and biases; + 3. LayerNorm weights and biases; + In addition to the parameter names, we require two + more parameters to help read the the data correctly + from the checkpoint and split the qkv heads in the + right order: + 1. `use_load_prefix` (Default: False): this specifies + whether we need to use the name of first abstraction + layer of the model for searching the parameter's name + in a checkpoint file. For more information of how this + is used please see + https://github.com/microsoft/DeepSpeed/blob/fix-ckpt-loading/deepspeed/module_inject/load_checkpoint.py#L341 + 2. `split_qkv` (Default: True): we use this flag when splitting + the qkv parameter into heads. If it is False, it means the heads + of q, k, and v are stored together and needs to split in the + DeepSpeed-Inference API. + """ raise NotImplementedError @@ -310,8 +337,8 @@ def get_param_names(self): 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ 'post_attention_layernorm.bias', \ - False, \ - True + self.use_load_prefix, \ + self.split_qkv class HFGPTJLayerPolicy(TransformerPolicy): @@ -369,8 +396,8 @@ def get_param_names(self): 'mlp.fc_out.bias', \ 'ln_1.weight', \ 'ln_1.bias', \ - False, \ - True + self.use_load_prefix, \ + self.split_qkv class MegatronLayerPolicy(TransformerPolicy): @@ -496,7 +523,11 @@ def layerNorm(self): class BLOOMLayerPolicy(TransformerPolicy): _orig_layer_class = None - def __init__(self, client_module, inference=True): + def __init__(self, + client_module, + inference=True, + use_load_prefix=True, + split_qkv=False): super().__init__(inference, linear_layer=True) self.client_module = client_module try: @@ -547,15 +578,15 @@ def get_param_names(self): 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ 'post_attention_layernorm.bias', \ - True, \ - False + self.use_load_prefix, \ + self.split_qkv class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None version = 0 - def __init__(self, client_module, inference=True, megatron_v2=True): + def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False): super().__init__(inference, megatron_v2=megatron_v2) self.client_module = client_module if GPTNEOXLayerPolicy._orig_layer_class is None: @@ -617,14 +648,14 @@ def get_param_names(self): 'input_layernorm.bias', \ 'post_attention_layernorm.weight', \ 'post_attention_layernorm.bias', \ - False, \ - False + self.use_load_prefix, \ + self.split_qkv class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None - def __init__(self, client_module, inference=True): + def __init__(self, client_module, inference=True, use_load_prefix=True): super().__init__(inference, linear_layer=True, mlp_act_func_type=ActivationFuncType.ReLU, @@ -694,8 +725,8 @@ def get_param_names(self): 'self_attn_layer_norm.bias', \ 'final_layer_norm.weight', \ 'final_layer_norm.bias', \ - True, \ - True + self.use_load_prefix, \ + self.split_qkv # transformer-based policies