diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py index 881669414b..6035e3cf47 100644 --- a/examples/stable-diffusion/training/train_text_to_image_sdxl.py +++ b/examples/stable-diffusion/training/train_text_to_image_sdxl.py @@ -799,7 +799,8 @@ def main(args): for idx, dt in enumerate(dataset['train']): dt['image'].save(f'{args.mediapipe}/{idx}.jpg') f.write(dt['text'] + '\n') - torch.distributed.barrier() + if accelerator.distributed_type != GaudiDistributedType.NO: + torch.distributed.barrier() from media_pipe_imgdir import get_dataset_for_pipeline dt = get_dataset_for_pipeline(args.mediapipe) dataset = {'train': dt}