diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 122af23865b8..8f92e3b44295 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -56,6 +56,7 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1007,6 +1008,11 @@ def main(args): if param.requires_grad: param.data = param.to(torch.float32) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1017,13 +1023,13 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -1048,11 +1054,11 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1621,16 +1627,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_one.to(torch.float32)) ) - text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_two = unwrap_model(text_encoder_two) text_encoder_2_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_two.to(torch.float32)) )