diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 64467b12e632..0ed0d6567ee1 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -26,7 +26,7 @@ from ..module_inject.auto_tp import AutoTP from ..module_inject.replace_policy import generic_policies -from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor +from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor DS_INFERENCE_ENABLED = False from torch import nn @@ -187,6 +187,9 @@ def build_alibi_tensor(self): if hasattr(self.module, 'transformer'): if hasattr(self.module.transformer, 'build_alibi_tensor'): self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor + if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'): + self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor + self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor def build_attn_bias(self): if hasattr(self.module, 'transformer'): diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index d02cd3b2934c..b3d46de84da2 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -296,21 +296,6 @@ def _replace(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: return weight_shape = child.weight.shape - if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock': - # MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size] - # instead of [num_head,3,head_dim,hidden_size] - new_weight = torch.empty(( - weight_shape[0] // self.mp_size, - weight_shape[1], - ), - device=child.weight.device, - dtype=child.weight.dtype) - reversed_dim = True - mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0) - # todo: can we remove new tensor allocation if we use strided copy? - mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim) - setattr(child, "replaced", True) - return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None) mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py index d31dfd17a2a9..ce52cd6a6250 100644 --- a/deepspeed/module_inject/auto_tp_model_utils.py +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -76,3 +76,18 @@ def build_mpt_atten_bias_tensor(self, offset = dist.get_rank() * num_heads_per_rank attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :] return attn_bias, attention_mask + + +def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor: + r""" + Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation. This implementation has been copied from + the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi: + https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292 + """ + alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device) + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size()) + offset = dist.get_rank() * num_heads_per_rank + alibi = alibi[offset:num_heads_per_rank + offset, :, :] + return alibi diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 30a5bb75db23..f25f9f8ddd07 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -16,7 +16,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): def require_tp_fused_qkvw(name, mp_size): - fused_qkvw_name_list = ['qkv_proj', 'query_key_value'] + fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv'] if mp_size == 1: return False @@ -33,6 +33,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): 'CodeGenBlock': 'codegentype', 'BloomBlock': 'bloomtype', 'GLMBlock': 'glmtype', + "MPTBlock": 'glmtype', + "MptBlock": 'glmtype', } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):