Skip to content

Commit

Permalink
Only download checkpoint on local rank 0
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Jun 3, 2024
1 parent d833434 commit 6a7c96f
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions diffusion/evaluation/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def __init__(self,

# Download the model checkpoint if needed
if self.load_path is not None:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
# Load the model
state_dict = torch.load(self.local_checkpoint_path)
if dist.get_local_rank() == 0:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
Expand Down

0 comments on commit 6a7c96f

Please sign in to comment.