diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index bf726e65c94b..112884609901 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2154,6 +2154,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: + elems_to_repeat = 1 if freeze_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers @@ -2168,17 +2169,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): max_sequence_length=args.max_sequence_length, add_special_tokens=add_special_tokens_t5, ) + else: + elems_to_repeat = len(prompts) if not freeze_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], max_sequence_length=args.max_sequence_length, device=accelerator.device, prompt=prompts, ) - # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -2371,6 +2376,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) + images = None + del pipeline + if freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() @@ -2448,6 +2456,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2c1126109a36..f73269a48967 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1648,11 +1648,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt=prompts, ) else: + elems_to_repeat = len(prompts) if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=[None, None], - text_input_ids_list=[tokens_one, tokens_two], + text_input_ids_list=[ + tokens_one.repeat(elems_to_repeat, 1), + tokens_two.repeat(elems_to_repeat, 1), + ], max_sequence_length=args.max_sequence_length, device=accelerator.device, prompt=args.instance_prompt,