diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 306feac322ea..fe83b7b3b4cf 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -1,5 +1,6 @@ import copy import torch +import tqdm import deepspeed import deepspeed.ops.transformer as transformer_inference from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy, BLOOMLayerPolicy @@ -765,9 +766,11 @@ def replace_fn(child, _policy, layer_id=0): _replace_policy=policy) if checkpoint is not None: + pbar = tqdm.tqdm(total=len(checkpoint), + desc=f"Loading {len(checkpoint)} checkpoint shards") for i in range(len(checkpoint)): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - print(f"loading checkpoint ({i})") + pbar.update(1) sd = torch.load(checkpoint[i], map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace) return replaced_module