diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 7f4917b5464c..2ca511c857ae 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1084,7 +1084,7 @@ def unwrap_model(model): # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype) # time ids def compute_time_ids(original_size, crops_coords_top_left): @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Predict the noise residual unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet(