3636from datasets import load_dataset
3737from huggingface_hub import create_repo , upload_folder
3838from packaging import version
39- from peft import LoraConfig , get_peft_model_state_dict
39+ from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
4040from torchvision import transforms
4141from torchvision .transforms .functional import crop
4242from tqdm .auto import tqdm
5252)
5353from diffusers .optimization import get_scheduler
5454from diffusers .training_utils import cast_training_params , resolve_interpolation_mode
55- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
55+ from diffusers .utils import (
56+ check_min_version ,
57+ convert_state_dict_to_diffusers ,
58+ convert_unet_state_dict_to_peft ,
59+ is_wandb_available ,
60+ )
5661from diffusers .utils .import_utils import is_xformers_available
5762
5863
@@ -858,11 +863,6 @@ def main(args):
858863 )
859864 unet .add_adapter (lora_config )
860865
861- # Make sure the trainable params are in float32.
862- if args .mixed_precision == "fp16" :
863- # only upcast trainable parameters (LoRA) into fp32
864- cast_training_params (unet , dtype = torch .float32 )
865-
866866 # Also move the alpha and sigma noise schedules to accelerator.device.
867867 alpha_schedule = alpha_schedule .to (accelerator .device )
868868 sigma_schedule = sigma_schedule .to (accelerator .device )
@@ -887,13 +887,31 @@ def save_model_hook(models, weights, output_dir):
887887 def load_model_hook (models , input_dir ):
888888 # load the LoRA into the model
889889 unet_ = accelerator .unwrap_model (unet )
890- lora_state_dict , network_alphas = StableDiffusionXLPipeline .lora_state_dict (input_dir )
891- StableDiffusionXLPipeline .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
890+ lora_state_dict , _ = StableDiffusionXLPipeline .lora_state_dict (input_dir )
891+ unet_state_dict = {
892+ f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )
893+ }
894+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
895+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
896+ if incompatible_keys is not None :
897+ # check only for unexpected keys
898+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
899+ if unexpected_keys :
900+ logger .warning (
901+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
902+ f" { unexpected_keys } . "
903+ )
892904
893905 for _ in range (len (models )):
894906 # pop models so that they are not loaded again
895907 models .pop ()
896908
909+ # Make sure the trainable params are in float32. This is again needed since the base models
910+ # are in `weight_dtype`. More details:
911+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
912+ if args .mixed_precision == "fp16" :
913+ cast_training_params (unet_ , dtype = torch .float32 )
914+
897915 accelerator .register_save_state_pre_hook (save_model_hook )
898916 accelerator .register_load_state_pre_hook (load_model_hook )
899917
@@ -1092,6 +1110,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
10921110 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
10931111 )
10941112
1113+ # Make sure the trainable params are in float32.
1114+ if args .mixed_precision == "fp16" :
1115+ # only upcast trainable parameters (LoRA) into fp32
1116+ cast_training_params (unet , dtype = torch .float32 )
1117+
10951118 lr_scheduler = get_scheduler (
10961119 args .lr_scheduler ,
10971120 optimizer = optimizer ,
0 commit comments