From 6a7c96fded26b06f493e142b5063c44f2932449b Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 3 Jun 2024 18:31:25 +0000 Subject: [PATCH] Only download checkpoint on local rank 0 --- diffusion/evaluation/generate_images.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py index fc6674a6..c45c9b7f 100644 --- a/diffusion/evaluation/generate_images.py +++ b/diffusion/evaluation/generate_images.py @@ -79,9 +79,11 @@ def __init__(self, # Download the model checkpoint if needed if self.load_path is not None: - get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) - # Load the model - state_dict = torch.load(self.local_checkpoint_path) + if dist.get_local_rank() == 0: + get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) + with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): + # Load the model + state_dict = torch.load(self.local_checkpoint_path) for key in list(state_dict['state']['model'].keys()): if 'val_metrics.' in key: del state_dict['state']['model'][key]