Skip to content
69 changes: 49 additions & 20 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.peft_utils import delete_adapter_layers


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -996,17 +997,6 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand Down Expand Up @@ -1058,17 +1048,44 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
# We pass the config here to ensure parity between `unet_lora_config` and
# the `LoraConfig` that's inferred in `load_lora_into_unet`.
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config
)
# Remove the newly created adapter as we don't need it.
delete_adapter_layers(unet_, "default_1")

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=text_encoder_one_,
_config=text_lora_config,
)
delete_adapter_layers(text_encoder_one_, "default_1")

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=text_encoder_two_,
_config=text_lora_config,
)
delete_adapter_layers(text_encoder_two_, "default_1")

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`.
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1083,6 +1100,17 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Comment on lines +1108 to +1117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it just before assigning the parameters to the optimizer to avoid any consequences.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR, I can wrap this utility into a function and move to training_utils.py. It's shared by a number of scripts.


unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
Expand Down Expand Up @@ -1500,6 +1528,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
22 changes: 14 additions & 8 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _optionally_disable_offloading(cls, _pipeline):

@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _config=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
Expand Down Expand Up @@ -446,8 +446,11 @@ def load_lora_into_unet(
if "lora_B" in key:
rank[key] = val.shape[1]

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
if _config is None:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
else:
lora_config = _config

# adapter_name
if adapter_name is None:
Expand Down Expand Up @@ -490,6 +493,7 @@ def load_lora_into_text_encoder(
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
_config=None,
_pipeline=None,
):
"""
Expand Down Expand Up @@ -578,11 +582,13 @@ def load_lora_into_text_encoder(
if USE_PEFT_BACKEND:
from peft import LoraConfig

lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)

lora_config = LoraConfig(**lora_config_kwargs)
if _config is None:
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
else:
lora_config = _config

# adapter_name
if adapter_name is None:
Expand Down