Skip to content
79 changes: 58 additions & 21 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -56,6 +56,7 @@
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -996,17 +997,6 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand Down Expand Up @@ -1058,17 +1048,52 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
Comment on lines -1067 to -1072
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot be using load_lora_into_unet() and load_lora_into_text_encoder() for the following reason (described only for the UNet but applicable to the text encoders, too).

  • We call add_adapter() once on unet at the beginning of training. This creates an adapter config inside of the UNet.
  • Then during loading an intermediate checkpoint in the accelerate hook, we call load_lora_into_unet(). It internally again calls inject_adapter_in_model() with the config inferred from the state dict provided. So, it internally creates another adapter. This is undesirable, right?

unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
text_encoder_state_dict = {
f'{k.replace("text_encoder.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("text_encoder.")
}
text_encoder_state_dict = convert_state_dict_to_peft(
convert_state_dict_to_diffusers(text_encoder_state_dict)
)
set_peft_model_state_dict(text_encoder_one_, text_encoder_state_dict, adapter_name="default")

text_encoder_2_state_dict = {
f'{k.replace("text_encoder_2.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("text_encoder_2.")
}
text_encoder_2_state_dict = convert_state_dict_to_peft(
convert_state_dict_to_diffusers(text_encoder_2_state_dict)
)
set_peft_model_state_dict(text_encoder_two_, text_encoder_2_state_dict, adapter_name="default")

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`.
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1083,6 +1108,17 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Comment on lines +1108 to +1117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it just before assigning the parameters to the optimizer to avoid any consequences.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR, I can wrap this utility into a function and move to training_utils.py. It's shared by a number of scripts.


unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
Expand Down Expand Up @@ -1500,6 +1536,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def load_lora_into_text_encoder(
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down