-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
fa57c21
3994545
356ef29
9155406
32ecbee
03bef1d
967eeee
820522f
85e6b6b
90d50e4
e8f1d38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| from diffusers.training_utils import compute_snr | ||
| from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available | ||
| from diffusers.utils.import_utils import is_xformers_available | ||
| from diffusers.utils.peft_utils import delete_adapter_layers | ||
|
|
||
|
|
||
| # Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
|
|
@@ -996,17 +997,6 @@ def main(args): | |
| text_encoder_one.add_adapter(text_lora_config) | ||
| text_encoder_two.add_adapter(text_lora_config) | ||
|
|
||
| # Make sure the trainable params are in float32. | ||
| if args.mixed_precision == "fp16": | ||
| models = [unet] | ||
| if args.train_text_encoder: | ||
| models.extend([text_encoder_one, text_encoder_two]) | ||
| for model in models: | ||
| for param in model.parameters(): | ||
| # only upcast trainable parameters (LoRA) into fp32 | ||
| if param.requires_grad: | ||
| param.data = param.to(torch.float32) | ||
|
|
||
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
| def save_model_hook(models, weights, output_dir): | ||
| if accelerator.is_main_process: | ||
|
|
@@ -1058,17 +1048,44 @@ def load_model_hook(models, input_dir): | |
| raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
|
||
| lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) | ||
| LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) | ||
|
|
||
| text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} | ||
| LoraLoaderMixin.load_lora_into_text_encoder( | ||
| text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ | ||
| # We pass the config here to ensure parity between `unet_lora_config` and | ||
| # the `LoraConfig` that's inferred in `load_lora_into_unet`. | ||
| LoraLoaderMixin.load_lora_into_unet( | ||
| lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config | ||
| ) | ||
| # Remove the newly created adapter as we don't need it. | ||
| delete_adapter_layers(unet_, "default_1") | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} | ||
| LoraLoaderMixin.load_lora_into_text_encoder( | ||
| text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ | ||
| ) | ||
| if args.train_text_encoder: | ||
| text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} | ||
| LoraLoaderMixin.load_lora_into_text_encoder( | ||
| text_encoder_state_dict, | ||
| network_alphas=network_alphas, | ||
| text_encoder=text_encoder_one_, | ||
| _config=text_lora_config, | ||
| ) | ||
| delete_adapter_layers(text_encoder_one_, "default_1") | ||
|
|
||
| text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} | ||
| LoraLoaderMixin.load_lora_into_text_encoder( | ||
| text_encoder_2_state_dict, | ||
| network_alphas=network_alphas, | ||
| text_encoder=text_encoder_two_, | ||
| _config=text_lora_config, | ||
| ) | ||
| delete_adapter_layers(text_encoder_two_, "default_1") | ||
|
|
||
| # Make sure the trainable params are in float32. This is again needed since the base models | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # are in `weight_dtype`. | ||
| if args.mixed_precision == "fp16": | ||
| models = [unet_] | ||
| if args.train_text_encoder: | ||
| models.extend([text_encoder_one_, text_encoder_two_]) | ||
| for model in models: | ||
| for param in model.parameters(): | ||
| # only upcast trainable parameters (LoRA) into fp32 | ||
| if param.requires_grad: | ||
| param.data = param.to(torch.float32) | ||
|
|
||
| accelerator.register_save_state_pre_hook(save_model_hook) | ||
| accelerator.register_load_state_pre_hook(load_model_hook) | ||
|
|
@@ -1083,6 +1100,17 @@ def load_model_hook(models, input_dir): | |
| args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
| ) | ||
|
|
||
| # Make sure the trainable params are in float32. | ||
| if args.mixed_precision == "fp16": | ||
| models = [unet] | ||
| if args.train_text_encoder: | ||
| models.extend([text_encoder_one, text_encoder_two]) | ||
| for model in models: | ||
| for param in model.parameters(): | ||
| # only upcast trainable parameters (LoRA) into fp32 | ||
| if param.requires_grad: | ||
| param.data = param.to(torch.float32) | ||
|
Comment on lines
+1108
to
+1117
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do it just before assigning the parameters to the optimizer to avoid any consequences.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a follow-up PR, I can wrap this utility into a function and move to |
||
|
|
||
| unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) | ||
|
|
||
| if args.train_text_encoder: | ||
|
|
@@ -1500,6 +1528,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |
| else unet_lora_parameters | ||
| ) | ||
| accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
|
|
||
| optimizer.step() | ||
| lr_scheduler.step() | ||
| optimizer.zero_grad() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/huggingface/diffusers/pull/6514/files#r1447020705