diff --git a/examples/stable-diffusion/run_1x_bs16.sh b/examples/stable-diffusion/run_1x_bs16.sh index 5eefc1cc23..97f4f0ea43 100755 --- a/examples/stable-diffusion/run_1x_bs16.sh +++ b/examples/stable-diffusion/run_1x_bs16.sh @@ -2,7 +2,7 @@ python train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ --pretrained_vae_model_name_or_path stabilityai/sdxl-vae \ --dataset_name lambdalabs/pokemon-blip-captions \ - --resolution 512 \ + --resolution 1024 \ --center_crop \ --random_flip \ --proportion_empty_prompts=0.2 \ @@ -15,9 +15,10 @@ python train_text_to_image_sdxl.py \ --output_dir sdxl-pokemon-model \ --gaudi_config_name Habana/stable-diffusion \ --throughput_warmup_steps 3 \ + --dataloader_num_workers 8 \ --bf16 \ - --validation_prompt="a horse running on the beach during sunset" \ - --validation_epochs 48 \ - --use_hpu_graphs \ + --use_hpu_graphs_for_inference \ + --validation_prompt="a robotic cat with wings" \ + --validation_epochs 48\ --checkpointing_steps 2500 \ --cache_dir /root/software/data/pytorch/huggingface/sdxl 2>&1 | tee log_1x_bs16.txt diff --git a/examples/stable-diffusion/train_text_to_image_sdxl.py b/examples/stable-diffusion/train_text_to_image_sdxl.py index 9fa5ea2e75..df6e9868ee 100644 --- a/examples/stable-diffusion/train_text_to_image_sdxl.py +++ b/examples/stable-diffusion/train_text_to_image_sdxl.py @@ -40,6 +40,7 @@ from datasets import load_dataset from diffusers import ( AutoencoderKL, + DDPMScheduler, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler @@ -469,7 +470,14 @@ def parse_args(input_args=None): " lazy mode." ), ) - parser.add_argument("--use_hpu_graphs", action="store_true", help="Use HPU graphs on HPU.") + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.") + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.") parser.add_argument( "--image_save_dir", @@ -679,7 +687,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = GaudiEulerDiscreteScheduler.from_pretrained( + noise_scheduler = DDPMScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) @@ -965,7 +973,8 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) - + if args.use_hpu_graphs_for_training: + unet = htcore.hpu.ModuleCacher(max_graphs=10)(model=unet, inplace=True) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -1045,7 +1054,6 @@ def unwrap_model(model): with accelerator.accumulate(unet): # Sample noise that we'll add to the latents - model_input = batch["model_input"].to(dtype=weight_dtype) noise = torch.randn_like(model_input) @@ -1226,7 +1234,7 @@ def compute_time_ids(original_size, crops_coords_top_left): revision=args.revision, variant=args.variant, use_habana=True, - use_hpu_graphs=args.use_hpu_graphs, + use_hpu_graphs=args.use_hpu_graphs_for_inference, gaudi_config=args.gaudi_config_name, ) if args.prediction_type is not None: @@ -1297,7 +1305,7 @@ def compute_time_ids(original_size, crops_coords_top_left): torch_dtype=weight_dtype, scheduler=noise_scheduler, use_habana=True, - use_hpu_graphs=args.use_hpu_graphs, + use_hpu_graphs_for_inference=args.use_hpu_graphs_for_inference, gaudi_config=args.gaudi_config_name, ) if args.prediction_type is not None: @@ -1322,7 +1330,7 @@ def compute_time_ids(original_size, crops_coords_top_left): image_save_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving images in {image_save_dir.resolve()}...") for i, image in enumerate(images): - image.save(image_save_dir / f"image_{i+1}.png") + image.save(image_save_dir / f"image_{epoch}_{i+1}.png") else: logger.warning("--output_type should be equal to 'pil' to save images in --image_save_dir.") diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 67784e8ff2..68928ece97 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -125,7 +125,6 @@ def __init__( scheduler, force_zeros_for_empty_prompt, ) - self.to(self._device) def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): @@ -546,7 +545,9 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device="cpu") timesteps = self.scheduler.timesteps.to(device) - self.scheduler.reset_timestep_dependent_params() + reset_timestep = getattr(self.scheduler, "reset_timestep_dependent_params", None) + if callable(reset_timestep): + self.scheduler.reset_timestep_dependent_params() # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels