diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 30af889c34ac..158507e077e0 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -454,9 +454,8 @@ def transpose(data): dense_w.data = transpose(dense_w.data) def _transpose(x): - num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size - attention_head_size = x.shape[-1] // num_attention_heads_per_partition - new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, + attention_head_size = x.shape[-1] // transformer_config.heads + new_x_shape = x.size()[:-1] + (transformer_config.heads, attention_head_size) x_1 = x.view(*new_x_shape) (q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))