From ea3b2d14d97ab70078e80492c541bc06bdbbefe7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=8B=E7=A1=95?= Date: Mon, 19 Aug 2024 20:12:32 +0800 Subject: [PATCH] Fix dtype error --- examples/text_to_image/train_text_to_image_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 7f4917b5464c..2ca511c857ae 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1084,7 +1084,7 @@ def unwrap_model(model): # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype) # time ids def compute_time_ids(original_size, crops_coords_top_left): @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Predict the noise residual unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet(