From b97b4c1444732414fcf4f5bbac5e8515146ef7e4 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 18 Nov 2024 23:13:36 -0400 Subject: [PATCH] [advanced flux training] bug fix + reduce memory cost as in #9829 (#9838) * memory improvement as done here: https://github.com/huggingface/diffusers/pull/9829 * fix bug * fix bug * style --------- Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_flux_advanced.py | 14 ++++++++++++-- examples/dreambooth/train_dreambooth_lora_flux.py | 6 +++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index bf726e65c94b..112884609901 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: + elems_to_repeat = 1 if freeze_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers @@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5, ) + else: + elems_to_repeat = len(prompts) if not freeze_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], max_sequence_length=args.max_sequence_length, device=accelerator.device, prompt=prompts, ) - # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) + images = None + del pipeline + if freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() @@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2c1126109a36..f73269a48967 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1648,11 +1648,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt=prompts, ) else: + elems_to_repeat = len(prompts) if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], max_sequence_length=args.max_sequence_length, device=accelerator.device, prompt=args.instance_prompt,