Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..module_inject.auto_tp import AutoTP

from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor

DS_INFERENCE_ENABLED = False
from torch import nn
Expand Down Expand Up @@ -187,6 +187,9 @@ def build_alibi_tensor(self):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'):
self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor
self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor

def build_attn_bias(self):
if hasattr(self.module, 'transformer'):
Expand Down
15 changes: 0 additions & 15 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,21 +296,6 @@ def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
weight_shape = child.weight.shape
if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock':
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
# instead of [num_head,3,head_dim,hidden_size]
new_weight = torch.empty((
weight_shape[0] // self.mp_size,
weight_shape[1],
),
device=child.weight.device,
dtype=child.weight.dtype)
reversed_dim = True
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0)
# todo: can we remove new tensor allocation if we use strided copy?
mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
setattr(child, "replaced", True)
return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand Down
15 changes: 15 additions & 0 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,18 @@ def build_mpt_atten_bias_tensor(self,
offset = dist.get_rank() * num_heads_per_rank
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask


def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor:
r"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
"""
alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device)
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi[offset:num_heads_per_rank + offset, :, :]
return alibi
4 changes: 3 additions & 1 deletion deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):


def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value']
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv']

if mp_size == 1:
return False
Expand All @@ -33,6 +33,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
'CodeGenBlock': 'codegentype',
'BloomBlock': 'bloomtype',
'GLMBlock': 'glmtype',
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
Expand Down