diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 011b66858710..4ae9e5529d0e 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -483,8 +483,8 @@ def _transpose(x): attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: - attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) - attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) + attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw) + attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b)