diff --git a/README.md b/README.md
index c0acbccfae..ac99b6211f 100644
--- a/README.md
+++ b/README.md
@@ -251,6 +251,7 @@ The following model architectures, tasks and device distributions have been vali
| Stable Diffusion XL |
[fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl) | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| Stable Diffusion Depth2img | | Single card | [depth-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| LDM3D | | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
+| FLUX.1 | [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#dreambooth-lora-fine-tuning-with-flux1-dev) | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| Text to Video | | Single card | [text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-to-video) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 546f52d30d..86c4ce91b9 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -114,6 +114,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Stable Diffusion XL | [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#fine-tuning-for-stable-diffusion-xl) | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| Stable Diffusion Depth2img | | Single card | [depth-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| LDM3D | | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
+| FLUX.1 | [fine-tuning](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion/training#dreambooth-lora-fine-tuning-with-flux1-dev) | Single card | [text-to-image generation](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) |
| Text to Video | | Single card | [text-to-video generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-to-video) |
- PyTorch Image Models/TIMM:
diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md
index 3d35a1623e..a10c194066 100644
--- a/examples/stable-diffusion/training/README.md
+++ b/examples/stable-diffusion/training/README.md
@@ -25,17 +25,28 @@ This directory contains scripts that showcase how to perform training/fine-tunin
The `textual_inversion.py` script shows how to implement the training procedure on Habana Gaudi.
-### Cat toy example
+### Cat Toy Example
-Let's get our dataset. For this example, we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
+In the examples below, we will use a set of cat images from the following dataset:
+[https://huggingface.co/datasets/diffusers/cat_toy_example](https://huggingface.co/datasets/diffusers/cat_toy_example)
-Let's first download it locally:
+Let's first download this dataset locally:
```python
from huggingface_hub import snapshot_download
+from pathlib import Path
+import shutil
-local_dir = "./cat"
-snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
+local_dir = './cat'
+snapshot_download(
+ 'diffusers/cat_toy_example',
+ local_dir=local_dir,
+ repo_type='dataset',
+ ignore_patterns='.gitattributes',
+)
+cache_dir = Path(local_dir, '.cache')
+if cache_dir.is_dir():
+ shutil.rmtree(cache_dir)
```
This will be our training data.
@@ -150,10 +161,20 @@ image = pipe(prompt=prompt, prompt_2=prompt_2, num_inference_steps=50, guidance_
image.save(f"cat-backpack_p1and2.png")
```
+> [!NOTE]
+> Change `--resolution` to 768 if you are using [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.
+
+> [!NOTE]
+> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token,
+> e.g. `""`. However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable
+> parameters. This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to
+> a number larger than one, e.g.: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case.
+
## ControlNet Training
-ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image.
+ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models ](https://huggingface.co/papers/2302.05543)
+by Lvmin Zhang and Maneesh Agrawala. It is a type of model for controlling StableDiffusion by conditioning the model with an additional input image.
This example is adapted from [controlnet example in the diffusers repository](https://github.com/huggingface/diffusers/tree/main/examples/controlnet#training).
First, download the conditioning images as shown below:
@@ -203,7 +224,8 @@ python ../../gaudi_spawn.py --use_mpi --world_size 8 train_controlnet.py \
### Inference
-Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`.
+Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import ControlNetModel, UniPCMultistepScheduler
@@ -241,7 +263,7 @@ image.save("./output.png")
## Fine-Tuning for Stable Diffusion XL
-The `train_text_to_image_sdxl.py` script shows how to implement the fine-tuning of Stable Diffusion models on Habana Gaudi.
+The `train_text_to_image_sdxl.py` script shows how to implement the fine-tuning of Stable Diffusion XL models on Gaudi.
### Requirements
@@ -252,6 +274,7 @@ pip install -r requirements.txt
### Single-card Training
+To train Stable Diffusion XL on a single Gaudi card, use:
```bash
python train_text_to_image_sdxl.py \
--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
@@ -283,7 +306,9 @@ python train_text_to_image_sdxl.py \
```
-### Multi-card Training
+### Multi-Card Training
+
+To train Stable Diffusion XL on a multi-card Gaudi system, use:
```bash
PT_HPU_RECIPE_CACHE_CONFIG=/tmp/stdxl_recipe_cache,True,1024 \
python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py \
@@ -315,7 +340,9 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_text_to_image_sdxl.py
--adjust_throughput
```
-### Single-card Training on Gaudi1
+### Single-Card Training on Gaudi1
+
+To train Stable Diffusion XL on a single Gaudi1 card, use:
```bash
python train_text_to_image_sdxl.py \
--pretrained_model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
@@ -342,53 +369,52 @@ python train_text_to_image_sdxl.py \
```
> [!NOTE]
-> There is a known issue that in the first 2 steps, graph compilation takes longer than 10 seconds. This will be fixed in a future release.
+> There is a known issue that in the first 2 steps, graph compilation takes longer than 10 seconds.
+> This will be fixed in a future release.
> [!NOTE]
> `--mediapipe` only works on Gaudi2.
## DreamBooth
-DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.
-### Dog toy example
+DreamBooth is a technique for personalizing text-to-image models like Stable Diffusion using only a few images (typically 3-5)
+of a specific subject. The `train_dreambooth.py` script demonstrates how to implement this training process and adapt it for
+Stable Diffusion.
-Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+### Dog Toy Example
-Let's first download it locally:
+For DreamBooth examples we will use a set of dog images from the following dataset:
+[https://huggingface.co/datasets/diffusers/dog-example](https://huggingface.co/datasets/diffusers/dog-example).
+
+Let's first download this dataset locally:
```python
-import os
from huggingface_hub import snapshot_download
+from pathlib import Path
+import shutil
-local_dir = "./dog"
+local_dir = './dog'
snapshot_download(
- "diffusers/dog-example",
- local_dir=local_dir, repo_type="dataset",
- ignore_patterns=".gitattributes",
+ 'diffusers/dog-example',
+ local_dir=local_dir,
+ repo_type='dataset',
+ ignore_patterns='.gitattributes',
)
-
-# check if .cache folder exists and remove it.
-cache_folder = os.path.join(local_dir, ".cache")
-if os.path.exists(cache_folder):
- import shutil
- shutil.rmtree(cache_folder)
+cache_dir = Path(local_dir, '.cache')
+if cache_dir.is_dir():
+ shutil.rmtree(cache_dir)
```
-### Full model finetune
-And launch the multi-card training using:
-```bash
-
-export MODEL_NAME="CompVis/stable-diffusion-v1-4"
-export INSTANCE_DIR="dog"
-export CLASS_DIR="path-to-class-images"
-export OUTPUT_DIR="out"
+### Full Model Fine-Tuning
+To launch the multi-card Stable Diffusion training, use:
+```bash
python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --instance_data_dir=$INSTANCE_DIR \
- --output_dir=$OUTPUT_DIR \
- --class_data_dir=$CLASS_DIR \
+ --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
+ --instance_data_dir="dog" \
+ --output_dir="dog_sd" \
+ --class_data_dir="path-to-class-images" \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
@@ -405,31 +431,29 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/stable-diffusion \
full
-
```
-Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
-According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
-### PEFT model finetune
-We provide example for dreambooth to use lora/lokr/loha/oft to finetune unet or text encoder.
+Prior preservation is used to prevent overfitting and language drift. For more details, refer to the original paper.
+In this process, we first generate images using the model with a class prompt and then use those images during training
+alongside our data. According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior
+preservation, with 200-300 images being effective in most cases. The `num_class_images` flag controls how many images
+are generated with the class prompt. You can place existing images in the `class_data_dir`, and the training script will
+generate any additional images needed to meet the `num_class_images` requirement during training.
-**___Note: When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we
-use *1e-4* instead of the usual *5e-6*.___**
+### PEFT Model Fine-Tuning
-Launch the multi-card training using:
-```bash
-
-export MODEL_NAME="CompVis/stable-diffusion-v1-4"
-export INSTANCE_DIR="dog"
-export CLASS_DIR="path-to-class-images"
-export OUTPUT_DIR="out"
+We provide DreamBooth examples demonstrating how to use LoRA, LoKR, LoHA, and OFT adapters to fine-tune the
+UNet or text encoder.
+To run the multi-card training, use:
+```bash
python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --instance_data_dir=$INSTANCE_DIR \
- --output_dir=$OUTPUT_DIR \
- --class_data_dir=$CLASS_DIR \
- --with_prior_preservation --prior_loss_weight=1.0 \
+ --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
+ --instance_data_dir="dog" \
+ --output_dir="dog_sd" \
+ --class_data_dir="path-to-class-images" \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
@@ -445,50 +469,73 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \
--use_hpu_graphs_for_inference \
--gaudi_config_name Habana/stable-diffusion \
lora --unet_r 8 --unet_alpha 8
-
```
-Similar command could be applied to loha, lokr, oft.
-You could check each adapter specific args by "--help", like you could use following command to check oft specific args.
+> [!NOTE]
+> When using PEFT method we can use a much higher learning rate compared to vanilla dreambooth.
+> Here we use `1e-4` instead of the usual `5e-6`
+
+Similar command could be applied with `loha`, `lokr`, or `oft` adapters.
+
+You could check each adapter's specific arguments with `--help`, for example:
```bash
python3 train_dreambooth.py oft --help
-
```
+> [!NOTE]
+> Currently, the `oft` adapter is not supported in HPU graph mode, as it triggers `torch.inverse`,
+> causing a CPU fallback that is incompatible with HPU graph capturing.
-**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu.
-there's error like "cpu fallback is not supported during hpu graph capturing"___**
-
-
-You could use text_to_image_generation.py to generate picture using the peft adapter like
+After training completes, you can use `text_to_image_generation.py` sample for inference as follows:
```bash
python ../text_to_image_generation.py \
--model_name_or_path CompVis/stable-diffusion-v1-4 \
+ --unet_adapter_name_or_path dog_sd/unet \
--prompts "a sks dog" \
--num_images_per_prompt 5 \
--batch_size 1 \
--image_save_dir /tmp/stable_diffusion_images \
--use_habana \
--use_hpu_graphs \
- --unet_adapter_name_or_path out/unet \
--gaudi_config Habana/stable-diffusion \
--bf16
```
-### DreamBooth training example for Stable Diffusion XL
-You could use the dog images as example as well.
-You can launch training using:
+### DreamBooth LoRA Fine-Tuning with Stable Diffusion XL
+
+We can use the same `dog` dataset for the following examples.
+
+To launch Stable Diffusion XL LoRA training on a multi-card Gaudi system, use:"
```bash
-export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
-export INSTANCE_DIR="dog"
-export OUTPUT_DIR="lora-trained-xl"
-export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
+python train_dreambooth_lora_sdxl.py \
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
+ --instance_data_dir="dog" \
+ --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
+ --output_dir="lora-trained-xl" \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed=0 \
+ --use_hpu_graphs_for_inference \
+ --use_hpu_graphs_for_training \
+ --gaudi_config_name Habana/stable-diffusion
+```
+To launch Stable Diffusion XL LoRA training on a multi-card Gaudi system, use:"
+```bash
python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --instance_data_dir=$INSTANCE_DIR \
- --pretrained_vae_model_name_or_path=$VAE_PATH \
- --output_dir=$OUTPUT_DIR \
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
+ --instance_data_dir="dog" \
+ --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
+ --output_dir="lora-trained-xl" \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
@@ -504,21 +551,149 @@ python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl.
--use_hpu_graphs_for_inference \
--use_hpu_graphs_for_training \
--gaudi_config_name Habana/stable-diffusion
-
```
+> [!NOTE]
+> To use DeepSpeed instead of MPI, replace `--use_mpi` with `--deepspeed` in the previous example
+
+After training completes, you can run inference with a simple python script like this:
+```python
+import torch
+from optimum.habana import GaudiConfig
+from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline
+
+pipe = GaudiStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ use_hpu_graphs=True,
+ use_habana=True,
+ gaudi_config="Habana/stable-diffusion",
+)
+pipe.load_lora_weights("lora-trained-xl")
-You could use text_to_image_generation.py to generate picture using the peft adapter like
+prompt = "A photo of sks dog in a bucket"
+image = pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3.5,
+ num_inference_steps=30,
+ max_sequence_length=512,
+).images[0]
+image.save("sdxl-lora.png")
+```
+Alternatively, you could directly use `text_to_image_generation.py` sample for inference as follows:
```bash
python ../text_to_image_generation.py \
--model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \
+ --lora_id lora-trained-xl \
--prompts "A picture of a sks dog in a bucket" \
--num_images_per_prompt 5 \
--batch_size 1 \
--image_save_dir /tmp/stable_diffusion_xl_images \
--use_habana \
--use_hpu_graphs \
- --lora_id lora-trained-xl \
+ --gaudi_config Habana/stable-diffusion \
+ --bf16
+```
+
+### DreamBooth LoRA Fine-Tuning with FLUX.1-dev
+
+We can use the same `dog` dataset for the following examples.
+
+To launch FLUX.1-dev LoRA training on a single Gaudi card, use:"
+```bash
+python train_dreambooth_lora_flux.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --dataset="dog" \
+ --prompt="a photo of sks dog" \
+ --output_dir="dog_lora_flux" \
+ --mixed_precision="bf16" \
+ --weighting_scheme="none" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --learning_rate=1e-4 \
+ --guidance_scale=1 \
+ --report_to="tensorboard" \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --cache_latents \
+ --rank=4 \
+ --max_train_steps=500 \
+ --seed="0" \
+ --use_hpu_graphs_for_inference \
+ --use_hpu_graphs_for_training \
+ --gaudi_config_name="Habana/stable-diffusion"
+```
+
+To launch FLUX.1-dev LoRA training on a multi-card Gaudi system, use:"
+```bash
+python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_flux.py \
+ --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
+ --dataset="dog" \
+ --prompt="a photo of sks dog" \
+ --output_dir="dog_lora_flux" \
+ --mixed_precision="bf16" \
+ --weighting_scheme="none" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --learning_rate=1e-4 \
+ --guidance_scale=1 \
+ --report_to="tensorboard" \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --cache_latents \
+ --rank=4 \
+ --max_train_steps=500 \
+ --seed="0" \
+ --use_hpu_graphs_for_inference \
+ --use_hpu_graphs_for_training \
+ --gaudi_config_name="Habana/stable-diffusion"
+```
+> [!NOTE]
+> To use DeepSpeed instead of MPI, replace `--use_mpi` with `--use_deepspeed` in the previous example
+
+After training completes, you can run inference on Gaudi system with a simple python script like this:
+```python
+import torch
+from optimum.habana import GaudiConfig
+from optimum.habana.diffusers import GaudiFluxPipeline
+
+pipe = GaudiFluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+ use_hpu_graphs=True,
+ use_habana=True,
+ gaudi_config="Habana/stable-diffusion",
+)
+pipe.load_lora_weights("dog_lora_flux")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(
+ prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3.5,
+ num_inference_steps=30,
+).images[0]
+image.save("flux-dev.png")
+```
+
+Alternatively, you could directly use `text_to_image_generation.py` sample for inference as follows:
+```bash
+python ../text_to_image_generation.py \
+ --model_name_or_path "black-forest-labs/FLUX.1-dev" \
+ --lora_id dog_lora_flux \
+ --prompts "A picture of a sks dog in a bucket" \
+ --num_images_per_prompt 5 \
+ --batch_size 1 \
+ --image_save_dir /tmp/flux_images \
+ --use_habana \
+ --use_hpu_graphs \
--gaudi_config Habana/stable-diffusion \
--bf16
```
diff --git a/examples/stable-diffusion/training/requirements.txt b/examples/stable-diffusion/training/requirements.txt
index 7fb1748675..bf92040ae8 100644
--- a/examples/stable-diffusion/training/requirements.txt
+++ b/examples/stable-diffusion/training/requirements.txt
@@ -1,2 +1,3 @@
imagesize
peft == 0.10.0
+sentencepiece
diff --git a/examples/stable-diffusion/training/textual_inversion.py b/examples/stable-diffusion/training/textual_inversion.py
old mode 100644
new mode 100755
diff --git a/examples/stable-diffusion/training/train_controlnet.py b/examples/stable-diffusion/training/train_controlnet.py
old mode 100644
new mode 100755
diff --git a/examples/stable-diffusion/training/train_dreambooth.py b/examples/stable-diffusion/training/train_dreambooth.py
old mode 100644
new mode 100755
diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_flux.py b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py
new file mode 100755
index 0000000000..68b5320d19
--- /dev/null
+++ b/examples/stable-diffusion/training/train_dreambooth_lora_flux.py
@@ -0,0 +1,1142 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+"""
+Training script for FLUX LORA DreamBooth to Text-to-Image Diffusion Models
+Adapted from the following sources:
+https://github.com/huggingface/diffusers/blob/v0.31.0/examples/dreambooth/train_dreambooth_lora_flux.py
+https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization
+"""
+
+import argparse
+import copy
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import diffusers
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
+from datasets import load_dataset
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import T5EncoderModel
+
+from optimum.habana import GaudiConfig
+from optimum.habana.accelerate import GaudiAccelerator
+from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType
+from optimum.habana.utils import set_seed
+
+
+if is_wandb_available():
+ pass
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.31.0")
+
+logger = get_logger(__name__)
+warnings.simplefilter(action="ignore", category=FutureWarning)
+warnings.simplefilter(action="ignore", category=UserWarning)
+
+
+def save_model_card(
+ repo_id: str,
+ base_model: str = None,
+ instance_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+
+ model_description = f"""
+# Flux DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the
+[Gaudi Flux diffusers trainer](https://github.com/huggingface/optimum-habana/blob/main/examples/stable-diffusion/training/README.md).
+
+Was LoRA for the text encoder enabled? False.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+For more details, including weighting, merging and fusing LoRAs, check the
+[documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux",
+ "flux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default=None,
+ help="Path to dataset used for training.",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default=None,
+ help="Prompt to use with training dataset (if dataset itself does not have text prompt feature).",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ choices=["AdamW"],
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer.",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "bf16"],
+ help=("Choose whether to use bf16 (bfloat16) mixed precision or not."),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--gaudi_config_name",
+ type=str,
+ default=None,
+ help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.",
+ )
+ parser.add_argument(
+ "--use_hpu_graphs_for_training",
+ action="store_true",
+ help="Use HPU graphs for training on HPU.",
+ )
+ parser.add_argument(
+ "--use_hpu_graphs_for_inference",
+ action="store_true",
+ help="Use HPU graphs for inference on HPU.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ def __init__(
+ self,
+ pretrained_model_name_or_path,
+ instance_data_root,
+ instance_prompt,
+ size=1024,
+ max_sequence_length=77,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.max_sequence_length = max_sequence_length
+
+ self.instance_data_root = instance_data_root
+ self.instance_prompt = instance_prompt
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
+
+ # Load dataset and create text embeddings
+ dataset = load_dataset(self.instance_data_root, split="train")
+ image_hashes, image_prompts = [], {}
+
+ for sample in dataset:
+ image_hash = self.generate_image_hash(sample["image"])
+ image_hashes.append(image_hash)
+ text = sample["text"] if "text" in dataset.features else self.instance_prompt
+ image_prompts[image_hash] = text
+
+ all_prompts = list(image_prompts.values())
+
+ pipeline = self.load_text_encoder_pipeline()
+ all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = self.compute_embeddings(
+ pipeline, all_prompts, args.max_sequence_length
+ )
+ self.instance_images = [sample["image"] for sample in dataset]
+ self.image_hashes = image_hashes
+
+ # Image transformations
+ self.pixel_values = self.apply_image_transformations(
+ instance_images=self.instance_images, size=size, center_crop=center_crop
+ )
+
+ # Map hashes to embeddings.
+ self.data_dict = {}
+ for prompt, pooled_prompt, text_id, image_hash in zip(
+ all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids, self.image_hashes
+ ):
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(
+ embeddings=[prompt, pooled_prompt, text_id]
+ )
+ self.data_dict.update({image_hash: (prompt_embeds, pooled_prompt_embeds, text_ids)})
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ image_hash = self.image_hashes[index % self.num_instance_images]
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash]
+ example["instance_images"] = instance_image
+ example["prompt_embeds"] = prompt_embeds
+ example["pooled_prompt_embeds"] = pooled_prompt_embeds
+ example["text_ids"] = text_ids
+ return example
+
+ def apply_image_transformations(self, instance_images, size, center_crop):
+ pixel_values = []
+
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ pixel_values.append(image)
+
+ return pixel_values
+
+ def convert_to_torch_tensor(self, embeddings: list):
+ prompt_embeds = embeddings[0]
+ pooled_prompt_embeds = embeddings[1]
+ text_ids = embeddings[2]
+ prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, prompt_embeds.shape[-1])
+ pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(pooled_prompt_embeds.shape[-1])
+ text_ids = np.array(text_ids).reshape(self.max_sequence_length, 3)
+ return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids)
+
+ def generate_image_hash(self, image):
+ return insecure_hashlib.sha256(image.tobytes()).hexdigest()
+
+ def load_text_encoder_pipeline(self):
+ id = self.pretrained_model_name_or_path
+ text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", device_map="auto")
+ pipeline = FluxPipeline.from_pretrained(
+ id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
+ )
+ return pipeline
+
+ @torch.no_grad()
+ def compute_embeddings(self, pipeline, prompts, max_sequence_length):
+ all_prompt_embeds = []
+ all_pooled_prompt_embeds = []
+ all_text_ids = []
+ for prompt in tqdm(prompts, desc="Encoding prompts."):
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = pipeline.encode_prompt(
+ prompt=prompt,
+ prompt_2=None,
+ max_sequence_length=max_sequence_length,
+ )
+ all_prompt_embeds.append(prompt_embeds)
+ all_pooled_prompt_embeds.append(pooled_prompt_embeds)
+ all_text_ids.append(text_ids)
+ return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompt_embeds = [example["prompt_embeds"] for example in examples]
+ pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
+ text_ids = [example["text_ids"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.stack(prompt_embeds)
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
+ text_ids = torch.stack(text_ids)[0] # just 2D tensor
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "text_ids": text_ids,
+ }
+ return batch
+
+
+def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+
+def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+
+def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name)
+ gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16"
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
+ accelerator = GaudiAccelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ force_autocast=gaudi_config.use_torch_autocast,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ dtype = torch.float32
+ if args.mixed_precision == "bf16":
+ dtype = torch.bfloat16
+
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=dtype,
+ )
+ transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ FluxPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=None,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ if not accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+ else:
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=dtype,
+ )
+ transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False)
+ transformer_.add_adapter(transformer_lora_config)
+
+ lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if args.optimizer.lower() == "adamw":
+ if gaudi_config.use_fused_adam:
+ from habana_frameworks.torch.hpex.optimizers import FusedAdamW
+
+ optimizer_class = FusedAdamW
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ else:
+ raise ValueError(f"{args.optimizer} optimizer is not supported.")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ args.pretrained_model_name_or_path,
+ args.dataset,
+ args.prompt,
+ size=args.resolution,
+ max_sequence_length=args.max_sequence_length,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ vae_config_shift_factor = vae.config.shift_factor
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_block_out_channels = vae.config.block_out_channels
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ del vae
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer.to(device=accelerator.device, dtype=dtype)
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-dev-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ global_step = int(path.split("-")[1])
+ kwargs = {"step": global_step}
+ accelerator.load_state(os.path.join(args.output_dir, path), **kwargs)
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
+
+ latent_image_ids = _prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ packed_noisy_model_input = _pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+
+ # handle guidance
+ if unwrap_model(transformer).config.guidance_embeds:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise
+ prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype)
+ text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
+
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+
+ # model_pred = FluxPipeline._unpack_latents(
+ model_pred = _unpack_latents(
+ model_pred,
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ params_to_clip = transformer.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process or accelerator.distributed_type == GaudiDistributedType.DEEPSPEED:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ FluxPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=None,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=None,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py
old mode 100644
new mode 100755
diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py
old mode 100644
new mode 100755