From 921199bab49d8238f50367049f8ccbe1cd32bd58 Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Sat, 14 Sep 2024 11:12:40 +0800 Subject: [PATCH] Add flux fine-tuning script for Gaudi Signed-off-by: Daniel Socek --- README.md | 1 + docs/source/index.mdx | 1 + examples/stable-diffusion/training/README.md | 337 +++-- .../training/requirements.txt | 1 + .../training/textual_inversion.py | 0 .../training/train_controlnet.py | 0 .../training/train_dreambooth.py | 0 .../training/train_dreambooth_lora_flux.py | 1142 +++++++++++++++++ .../training/train_dreambooth_lora_sdxl.py | 0 .../training/train_text_to_image_sdxl.py | 0 10 files changed, 1401 insertions(+), 81 deletions(-) mode change 100644 => 100755 examples/stable-diffusion/training/textual_inversion.py mode change 100644 => 100755 examples/stable-diffusion/training/train_controlnet.py mode change 100644 => 100755 examples/stable-diffusion/training/train_dreambooth.py create mode 100755 examples/stable-diffusion/training/train_dreambooth_lora_flux.py mode change 100644 => 100755 examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py mode change 100644 => 100755 examples/stable-diffusion/training/train_text_to_image_sdxl.py diff --git a/README.md b/README.md index c0acbccfae..ac99b6211f 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ The following model architectures, tasks and device distributions have been vali | Stable Diffusion XL |
  • [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl)
  • |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | Stable Diffusion Depth2img | |
  • Single card
  • |
  • [depth-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | LDM3D | |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | +| FLUX.1 |
  • [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#dreambooth-lora-fine-tuning-with-flux1-dev)
  • |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | Text to Video | |
  • Single card
  • |
  • [text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-to-video)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 546f52d30d..86c4ce91b9 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -114,6 +114,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Stable Diffusion XL |
  • [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl)
  • |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | Stable Diffusion Depth2img | |
  • Single card
  • |
  • [depth-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | LDM3D | |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | +| FLUX.1 |
  • [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#dreambooth-lora-fine-tuning-with-flux1-dev)
  • |
  • Single card
  • |
  • [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)
  • | | Text to Video | |
  • Single card
  • |
  • [text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-to-video)
  • | - PyTorch Image Models/TIMM: diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md index 3d35a1623e..a10c194066 100644 --- a/examples/stable-diffusion/training/README.md +++ b/examples/stable-diffusion/training/README.md @@ -25,17 +25,28 @@ This directory contains scripts that showcase how to perform training/fine-tunin The `textual_inversion.py` script shows how to implement the training procedure on Habana Gaudi. -### Cat toy example +### Cat Toy Example -Let's get our dataset. For this example, we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example . +In the examples below, we will use a set of cat images from the following dataset: +[https://huggingface.co/datasets/diffusers/cat_toy_example](https://huggingface.co/datasets/diffusers/cat_toy_example) -Let's first download it locally: +Let's first download this dataset locally: ```python from huggingface_hub import snapshot_download +from pathlib import Path +import shutil -local_dir = "./cat" -snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes") +local_dir = './cat' +snapshot_download( + 'diffusers/cat_toy_example', + local_dir=local_dir, + repo_type='dataset', + ignore_patterns='.gitattributes', +) +cache_dir = Path(local_dir, '.cache') +if cache_dir.is_dir(): + shutil.rmtree(cache_dir) ``` This will be our training data. @@ -150,10 +161,20 @@ image = pipe(prompt=prompt, prompt_2=prompt_2, num_inference_steps=50, guidance_ image.save(f"cat-backpack_p1and2.png") ``` +> [!NOTE] +> Change `--resolution` to 768 if you are using [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model. + +> [!NOTE] +> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, +> e.g. `""`. However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable +> parameters. This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to +> a number larger than one, e.g.: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case. + ## ControlNet Training -ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image. +ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) +by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image. This example is adapted from [controlnet example in the diffusers repository](https://github.com/huggingface/diffusers/tree/main/examples/controlnet#training). First, download the conditioning images as shown below: @@ -203,7 +224,8 @@ python ../../gaudi_spawn.py --use_mpi --world_size 8 train_controlnet.py \ ### Inference -Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt. +Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`. +Make sure to include the `placeholder_token` in your prompt. ```python from diffusers import ControlNetModel, UniPCMultistepScheduler @@ -241,7 +263,7 @@ image.save("./output.png") ## Fine-Tuning for Stable Diffusion XL -The `train_text_to_image_sdxl.py` script shows how to implement the fine-tuning of Stable Diffusion models on Habana Gaudi. +The `train_text_to_image_sdxl.py` script shows how to implement the fine-tuning of Stable Diffusion XL models on Gaudi. ### Requirements @@ -252,6 +274,7 @@ pip install -r requirements.txt ### Single-card Training +To train Stable Diffusion XL on a single Gaudi card, use: ```bash python train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ @@ -283,7 +306,9 @@ python train_text_to_image_sdxl.py \ ``` -### Multi-card Training +### Multi-Card Training + +To train Stable Diffusion XL on a multi-card Gaudi system, use: ```bash PT_HPU_RECIPE_CACHE_CONFIG=/tmp/stdxl_recipe_cache,True,1024 \ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py \ @@ -315,7 +340,9 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py --adjust_throughput ``` -### Single-card Training on Gaudi1 +### Single-Card Training on Gaudi1 + +To train Stable Diffusion XL on a single Gaudi1 card, use: ```bash python train_text_to_image_sdxl.py \ --pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ @@ -342,53 +369,52 @@ python train_text_to_image_sdxl.py \ ``` > [!NOTE] -> There is a known issue that in the first 2 steps, graph compilation takes longer than 10 seconds. This will be fixed in a future release. +> There is a known issue that in the first 2 steps, graph compilation takes longer than 10 seconds. +> This will be fixed in a future release. > [!NOTE] > `--mediapipe` only works on Gaudi2. ## DreamBooth -DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. -### Dog toy example +DreamBooth is a technique for personalizing text-to-image models like Stable Diffusion using only a few images (typically 3-5) +of a specific subject. The `train_dreambooth.py` script demonstrates how to implement this training process and adapt it for +Stable Diffusion. -Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. +### Dog Toy Example -Let's first download it locally: +For DreamBooth examples we will use a set of dog images from the following dataset: +[https://huggingface.co/datasets/diffusers/dog-example](https://huggingface.co/datasets/diffusers/dog-example). + +Let's first download this dataset locally: ```python -import os from huggingface_hub import snapshot_download +from pathlib import Path +import shutil -local_dir = "./dog" +local_dir = './dog' snapshot_download( - "diffusers/dog-example", - local_dir=local_dir, repo_type="dataset", - ignore_patterns=".gitattributes", + 'diffusers/dog-example', + local_dir=local_dir, + repo_type='dataset', + ignore_patterns='.gitattributes', ) - -# check if .cache folder exists and remove it. -cache_folder = os.path.join(local_dir, ".cache") -if os.path.exists(cache_folder): - import shutil - shutil.rmtree(cache_folder) +cache_dir = Path(local_dir, '.cache') +if cache_dir.is_dir(): + shutil.rmtree(cache_dir) ``` -### Full model finetune -And launch the multi-card training using: -```bash - -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="dog" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="out" +### Full Model Fine-Tuning +To launch the multi-card Stable Diffusion training, use: +```bash python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --class_data_dir=$CLASS_DIR \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --instance_data_dir="dog" \ + --output_dir="dog_sd" \ + --class_data_dir="path-to-class-images" \ --with_prior_preservation --prior_loss_weight=1.0 \ --instance_prompt="a photo of sks dog" \ --class_prompt="a photo of dog" \ @@ -405,31 +431,29 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ --use_hpu_graphs_for_inference \ --gaudi_config_name Habana/stable-diffusion \ full - ``` -Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. -According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. -### PEFT model finetune -We provide example for dreambooth to use lora/lokr/loha/oft to finetune unet or text encoder. +Prior preservation is used to prevent overfitting and language drift. For more details, refer to the original paper. +In this process, we first generate images using the model with a class prompt and then use those images during training +alongside our data. According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior +preservation, with 200-300 images being effective in most cases. The `num_class_images` flag controls how many images +are generated with the class prompt. You can place existing images in the `class_data_dir`, and the training script will +generate any additional images needed to meet the `num_class_images` requirement during training. -**___Note: When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we -use *1e-4* instead of the usual *5e-6*.___** +### PEFT Model Fine-Tuning -Launch the multi-card training using: -```bash - -export MODEL_NAME="CompVis/stable-diffusion-v1-4" -export INSTANCE_DIR="dog" -export CLASS_DIR="path-to-class-images" -export OUTPUT_DIR="out" +We provide DreamBooth examples demonstrating how to use LoRA, LoKR, LoHA, and OFT adapters to fine-tune the +UNet or text encoder. +To run the multi-card training, use: +```bash python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --output_dir=$OUTPUT_DIR \ - --class_data_dir=$CLASS_DIR \ - --with_prior_preservation --prior_loss_weight=1.0 \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --instance_data_dir="dog" \ + --output_dir="dog_sd" \ + --class_data_dir="path-to-class-images" \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ --instance_prompt="a photo of sks dog" \ --class_prompt="a photo of dog" \ --resolution=512 \ @@ -445,50 +469,73 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ --use_hpu_graphs_for_inference \ --gaudi_config_name Habana/stable-diffusion \ lora --unet_r 8 --unet_alpha 8 - ``` -Similar command could be applied to loha, lokr, oft. -You could check each adapter specific args by "--help", like you could use following command to check oft specific args. +> [!NOTE] +> When using PEFT method we can use a much higher learning rate compared to vanilla dreambooth. +> Here we use `1e-4` instead of the usual `5e-6` + +Similar command could be applied with `loha`, `lokr`, or `oft` adapters. + +You could check each adapter's specific arguments with `--help`, for example: ```bash python3 train_dreambooth.py oft --help - ``` +> [!NOTE] +> Currently, the `oft` adapter is not supported in HPU graph mode, as it triggers `torch.inverse`, +> causing a CPU fallback that is incompatible with HPU graph capturing. -**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu. -there's error like "cpu fallback is not supported during hpu graph capturing"___** - - -You could use text_to_image_generation.py to generate picture using the peft adapter like +After training completes, you can use `text_to_image_generation.py` sample for inference as follows: ```bash python ../text_to_image_generation.py \ --model_name_or_path CompVis/stable-diffusion-v1-4 \ + --unet_adapter_name_or_path dog_sd/unet \ --prompts "a sks dog" \ --num_images_per_prompt 5 \ --batch_size 1 \ --image_save_dir /tmp/stable_diffusion_images \ --use_habana \ --use_hpu_graphs \ - --unet_adapter_name_or_path out/unet \ --gaudi_config Habana/stable-diffusion \ --bf16 ``` -### DreamBooth training example for Stable Diffusion XL -You could use the dog images as example as well. -You can launch training using: +### DreamBooth LoRA Fine-Tuning with Stable Diffusion XL + +We can use the same `dog` dataset for the following examples. + +To launch Stable Diffusion XL LoRA training on a multi-card Gaudi system, use:" ```bash -export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" -export INSTANCE_DIR="dog" -export OUTPUT_DIR="lora-trained-xl" -export VAE_PATH="madebyollin/sdxl-vae-fp16-fix" +python train_dreambooth_lora_sdxl.py \ + --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ + --instance_data_dir="dog" \ + --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \ + --output_dir="lora-trained-xl" \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed=0 \ + --use_hpu_graphs_for_inference \ + --use_hpu_graphs_for_training \ + --gaudi_config_name Habana/stable-diffusion +``` +To launch Stable Diffusion XL LoRA training on a multi-card Gaudi system, use:" +```bash python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl.py \ - --pretrained_model_name_or_path=$MODEL_NAME \ - --instance_data_dir=$INSTANCE_DIR \ - --pretrained_vae_model_name_or_path=$VAE_PATH \ - --output_dir=$OUTPUT_DIR \ + --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \ + --instance_data_dir="dog" \ + --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \ + --output_dir="lora-trained-xl" \ --mixed_precision="bf16" \ --instance_prompt="a photo of sks dog" \ --resolution=1024 \ @@ -504,21 +551,149 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl. --use_hpu_graphs_for_inference \ --use_hpu_graphs_for_training \ --gaudi_config_name Habana/stable-diffusion - ``` +> [!NOTE] +> To use DeepSpeed instead of MPI, replace `--use_mpi` with `--deepspeed` in the previous example + +After training completes, you can run inference with a simple python script like this: +```python +import torch +from optimum.habana import GaudiConfig +from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline + +pipe = GaudiStableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.bfloat16, + use_hpu_graphs=True, + use_habana=True, + gaudi_config="Habana/stable-diffusion", +) +pipe.load_lora_weights("lora-trained-xl") -You could use text_to_image_generation.py to generate picture using the peft adapter like +prompt = "A photo of sks dog in a bucket" +image = pipe( + prompt, + height=1024, + width=1024, + guidance_scale=3.5, + num_inference_steps=30, + max_sequence_length=512, +).images[0] +image.save("sdxl-lora.png") +``` +Alternatively, you could directly use `text_to_image_generation.py` sample for inference as follows: ```bash python ../text_to_image_generation.py \ --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --lora_id lora-trained-xl \ --prompts "A picture of a sks dog in a bucket" \ --num_images_per_prompt 5 \ --batch_size 1 \ --image_save_dir /tmp/stable_diffusion_xl_images \ --use_habana \ --use_hpu_graphs \ - --lora_id lora-trained-xl \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +### DreamBooth LoRA Fine-Tuning with FLUX.1-dev + +We can use the same `dog` dataset for the following examples. + +To launch FLUX.1-dev LoRA training on a single Gaudi card, use:" +```bash +python train_dreambooth_lora_flux.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --dataset="dog" \ + --prompt="a photo of sks dog" \ + --output_dir="dog_lora_flux" \ + --mixed_precision="bf16" \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="tensorboard" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=500 \ + --seed="0" \ + --use_hpu_graphs_for_inference \ + --use_hpu_graphs_for_training \ + --gaudi_config_name="Habana/stable-diffusion" +``` + +To launch FLUX.1-dev LoRA training on a multi-card Gaudi system, use:" +```bash +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_flux.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --dataset="dog" \ + --prompt="a photo of sks dog" \ + --output_dir="dog_lora_flux" \ + --mixed_precision="bf16" \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="tensorboard" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=500 \ + --seed="0" \ + --use_hpu_graphs_for_inference \ + --use_hpu_graphs_for_training \ + --gaudi_config_name="Habana/stable-diffusion" +``` +> [!NOTE] +> To use DeepSpeed instead of MPI, replace `--use_mpi` with `--use_deepspeed` in the previous example + +After training completes, you can run inference on Gaudi system with a simple python script like this: +```python +import torch +from optimum.habana import GaudiConfig +from optimum.habana.diffusers import GaudiFluxPipeline + +pipe = GaudiFluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, + use_hpu_graphs=True, + use_habana=True, + gaudi_config="Habana/stable-diffusion", +) +pipe.load_lora_weights("dog_lora_flux") + +prompt = "A photo of sks dog in a bucket" +image = pipe( + prompt, + height=1024, + width=1024, + guidance_scale=3.5, + num_inference_steps=30, +).images[0] +image.save("flux-dev.png") +``` + +Alternatively, you could directly use `text_to_image_generation.py` sample for inference as follows: +```bash +python ../text_to_image_generation.py \ + --model_name_or_path "black-forest-labs/FLUX.1-dev" \ + --lora_id dog_lora_flux \ + --prompts "A picture of a sks dog in a bucket" \ + --num_images_per_prompt 5 \ + --batch_size 1 \ + --image_save_dir /tmp/flux_images \ + --use_habana \ + --use_hpu_graphs \ --gaudi_config Habana/stable-diffusion \ --bf16 ``` diff --git a/examples/stable-diffusion/training/requirements.txt b/examples/stable-diffusion/training/requirements.txt index 7fb1748675..bf92040ae8 100644 --- a/examples/stable-diffusion/training/requirements.txt +++ b/examples/stable-diffusion/training/requirements.txt @@ -1,2 +1,3 @@ imagesize peft == 0.10.0 +sentencepiece diff --git a/examples/stable-diffusion/training/textual_inversion.py b/examples/stable-diffusion/training/textual_inversion.py old mode 100644 new mode 100755 diff --git a/examples/stable-diffusion/training/train_controlnet.py b/examples/stable-diffusion/training/train_controlnet.py old mode 100644 new mode 100755 diff --git a/examples/stable-diffusion/training/train_dreambooth.py b/examples/stable-diffusion/training/train_dreambooth.py old mode 100644 new mode 100755 diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_flux.py b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py new file mode 100755 index 0000000000..68b5320d19 --- /dev/null +++ b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py @@ -0,0 +1,1142 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +""" +Training script for FLUX LORA DreamBooth to Text-to-Image Diffusion Models +Adapted from the following sources: +https://github.com/huggingface/diffusers/blob/v0.31.0/examples/dreambooth/train_dreambooth_lora_flux.py +https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization +""" + +import argparse +import copy +import logging +import math +import os +import random +import shutil +import warnings +from pathlib import Path + +import diffusers +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration +from datasets import load_dataset +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import T5EncoderModel + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType +from optimum.habana.utils import set_seed + + +if is_wandb_available(): + pass + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0") + +logger = get_logger(__name__) +warnings.simplefilter(action="ignore", category=FutureWarning) +warnings.simplefilter(action="ignore", category=UserWarning) + + +def save_model_card( + repo_id: str, + base_model: str = None, + instance_prompt=None, + repo_folder=None, +): + widget_dict = [] + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the +[Gaudi Flux diffusers trainer](https://github.com/huggingface/optimum-habana/blob/main/examples/stable-diffusion/training/README.md). + +Was LoRA for the text encoder enabled? False. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +For more details, including weighting, merging and fusing LoRAs, check the +[documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to dataset used for training.", + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Prompt to use with training dataset (if dataset itself does not have text prompt feature).", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + choices=["AdamW"], + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "bf16"], + help=("Choose whether to use bf16 (bfloat16) mixed precision or not."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.", + ) + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + def __init__( + self, + pretrained_model_name_or_path, + instance_data_root, + instance_prompt, + size=1024, + max_sequence_length=77, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.max_sequence_length = max_sequence_length + + self.instance_data_root = instance_data_root + self.instance_prompt = instance_prompt + self.pretrained_model_name_or_path = pretrained_model_name_or_path + + # Load dataset and create text embeddings + dataset = load_dataset(self.instance_data_root, split="train") + image_hashes, image_prompts = [], {} + + for sample in dataset: + image_hash = self.generate_image_hash(sample["image"]) + image_hashes.append(image_hash) + text = sample["text"] if "text" in dataset.features else self.instance_prompt + image_prompts[image_hash] = text + + all_prompts = list(image_prompts.values()) + + pipeline = self.load_text_encoder_pipeline() + all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = self.compute_embeddings( + pipeline, all_prompts, args.max_sequence_length + ) + self.instance_images = [sample["image"] for sample in dataset] + self.image_hashes = image_hashes + + # Image transformations + self.pixel_values = self.apply_image_transformations( + instance_images=self.instance_images, size=size, center_crop=center_crop + ) + + # Map hashes to embeddings. + self.data_dict = {} + for prompt, pooled_prompt, text_id, image_hash in zip( + all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids, self.image_hashes + ): + prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor( + embeddings=[prompt, pooled_prompt, text_id] + ) + self.data_dict.update({image_hash: (prompt_embeds, pooled_prompt_embeds, text_ids)}) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + image_hash = self.image_hashes[index % self.num_instance_images] + prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash] + example["instance_images"] = instance_image + example["prompt_embeds"] = prompt_embeds + example["pooled_prompt_embeds"] = pooled_prompt_embeds + example["text_ids"] = text_ids + return example + + def apply_image_transformations(self, instance_images, size, center_crop): + pixel_values = [] + + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + pixel_values.append(image) + + return pixel_values + + def convert_to_torch_tensor(self, embeddings: list): + prompt_embeds = embeddings[0] + pooled_prompt_embeds = embeddings[1] + text_ids = embeddings[2] + prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, prompt_embeds.shape[-1]) + pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(pooled_prompt_embeds.shape[-1]) + text_ids = np.array(text_ids).reshape(self.max_sequence_length, 3) + return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids) + + def generate_image_hash(self, image): + return insecure_hashlib.sha256(image.tobytes()).hexdigest() + + def load_text_encoder_pipeline(self): + id = self.pretrained_model_name_or_path + text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", device_map="auto") + pipeline = FluxPipeline.from_pretrained( + id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced" + ) + return pipeline + + @torch.no_grad() + def compute_embeddings(self, pipeline, prompts, max_sequence_length): + all_prompt_embeds = [] + all_pooled_prompt_embeds = [] + all_text_ids = [] + for prompt in tqdm(prompts, desc="Encoding prompts."): + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = pipeline.encode_prompt( + prompt=prompt, + prompt_2=None, + max_sequence_length=max_sequence_length, + ) + all_prompt_embeds.append(prompt_embeds) + all_pooled_prompt_embeds.append(pooled_prompt_embeds) + all_text_ids.append(text_ids) + return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompt_embeds = [example["prompt_embeds"] for example in examples] + pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples] + text_ids = [example["text_ids"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + prompt_embeds = torch.stack(prompt_embeds) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds) + text_ids = torch.stack(text_ids)[0] # just 2D tensor + + batch = { + "pixel_values": pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + } + return batch + + +def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + +def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16" + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + force_autocast=gaudi_config.use_torch_autocast, + kwargs_handlers=[kwargs], + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + dtype = torch.float32 + if args.mixed_precision == "bf16": + dtype = torch.bfloat16 + + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=dtype, + ) + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=None, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == GaudiDistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=dtype, + ) + transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if args.optimizer.lower() == "adamw": + if gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_class = FusedAdamW + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + else: + raise ValueError(f"{args.optimizer} optimizer is not supported.") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + args.pretrained_model_name_or_path, + args.dataset, + args.prompt, + size=args.resolution, + max_sequence_length=args.max_sequence_length, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + del vae + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer.to(device=accelerator.device, dtype=dtype) + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + kwargs = {"step": global_step} + accelerator.load_state(os.path.join(args.output_dir, path), **kwargs) + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + with accelerator.accumulate(models_to_accumulate): + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) + + latent_image_ids = _prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2] // 2, + model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = _pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if unwrap_model(transformer).config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise + prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype) + + model_pred = transformer( + hidden_states=packed_noisy_model_input, + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + # model_pred = FluxPipeline._unpack_latents( + model_pred = _unpack_latents( + model_pred, + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == GaudiDistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=None, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + base_model=args.pretrained_model_name_or_path, + instance_prompt=None, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py old mode 100644 new mode 100755 diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py old mode 100644 new mode 100755