Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def unwrap_model(model):

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this conditionally if the dtype of noisy_model_input doesn't match with vae.dtype?

Copy link
Contributor Author

@leisuzz leisuzz Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your feedback, actually I tested several times, and I found the dtype error actually happens in:

model_pred = unet(
                    noisy_model_input,
                    timesteps,
                    prompt_embeds,
                    added_cond_kwargs=unet_added_conditions,
                    return_dict=False,
                )[0]

So I think it may not affect the vae. Further, the model_pred will be model_pred.float() in the loss part.


# time ids
def compute_time_ids(original_size, crops_coords_top_left):
Expand All @@ -1101,7 +1101,7 @@ def compute_time_ids(original_size, crops_coords_top_left):

# Predict the noise residual
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
model_pred = unet(
Expand Down