diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index ebfbbec8aa66..538cc04759b0 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -478,8 +478,16 @@ def _transpose(x): attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) else: - attn_block.attn_qkvw = mp_replace.qkv_copy(attn_block.attn_qkvw, qkvw) - attn_block.attn_qkvb = mp_replace.qkv_copy(attn_block.attn_qkvb, qkvb) + if bigscience_bloom: + attn_block.attn_qkvw = mp_replace.copy(attn_block.attn_qkvw, qkvw) + attn_block.attn_qkvb = mp_replace.copy(attn_block.attn_qkvb, qkvb) + else: + attn_block.attn_qkvw = mp_replace.qkv_copy( + attn_block.attn_qkvw, + qkvw) + attn_block.attn_qkvb = mp_replace.qkv_copy( + attn_block.attn_qkvb, + qkvb) attn_block.attn_ow = mp_replace.copy(attn_block.attn_ow, dense_w) attn_block.attn_ob = mp_replace.copy(attn_block.attn_ob, dense_b) @@ -773,10 +781,10 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None: start_time = time.time() - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 checkpoint = checkpoint_dict['checkpoints'] ckpt_type = checkpoint_dict.get('parallelization', 'pp') ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) @@ -818,12 +826,22 @@ def replace_fn(child, _policy, layer_id=0): from collections import OrderedDict import json - ckpt_name = checkpoint_dict['type'] + if checkpoint_dict is None: + ckpt_name = "ds_model" + try: + from transformers.models.bloom.modeling_bloom import BloomForCausalLM + if isinstance(model, BloomForCausalLM): + ckpt_name = "bloom" + except ImportError: + ckpt_name = "ds_model" + else: + ckpt_name = checkpoint_dict['type'] if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' ckpt_files = [non_tp_ckpt_name] * world_size + os.makedirs(save_mp_checkpoint_path, exist_ok=True) if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") torch.save(