diff --git a/flagai/model/base_model.py b/flagai/model/base_model.py index 46367e28..c385c52a 100644 --- a/flagai/model/base_model.py +++ b/flagai/model/base_model.py @@ -77,9 +77,11 @@ def from_pretrain(cls, config_path = os.path.join(download_path, "config.json") checkpoint_path = os.path.join(download_path, "pytorch_model.bin") - def load_local(checkpoint_path): + def load_local(checkpoint_path, only_download_config=False): model = cls.init_from_json(config_path, **kwargs) model.to(device) + if only_download_config: + return model if os.getenv('ENV_TYPE') != 'deepspeed+mpu': if os.path.exists(checkpoint_path): model.load_weights(checkpoint_path) @@ -146,7 +148,7 @@ def load_diffusion_local(yaml_path, only_download_config=False, **kwargs): It is fine when checkpoint_path does not exist, for the case that only_download_config=True At that time the model will not be loaded. """ - return load_local(checkpoint_path) + return load_local(checkpoint_path, only_download_config=only_download_config) try: model_id = _get_model_id(model_name) diff --git a/flagai/trainer.py b/flagai/trainer.py index 50309849..17703d73 100644 --- a/flagai/trainer.py +++ b/flagai/trainer.py @@ -335,9 +335,9 @@ def get_dataloader(self, dataset, collate_fn, shuffle=False): shuffle=shuffle) else: if self.env_type == 'deepspeed+mpu': - rank = mpu.get_model_parallel_src_rank() + rank = mpu.get_data_parallel_rank() print("*"*80) - print("local rank",self.rank, "model rank", rank) + print("local rank",self.rank, "data parallel rank", rank) print("*"*80) sampler = torch.utils.data.distributed.DistributedSampler( dataset,