diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index b3851ca43b72..d79f7c2e107d 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -161,6 +161,16 @@ class Context { inline size_t GetMaxTokenLenght() const { return _max_seq_len; } cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; } + inline void advance_tokens() + { + if (_num_tokens >= _max_seq_len) + printf( + "Requesting to generate more tokens (%d) than max-seq-len allowed by cache (%d)\n", + _num_tokens, + _max_seq_len); + assert(_num_tokens < _max_seq_len); + _num_tokens++; + } size_t get_workspace_size() const { return _workSpaceSize; } void* GetWorkSpace() { return _workspace; } @@ -182,7 +192,7 @@ class Context { inline unsigned current_tokens() const { return _num_tokens; } - inline void advance_tokens() { _num_tokens++; } + //inline void advance_tokens() { _num_tokens++; } cudaStream_t GetCommStream(bool async_op = false) { diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 49b7c81698a1..d313b16b6f31 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -68,7 +68,7 @@ def __init__(self, merge_count, mlp_extra_grouping) - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 1ae5bd5ccf9f..7b9f14756347 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -23,7 +23,7 @@ def forward(self, input): class LinearLayer(nn.Module): - def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): + def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): super(LinearLayer, self).__init__() if weight is not None: self.weight = weight @@ -34,9 +34,10 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): dtype=dtype, device=torch.cuda.current_device())) self.bias = Parameter( - torch.empty(weight_shape[0], - dtype=dtype, - device=torch.cuda.current_device())) + torch.empty( + weight_shape[0], + dtype=dtype, + device=torch.cuda.current_device())) if bias is not None else None def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) @@ -67,3 +68,27 @@ def __init__(self, weight_shape, dtype=torch.float): def forward(self, input): return F.embedding(input, self.weight) + +class OPTEmbedding(EmbeddingLayer): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + def __init__(self, weight_shape): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(weight_shape) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, + dim=1).type_as(attention_mask) * + attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) \ No newline at end of file diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index f577a1a0e1bc..c67c6b8a18aa 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -1,7 +1,7 @@ from torch import nn import deepspeed.ops.transformer as transformer_inference from ..runtime.zero import GatheredParameters -from .layers import LinearLayer, Normalize, EmbeddingLayer +from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding import torch import gc @@ -11,7 +11,9 @@ def load_model_with_checkpoint(r_module, mp_replace, ckpt_type, weight_quantizer=None, - rank=0): + rank=0, + transformer_config=None, + param_names=None): error_msgs = [] def transpose(data): @@ -138,37 +140,93 @@ def load_parameters(module, prefix): for n, child in module.named_children(): load_parameters(child, prefix + n + '.') else: - module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']))) - module.attention.attn_qkvb = mp_replace.copy( - module.attention.attn_qkvb.data, - sd[0][prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow, - weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight']))) - module.attention.attn_ob = mp_replace.copy( - module.attention.attn_ob.data, - sd[0][prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' + - 'bias']) - module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']))) - module.mlp.inter_b = mp_replace.copy( - module.mlp.inter_b.data, - sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy(module.mlp.output_w, - weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \ - weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']))) - module.mlp.output_b = mp_replace.copy( - module.mlp.output_b.data, - sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias']) - + def maybe_copy(module, dst_name, src_name, qkv=False): + if src_name in sd[0]: + dst = getattr(module, dst_name) + if len(dst.shape) == 1: + if qkv: + dst = mp_replace.qkv_copy( + dst, + (sd[0][src_name]).contiguous()) + else: + dst = mp_replace.copy(dst, sd[0][src_name]) + else: + if qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ + ((transpose(sd[0][src_name])).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, sd[0][src_name] if weight_quantizer.q_int8 else \ + transpose(sd[0][src_name]))) + setattr(module, dst_name, dst) + def maybe_copy1(module, dst_name, src_names, qkv=False): + if src_names[0] in sd[0]: + q = sd[0][src_names[0]] + k = sd[0][src_names[1]] + v = sd[0][src_names[2]] + qkv_data = torch.cat((q, k, v), dim=0) + dst = getattr(module, dst_name) + if len(dst.shape) == 1: + if qkv: + dst = mp_replace.qkv_copy(dst, + (qkv_data).contiguous()) + else: + dst = mp_replace.copy(dst, qkv_data) + else: + if qkv: + dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data if weight_quantizer.q_int8 else \ + ((transpose(qkv_data)).contiguous()))) + else: + dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data if weight_quantizer.q_int8 else \ + transpose(qkv_data))) + setattr(module, dst_name, dst) + if len(param_names) == 12: + qkv_w, qkv_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb = param_names + elif len(param_names) < 12: + q_w, k_w, v_w, attn_ow, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb = param_names + else: + q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \ + mlp_intw, mlp_intb, mlp_ow, mlp_ob, \ + inp_normw, inp_normb, attn_nw, attn_nb = param_names + maybe_copy(module, 'norm_w', prefix + inp_normw) + maybe_copy(module, 'norm_b', prefix + inp_normb) + if len(param_names) == 12: + maybe_copy(module.attention, 'attn_qkvw', prefix + qkv_w, qkv=True) + maybe_copy(module.attention, 'attn_qkvb', prefix + qkv_b, qkv=True) + elif len(param_names) < 12: + maybe_copy1(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w]) + else: + maybe_copy1(module.attention, + 'attn_qkvw', + [prefix + q_w, + prefix + k_w, + prefix + v_w]) + maybe_copy1(module.attention, + 'attn_qkvb', + [prefix + q_b, + prefix + k_b, + prefix + v_b]) + maybe_copy(module.attention, 'attn_ow', prefix + attn_ow) + if len(param_names) > 12: + maybe_copy(module.attention, 'attn_ob', prefix + attn_ob) + maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw) + maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb) + maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw) + maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb) + maybe_copy(module.mlp, 'output_w', prefix + mlp_ow) + maybe_copy(module.mlp, 'output_b', prefix + mlp_ob) + try: + import transformers + OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding + except: + OPTLearnedPositionalEmbedding = None layer_policies = { nn.Linear: load, nn.Embedding: load, @@ -176,7 +234,9 @@ def load_parameters(module, prefix): EmbeddingLayer: load, LinearLayer: load, Normalize: load, - transformer_inference.DeepSpeedTransformerInference: load_transformer_layer + transformer_inference.DeepSpeedTransformerInference: load_transformer_layer, + OPTLearnedPositionalEmbedding: load, + OPTEmbedding: load } all_ds_ids = {} @@ -210,6 +270,9 @@ def load_module_recursive(module, prefix='', level=0): elif child.__class__ is nn.Linear: child = LinearLayer(weight=child.weight, bias=child.bias) setattr(module, name, child) + elif child.__class__ is OPTLearnedPositionalEmbedding: + child = OPTEmbedding(weight_shape=ds_shape) + setattr(module, name, child) else: ds_id = None if hasattr(child.weight, 'ds_id'): @@ -222,6 +285,7 @@ def load_module_recursive(module, prefix='', level=0): layer_policies[child.__class__](child, prefix + name + '.') else: + load_module_recursive( child, prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', @@ -229,13 +293,13 @@ def load_module_recursive(module, prefix='', level=0): load_module_recursive(r_module) - #XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon embedding_weight = None for n, p in r_module.named_parameters(): - if "word_embeddings." in n: + if "word_embeddings." in n or 'embed_tokens.' in n: embedding_weight = p - assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist" - r_module.lm_head.weight = embedding_weight + if hasattr(r_module, 'lm_head'): + if embedding_weight is not None: + r_module.lm_head.weight = embedding_weight for sd_ in sd: del sd_ sd = None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py old mode 100755 new mode 100644 index 52f8af00f6f4..e161484e5276 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -187,6 +187,9 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): return out +transformer_config_g = None +selected_policy_g = None + def _module_match(module): for policy in generic_policies: policy = policy() @@ -285,7 +288,6 @@ def _replace_module(module, policy): print(f"**** found and replaced {name} w. {type(new_module)}") setattr(module, name, new_module) - def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -361,6 +363,9 @@ def replace_with_policy(child, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) + global selected_policy_g + if selected_policy_g is None: + selected_policy_g = policy if not policy.cuda_graph_supported: # policy says cuda graph is not supported raise an error if set assert not enable_cuda_graph, "cuda graph is not supported with this model, please disable" @@ -472,6 +477,9 @@ def replace_with_policy(child, bigscience_bloom=bigscience_bloom, max_out_tokens=max_out_tokens, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) + global transformer_config_g + if transformer_config_g is None: + transformer_config_g = transformer_config if quantize and quantize_settings is not None: (quantization_scales, @@ -978,6 +986,8 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, + transformer_config=transformer_config_g, + param_names=selected_policy_g.get_param_names(), ) pbar.update(1) else: @@ -1007,7 +1017,9 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, - int(rank % tp_split_size)) + int(rank % tp_split_size), + transformer_config=transformer_config_g, + param_names=selected_policy_g.get_param_names()) sds = [None for _ in sds] gc.collect() @@ -1027,7 +1039,9 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, quantizer, - int(rank % tp_split_size)) + int(rank % tp_split_size), + transformer_config=transformer_config_g, + param_names=selected_policy_g.get_param_names()) sds = [None for _ in sds] gc.collect() print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 9dcb4ace234e..306f799871c2 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -293,7 +293,19 @@ def layerNorm(self): self.client_module.ln_2.bias, \ self.client_module.ln_1.weight, \ self.client_module.ln_1.bias - + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', class HFGPTJLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -339,6 +351,17 @@ def layerNorm(self): self.client_module.ln_1.weight, \ self.client_module.ln_1.bias + def get_param_names(self): + return 'attn.q_proj.weight', \ + 'attn.k_proj.weight', \ + 'attn.v_proj.weight', \ + 'attn.out_proj.weight', \ + 'mlp.fc_in.weight', \ + 'mlp.fc_in.bias', \ + 'mlp.fc_out.weight', \ + 'mlp.fc_out.bias', \ + 'ln_1.weight', \ + 'ln_1.bias', class MegatronLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -501,6 +524,19 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'self_attention.query_key_value.weight', \ + 'self_attention.query_key_value.bias', \ + 'self_attention.dense.weight', \ + 'self_attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', class GPTNEOXLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -555,6 +591,19 @@ def layerNorm(self): self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.bias + def get_param_names(self): + return 'attention.query_key_value.weight', \ + 'attention.query_key_value.bias', \ + 'attention.dense.weight', \ + 'attention.dense.bias', \ + 'mlp.dense_h_to_4h.weight', \ + 'mlp.dense_h_to_4h.bias', \ + 'mlp.dense_4h_to_h.weight', \ + 'mlp.dense_4h_to_h.bias', \ + 'input_layernorm.weight', \ + 'input_layernorm.bias', \ + 'post_attention_layernorm.weight', \ + 'post_attention_layernorm.bias', class HFOPTLayerPolicy(TransformerPolicy): _orig_layer_class = None @@ -611,7 +660,23 @@ def layerNorm(self): self.client_module.final_layer_norm.bias, \ self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.bias - + def get_param_names(self): + return 'self_attn.q_proj.weight', \ + 'self_attn.q_proj.bias', \ + 'self_attn.k_proj.weight', \ + 'self_attn.k_proj.bias', \ + 'self_attn.v_proj.weight', \ + 'self_attn.v_proj.bias', \ + 'self_attn.out_proj.weight', \ + 'self_attn.out_proj.bias', \ + 'fc1.weight', \ + 'fc1.bias', \ + 'fc2.weight', \ + 'fc2.bias', \ + 'final_layer_norm.weight', \ + 'final_layer_norm.bias', \ + 'self_attn_layer_norm.weight', \ + 'self_attn_layer_norm.bias', # transformer-based policies replace_policies = [ diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 610bd882ecf4..d056b20106bc 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -394,7 +394,7 @@ def __init__(self, data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, @@ -484,3 +484,4 @@ def forward(self, alibi) return output + diff --git a/deepspeed/ops/transformer/inference/ds_mlp.py b/deepspeed/ops/transformer/inference/ds_mlp.py index 4f1c705c55ea..38fb684dca72 100644 --- a/deepspeed/ops/transformer/inference/ds_mlp.py +++ b/deepspeed/ops/transformer/inference/ds_mlp.py @@ -96,7 +96,7 @@ def __init__(self, self.config = config data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu' self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),