3535from datasets import load_dataset
3636from huggingface_hub import create_repo , upload_folder
3737from packaging import version
38- from peft import LoraConfig
38+ from peft import LoraConfig , set_peft_model_state_dict
3939from peft .utils import get_peft_model_state_dict
4040from torchvision import transforms
4141from torchvision .transforms .functional import crop
5151)
5252from diffusers .loaders import LoraLoaderMixin
5353from diffusers .optimization import get_scheduler
54- from diffusers .training_utils import cast_training_params , compute_snr
55- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
54+ from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params , compute_snr
55+ from diffusers .utils import (
56+ check_min_version ,
57+ convert_state_dict_to_diffusers ,
58+ convert_unet_state_dict_to_peft ,
59+ is_wandb_available ,
60+ )
5661from diffusers .utils .import_utils import is_xformers_available
5762from diffusers .utils .torch_utils import is_compiled_module
5863
@@ -629,14 +634,6 @@ def main(args):
629634 text_encoder_one .add_adapter (text_lora_config )
630635 text_encoder_two .add_adapter (text_lora_config )
631636
632- # Make sure the trainable params are in float32.
633- if args .mixed_precision == "fp16" :
634- models = [unet ]
635- if args .train_text_encoder :
636- models .extend ([text_encoder_one , text_encoder_two ])
637- # only upcast trainable parameters (LoRA) into fp32
638- cast_training_params (models , dtype = torch .float32 )
639-
640637 def unwrap_model (model ):
641638 model = accelerator .unwrap_model (model )
642639 model = model ._orig_mod if is_compiled_module (model ) else model
@@ -693,18 +690,34 @@ def load_model_hook(models, input_dir):
693690 else :
694691 raise ValueError (f"unexpected save model: { model .__class__ } " )
695692
696- lora_state_dict , network_alphas = LoraLoaderMixin .lora_state_dict (input_dir )
697- LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
693+ lora_state_dict , _ = LoraLoaderMixin .lora_state_dict (input_dir )
694+ unet_state_dict = {f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )}
695+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
696+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
697+ if incompatible_keys is not None :
698+ # check only for unexpected keys
699+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
700+ if unexpected_keys :
701+ logger .warning (
702+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
703+ f" { unexpected_keys } . "
704+ )
698705
699- text_encoder_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder." in k }
700- LoraLoaderMixin .load_lora_into_text_encoder (
701- text_encoder_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_one_
702- )
706+ if args .train_text_encoder :
707+ _set_state_dict_into_text_encoder (lora_state_dict , prefix = "text_encoder." , text_encoder = text_encoder_one_ )
703708
704- text_encoder_2_state_dict = {k : v for k , v in lora_state_dict .items () if "text_encoder_2." in k }
705- LoraLoaderMixin .load_lora_into_text_encoder (
706- text_encoder_2_state_dict , network_alphas = network_alphas , text_encoder = text_encoder_two_
707- )
709+ _set_state_dict_into_text_encoder (
710+ lora_state_dict , prefix = "text_encoder_2." , text_encoder = text_encoder_two_
711+ )
712+
713+ # Make sure the trainable params are in float32. This is again needed since the base models
714+ # are in `weight_dtype`. More details:
715+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
716+ if args .mixed_precision == "fp16" :
717+ models = [unet_ ]
718+ if args .train_text_encoder :
719+ models .extend ([text_encoder_one_ , text_encoder_two_ ])
720+ cast_training_params (models , dtype = torch .float32 )
708721
709722 accelerator .register_save_state_pre_hook (save_model_hook )
710723 accelerator .register_load_state_pre_hook (load_model_hook )
@@ -725,6 +738,13 @@ def load_model_hook(models, input_dir):
725738 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
726739 )
727740
741+ # Make sure the trainable params are in float32.
742+ if args .mixed_precision == "fp16" :
743+ models = [unet ]
744+ if args .train_text_encoder :
745+ models .extend ([text_encoder_one , text_encoder_two ])
746+ cast_training_params (models , dtype = torch .float32 )
747+
728748 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
729749 if args .use_8bit_adam :
730750 try :
0 commit comments