diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py index 68437921f930..d610f5b61c24 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py @@ -21,7 +21,7 @@ import wget from torch.hub import _get_torch_home -from nemo.utils import get_rank, logging +from nemo.utils import logging __all__ = [ "get_megatron_lm_model", @@ -203,7 +203,7 @@ def _download(path: str, url: str): if url is None: return None - if get_rank.is_global_rank_zero() and not os.path.exists(path): + if (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0) and not os.path.exists(path): os.makedirs(MEGATRON_CACHE, exist_ok=True) logging.info(f"Downloading from {url} to {path}") downloaded_path = wget.download(url)