diff --git a/Dockerfile b/Dockerfile
index eb9b0f43..3b2e60e2 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,6 +11,12 @@ RUN apt-get update -y
# on user input during build
ENV DEBIAN_FRONTEND noninteractive
+# Install libg dependencies
+RUN apt install libgl1-mesa-glx -y
+RUN apt-get install 'ffmpeg'\
+ 'libsm6'\
+ 'libxext6' -y
+
# Install misc unix libraries
RUN apt-get install -y --no-install-recommends openssh-server \
openssh-client \
diff --git a/INSTALL.md b/INSTALL.md
index db15a629..a4ed4716 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -4,6 +4,8 @@ For users that wish to make use of Docker or another container orchestration pla
### Installation
+For users operating on Windows 10 or newer, an installation guide based on Docker and WSL is available here [this document](/documentation/DOCKER.md).
+
Clone the SimpleTuner repository and set up the python venv:
```bash
diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md
index 0b39c7c3..e2eb4bbb 100644
--- a/documentation/DOCKER.md
+++ b/documentation/DOCKER.md
@@ -11,6 +11,11 @@ This Docker configuration provides a comprehensive environment for running the S
## Getting Started
+### Windows OS support via WSL (Experimental)
+
+The following guide was tested in a WSL2 Distro that has Dockerengine installed.
+
+
### 1. Building the Container
Clone the repository and navigate to the directory containing the Dockerfile. Build the Docker image using:
@@ -68,6 +73,44 @@ If you want to add custom startup scripts or modify configurations, extend the e
If any capabilities cannot be achieved through this setup, please open a new issue.
+### Docker Compose
+
+For users who prefer `docker-compose.yaml`, this template is provided for you to extend and customise for your needs.
+
+Once the stack is deployed you can connect to the container and start operating in it as mentioned in the steps above.
+
+```bash
+docker compose up -d
+
+docker exec -it simpletuner /bin/bash
+```
+
+```docker-compose.yaml
+services:
+ simpletuner:
+ container_name: simpletuner
+ build:
+ context: [Path to the repository]/SimpleTuner
+ dockerfile: Dockerfile
+ ports:
+ - "[port to connect to the container]:22"
+ volumes:
+ - "[path to your datasets]:/datasets"
+ - "[path to your configs]:/workspace/SimpleTuner/config"
+ environment:
+ HUGGING_FACE_HUB_TOKEN: [your hugging face token]
+ WANDB_TOKEN: [your wanddb token]
+ command: ["tail", "-f", "/dev/null"]
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
+```
+
+> ⚠️ Please be cautious of handling your WandB and Hugging Face tokens! It's advised not to commit them even to a private version-control repository to ensure they are not leaked. For production use-cases, key management storage is recommended, but out of scope for this guide.
---
## Troubleshooting
diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md
index ece4e96b..4af8e56d 100644
--- a/documentation/quickstart/SD3.md
+++ b/documentation/quickstart/SD3.md
@@ -264,6 +264,33 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu
## Notes & troubleshooting tips
+### Skip-layer guidance (SD3.5 Medium)
+
+StabilityAI recommends enabling SLG (Skip-layer guidance) on SD 3.5 Medium inference. This doesn't impact training results, only the validation sample quality.
+
+The following values are recommended for `config.json`:
+
+```json
+{
+ "--validation_guidance_skip_layers": [7, 8, 9],
+ "--validation_guidance_skip_layers_start": 0.01,
+ "--validation_guidance_skip_layers_stop": 0.2,
+ "--validation_guidance_skip_scale": 2.8,
+ "--validation_guidance": 4.0
+}
+```
+
+- `..skip_scale` determines how much to scale the positive prompt prediction during skip-layer guidance. The default value of 2.8 is safe for the base model's skip value of `7, 8, 9` but will need to be increased if more layers are skipped, doubling it for each additional layer.
+- `..skip_layers` tells which layers to skip during the negative prompt prediction.
+- `..skip_layers_start` determine the fraction of the inference pipeline during which skip-layer guidance should begin to be applied.
+- `..skip_layers_stop` will set the fraction of the total number of inference steps after which SLG will no longer be applied.
+
+SLG can be applied for fewer steps for a weaker effect or less reduction of inference speed.
+
+It seems that extensive training of a LoRA or LyCORIS model will require modification to these values, though it's not clear how exactly it changes.
+
+**Lower CFG must be used during inference.**
+
### Model instability
The SD 3.5 Large 8B model has potential instabilities during training:
@@ -288,12 +315,14 @@ Some changes were made to SimpleTuner's SD3.5 support:
#### Stable configuration values
These options have been known to keep SD3.5 in-tact for as long as possible:
-- optimizer=optimi-stableadamw
-- learning_rate=1e-5
+- optimizer=adamw_bf16
+- flux_schedule_shift=1
+- learning_rate=1e-4
- batch_size=4 * 3 GPUs
-- max_grad_norm=0.01
+- max_grad_norm=0.1
- base_model_precision=int8-quanto
- No loss masking or dataset regularisation, as their contribution to this instability is unknown
+- `validation_guidance_skip_layers=[7,8,9]`
### Lowest VRAM config
diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py
index ee3b4f79..14e1b470 100644
--- a/helpers/configuration/cmd_args.py
+++ b/helpers/configuration/cmd_args.py
@@ -6,6 +6,7 @@
from typing import Dict, List, Optional, Tuple
import random
import time
+import json
import logging
import sys
import torch
@@ -148,6 +149,15 @@ def get_argument_parser():
" which has improved results in short experiments. Thanks to @mhirki for the contribution."
),
)
+ parser.add_argument(
+ "--flux_use_uniform_schedule",
+ action="store_true",
+ help=(
+ "Whether or not to use a uniform schedule with Flux instead of sigmoid."
+ " Using uniform sampling may help preserve more capabilities from the base model."
+ " Some tasks may not benefit from this."
+ ),
+ )
parser.add_argument(
"--flux_use_beta_schedule",
action="store_true",
@@ -1350,6 +1360,37 @@ def get_argument_parser():
" the default mode, provides the most benefit."
),
)
+ parser.add_argument(
+ "--validation_guidance_skip_layers",
+ type=str,
+ default=None,
+ help=(
+ "StabilityAI recommends a value of [7, 8, 9] for Stable Diffusion 3.5 Medium."
+ ),
+ )
+ parser.add_argument(
+ "--validation_guidance_skip_layers_start",
+ type=float,
+ default=0.01,
+ help=("StabilityAI recommends a value of 0.01 for SLG start."),
+ )
+ parser.add_argument(
+ "--validation_guidance_skip_layers_stop",
+ type=float,
+ default=0.01,
+ help=("StabilityAI recommends a value of 0.2 for SLG start."),
+ )
+ parser.add_argument(
+ "--validation_guidance_skip_scale",
+ type=float,
+ default=2.8,
+ help=(
+ "StabilityAI recommends a value of 2.8 for SLG guidance skip scaling."
+ " When adding more layers, you must increase the scale, eg. adding one more layer requires doubling"
+ " the value given."
+ ),
+ )
+
parser.add_argument(
"--allow_tf32",
action="store_true",
@@ -2391,4 +2432,15 @@ def parse_cmdline_args(input_args=None):
f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
)
+ if args.validation_guidance_skip_layers is not None:
+ try:
+ import json
+
+ args.validation_guidance_skip_layers = json.loads(
+ args.validation_guidance_skip_layers
+ )
+ except Exception as e:
+ logger.error(f"Could not load skip layers: {e}")
+ raise
+
return args
diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py
index 9c582c30..ef2593dd 100644
--- a/helpers/data_backend/factory.py
+++ b/helpers/data_backend/factory.py
@@ -24,6 +24,8 @@
from tqdm import tqdm
import queue
from math import sqrt
+import pandas as pd
+import numpy as np
logger = logging.getLogger("DataBackendFactory")
if should_log():
@@ -48,6 +50,68 @@ def info_log(message):
logger.info(message)
+def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False):
+ # Determine if the column contains arrays or scalar values
+ non_null_values = column_data.dropna()
+ if non_null_values.empty:
+ # All values are null
+ raise ValueError(
+ f"Parquet file {parquet_path} contains only null values in the '{column_name}' column."
+ )
+
+ first_non_null = non_null_values.iloc[0]
+ if isinstance(first_non_null, (list, tuple, np.ndarray, pd.Series)):
+ # Column contains arrays
+ # Check for null arrays
+ if column_data.isnull().any() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains null arrays in the '{column_name}' column."
+ )
+
+ # Check for empty arrays
+ empty_arrays = column_data.apply(lambda x: len(x) == 0)
+ if empty_arrays.any() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains empty arrays in the '{column_name}' column."
+ )
+
+ # Check for null elements within arrays
+ null_elements_in_arrays = column_data.apply(
+ lambda arr: any(pd.isnull(s) for s in arr)
+ )
+ if null_elements_in_arrays.any() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains null values within arrays in the '{column_name}' column."
+ )
+
+ # Check for empty strings within arrays
+ empty_strings_in_arrays = column_data.apply(
+ lambda arr: any(s == "" for s in arr)
+ )
+ if empty_strings_in_arrays.all() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains only empty strings within arrays in the '{column_name}' column."
+ )
+
+ elif isinstance(first_non_null, str):
+ # Column contains scalar strings
+ # Check for null values
+ if column_data.isnull().any() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains null values in the '{column_name}' column."
+ )
+
+ # Check for empty strings
+ if (column_data == "").any() and not fallback_caption_column:
+ raise ValueError(
+ f"Parquet file {parquet_path} contains empty strings in the '{column_name}' column."
+ )
+ else:
+ raise TypeError(
+ f"Unsupported data type in column '{column_name}'. Expected strings or arrays of strings."
+ )
+
+
def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output = {"id": backend["id"], "config": {}}
if backend.get("dataset_type", None) == "text_embeds":
@@ -292,24 +356,23 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken
raise ValueError(
f"Parquet file {parquet_path} does not contain a column named '{filename_column}'."
)
- # Check for null values
- if df[caption_column].isnull().values.any() and not fallback_caption_column:
- raise ValueError(
- f"Parquet file {parquet_path} contains null values in the '{caption_column}' column, but no fallback_caption_column was set."
- )
- if df[filename_column].isnull().values.any():
- raise ValueError(
- f"Parquet file {parquet_path} contains null values in the '{filename_column}' column."
- )
- # Check for empty strings
- if (df[caption_column] == "").sum() > 0 and not fallback_caption_column:
- raise ValueError(
- f"Parquet file {parquet_path} contains empty strings in the '{caption_column}' column."
- )
- if (df[filename_column] == "").sum() > 0:
- raise ValueError(
- f"Parquet file {parquet_path} contains empty strings in the '{filename_column}' column."
- )
+
+ # Apply the function to the caption_column.
+ check_column_values(
+ df[caption_column],
+ caption_column,
+ parquet_path,
+ fallback_caption_column=fallback_caption_column
+ )
+
+ # Apply the function to the filename_column.
+ check_column_values(
+ df[filename_column],
+ filename_column,
+ parquet_path,
+ fallback_caption_column=False # Always check filename_column
+ )
+
# Store the database in StateTracker
StateTracker.set_parquet_database(
backend["id"],
diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py
index 584cd9c0..2850f986 100644
--- a/helpers/metadata/backends/parquet.py
+++ b/helpers/metadata/backends/parquet.py
@@ -150,11 +150,13 @@ def _extract_captions_to_fast_list(self):
if len(caption_column) > 0:
caption = [row[c] for c in caption_column]
else:
- caption = row[caption_column]
+ caption = row.get(caption_column)
+ if isinstance(caption, (numpy.ndarray, pd.Series)):
+ caption = [str(item) for item in caption if item is not None]
- if not caption and fallback_caption_column:
- caption = row[fallback_caption_column]
- if not caption:
+ if caption is None and fallback_caption_column:
+ caption = row.get(fallback_caption_column, None)
+ if caption is None or caption == "" or caption == []:
raise ValueError(
f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries."
)
@@ -162,7 +164,7 @@ def _extract_captions_to_fast_list(self):
caption = caption.decode("utf-8")
elif type(caption) == list:
caption = [c.strip() for c in caption if c.strip()]
- if caption:
+ elif type(caption) == str:
caption = caption.strip()
captions[filename] = caption
return captions
diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py
new file mode 100644
index 00000000..fadbaf05
--- /dev/null
+++ b/helpers/models/omnigen/pipeline.py
@@ -0,0 +1,367 @@
+import os
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+import gc
+
+from PIL import Image
+import numpy as np
+import torch
+from huggingface_hub import snapshot_download
+from peft import LoraConfig, PeftModel
+from diffusers.models import AutoencoderKL
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from safetensors.torch import load_file
+
+from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from OmniGen import OmniGenPipeline
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
+ ... base_model
+ ... )
+ >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
+ >>> image = pipe(
+ ... prompt,
+ ... guidance_scale=2.5,
+ ... num_inference_steps=50,
+ ... ).images[0]
+ >>> image.save("t2i.png")
+ ```
+"""
+
+
+90
+
+
+class OmniGenPipeline:
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ model: OmniGen,
+ processor: OmniGenProcessor,
+ device: Union[str, torch.device],
+ ):
+ self.vae = vae
+ self.model = model
+ self.processor = processor
+ self.device = device
+
+ self.model.to(torch.bfloat16)
+ self.model.eval()
+ self.vae.eval()
+
+ self.model_cpu_offload = False
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path, vae_path: str = None, **kwargs
+ ):
+ if not os.path.exists(pretrained_model_name_or_path) or (
+ not os.path.exists(
+ os.path.join(pretrained_model_name_or_path, "model.safetensors")
+ )
+ and pretrained_model_name_or_path == "Shitao/OmniGen-v1"
+ ):
+ logger.info("Model not found, downloading...")
+ cache_folder = os.getenv("HF_HUB_CACHE")
+ pretrained_model_name_or_path = snapshot_download(
+ repo_id=pretrained_model_name_or_path,
+ cache_dir=cache_folder,
+ ignore_patterns=[
+ "flax_model.msgpack",
+ "rust_model.ot",
+ "tf_model.h5",
+ "model.pt",
+ ],
+ )
+ logger.info(f"Downloaded model to {pretrained_model_name_or_path}")
+ model = OmniGen.from_pretrained(pretrained_model_name_or_path)
+ processor = OmniGenProcessor.from_pretrained(pretrained_model_name_or_path)
+
+ if os.path.exists(os.path.join(pretrained_model_name_or_path, "vae")):
+ vae = AutoencoderKL.from_pretrained(
+ os.path.join(pretrained_model_name_or_path, "vae")
+ )
+ elif vae_path is not None:
+ vae = AutoencoderKL.from_pretrained(vae_path).to(
+ StateTracker.get_accelerator().device
+ )
+ else:
+ logger.info(
+ f"No VAE found in {pretrained_model_name_or_path}, downloading stabilityai/sdxl-vae from HF"
+ )
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(
+ StateTracker.get_accelerator().device
+ )
+
+ print(f"OmniGenPipeline received unexpected arguments: {kwargs.keys()}")
+
+ return cls(vae, model, processor)
+
+ def merge_lora(self, lora_path: str):
+ model = PeftModel.from_pretrained(self.model, lora_path)
+ model.merge_and_unload()
+
+ self.model = model
+
+ def to(self, device: Union[str, torch.device]):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.model.to(device)
+ self.vae.to(device)
+ self.device = device
+
+ def vae_encode(self, x, dtype):
+ if self.vae.config.shift_factor is not None:
+ x = self.vae.encode(x).latent_dist.sample()
+ x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ else:
+ x = (
+ self.vae.encode(x)
+ .latent_dist.sample()
+ .mul_(self.vae.config.scaling_factor)
+ )
+ x = x.to(dtype)
+ return x
+
+ def move_to_device(self, data):
+ if isinstance(data, list):
+ return [x.to(self.device) for x in data]
+ return data.to(self.device)
+
+ def enable_model_cpu_offload(self):
+ self.model_cpu_offload = True
+ self.model.to("cpu")
+ self.vae.to("cpu")
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+
+ def disable_model_cpu_offload(self):
+ self.model_cpu_offload = False
+ self.model.to(self.device)
+ self.vae.to(self.device)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Union[List[str], List[List[str]]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3,
+ use_img_guidance: bool = True,
+ img_guidance_scale: float = 1.6,
+ max_input_image_size: int = 1024,
+ separate_cfg_infer: bool = True,
+ offload_model: bool = False,
+ use_kv_cache: bool = True,
+ offload_kv_cache: bool = True,
+ use_input_image_size_as_output: bool = False,
+ dtype: torch.dtype = torch.bfloat16,
+ seed: int = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
+ The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
+ height (`int`, *optional*, defaults to 1024):
+ The height in pixels of the generated image. The number must be a multiple of 16.
+ width (`int`, *optional*, defaults to 1024):
+ The width in pixels of the generated image. The number must be a multiple of 16.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ use_img_guidance (`bool`, *optional*, defaults to True):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size
+ separate_cfg_infer (`bool`, *optional*, defaults to False):
+ Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
+ use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
+ offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly
+ offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation
+ use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task
+ seed (`int`, *optional*):
+ A random seed for generating output.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ data type for the model
+ Examples:
+
+ Returns:
+ A list with the generated images.
+ """
+ # check inputs:
+ if use_input_image_size_as_output:
+ assert (
+ isinstance(prompt, str) and len(input_images) == 1
+ ), "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images"
+ else:
+ assert (
+ height % 16 == 0 and width % 16 == 0
+ ), "The height and width must be a multiple of 16."
+ if input_images is None:
+ use_img_guidance = False
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ input_images = [input_images] if input_images is not None else None
+
+ # set model and processor
+ if max_input_image_size != self.processor.max_image_size:
+ self.processor = OmniGenProcessor(
+ self.processor.text_tokenizer, max_image_size=max_input_image_size
+ )
+ if offload_model:
+ self.enable_model_cpu_offload()
+ else:
+ self.disable_model_cpu_offload()
+
+ input_data = self.processor(
+ prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_guidance,
+ separate_cfg_input=separate_cfg_infer,
+ use_input_image_size_as_output=use_input_image_size_as_output,
+ )
+ print(f"Input shapes: {input_data['attention_mask'][0].shape}")
+
+ num_prompt = len(prompt)
+ num_cfg = 2 if use_img_guidance else 1
+ if use_input_image_size_as_output:
+ if separate_cfg_infer:
+ height, width = input_data["input_pixel_values"][0][0].shape[-2:]
+ else:
+ height, width = input_data["input_pixel_values"][0].shape[-2:]
+ latent_size_h, latent_size_w = height // 8, width // 8
+
+ if seed is not None:
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+ else:
+ generator = None
+ latents = torch.randn(
+ num_prompt,
+ 4,
+ latent_size_h,
+ latent_size_w,
+ device=self.device,
+ generator=generator,
+ )
+ latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype)
+
+ if input_images is not None and self.model_cpu_offload:
+ self.vae.to(self.device)
+ input_img_latents = []
+ if separate_cfg_infer:
+ for temp_pixel_values in input_data["input_pixel_values"]:
+ temp_input_latents = []
+ for img in temp_pixel_values:
+ img = self.vae_encode(img.to(self.device), dtype)
+ temp_input_latents.append(img)
+ input_img_latents.append(temp_input_latents)
+ else:
+ for img in input_data["input_pixel_values"]:
+ img = self.vae_encode(img.to(self.device), dtype)
+ input_img_latents.append(img)
+ if input_images is not None and self.model_cpu_offload:
+ self.vae.to("cpu")
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+
+ model_kwargs = dict(
+ input_ids=self.move_to_device(input_data["input_ids"]),
+ input_img_latents=input_img_latents,
+ input_image_sizes=input_data["input_image_sizes"],
+ attention_mask=self.move_to_device(input_data["attention_mask"]),
+ position_ids=self.move_to_device(input_data["position_ids"]),
+ cfg_scale=guidance_scale,
+ img_cfg_scale=img_guidance_scale,
+ use_img_cfg=use_img_guidance,
+ use_kv_cache=use_kv_cache,
+ offload_model=offload_model,
+ )
+
+ if separate_cfg_infer:
+ func = self.model.forward_with_separate_cfg
+ else:
+ func = self.model.forward_with_cfg
+ self.model.to(dtype)
+
+ if self.model_cpu_offload:
+ for name, param in self.model.named_parameters():
+ if "layers" in name and "layers.0" not in name:
+ param.data = param.data.cpu()
+ else:
+ param.data = param.data.to(self.device)
+ for buffer_name, buffer in self.model.named_buffers():
+ setattr(self.model, buffer_name, buffer.to(self.device))
+ # else:
+ # self.model.to(self.device)
+
+ scheduler = OmniGenScheduler(num_steps=num_inference_steps)
+ samples = scheduler(
+ latents,
+ func,
+ model_kwargs,
+ use_kv_cache=use_kv_cache,
+ offload_kv_cache=offload_kv_cache,
+ )
+ samples = samples.chunk((1 + num_cfg), dim=0)[0]
+
+ if self.model_cpu_offload:
+ self.model.to("cpu")
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ self.vae.to(self.device)
+ samples = samples.to(torch.float32)
+ if self.vae.config.shift_factor is not None:
+ samples = (
+ samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
+ )
+ else:
+ samples = samples / self.vae.config.scaling_factor
+ samples = self.vae.decode(
+ samples.to(dtype=self.vae.dtype, device=self.vae.device)
+ ).sample
+
+ if self.model_cpu_offload:
+ self.vae.to("cpu")
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255
+ output_samples = (
+ output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
+ )
+ output_images = []
+ for i, sample in enumerate(output_samples):
+ output_images.append(Image.fromarray(sample))
+
+ torch.cuda.empty_cache() # Clear VRAM
+ gc.collect() # Run garbage collection to free system RAM
+ return output_images
diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py
index 1bf3332f..653c2a6a 100644
--- a/helpers/models/sd3/pipeline.py
+++ b/helpers/models/sd3/pipeline.py
@@ -29,9 +29,12 @@
from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
+ USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
@@ -39,6 +42,7 @@
StableDiffusion3PipelineOutput,
)
+from diffusers.image_processor import PipelineImageInput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -76,7 +80,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -221,11 +225,17 @@ def __init__(
if hasattr(self, "transformer") and self.transformer is not None
else 128
)
+ self.patch_size = (
+ self.transformer.config.patch_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 2
+ )
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -238,7 +248,7 @@ def _get_t5_prompt_embeds(
if self.text_encoder_3 is None:
return torch.zeros(
(
- batch_size,
+ batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
@@ -249,7 +259,7 @@ def _get_t5_prompt_embeds(
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
- max_length=self.tokenizer_max_length,
+ max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
@@ -266,8 +276,8 @@ def _get_t5_prompt_embeds(
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
)
logger.warning(
- "The following part of your input was truncated because CLIP can only handle sequences up to"
- f" {self.tokenizer_max_length} tokens: {removed_text}"
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
@@ -368,6 +378,8 @@ def encode_prompt(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
):
r"""
@@ -413,9 +425,22 @@ def encode_prompt(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
@@ -448,6 +473,7 @@ def encode_prompt(
t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
device=device,
)
@@ -520,6 +546,7 @@ def encode_prompt(
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
device=device,
)
@@ -539,6 +566,16 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)
+ if self.text_encoder is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
return (
prompt_embeds,
negative_prompt_embeds,
@@ -561,10 +598,15 @@ def check_inputs(
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
):
- if height % 8 != 0 or width % 8 != 0:
+ if (
+ height % (self.vae_scale_factor * self.patch_size) != 0
+ or width % (self.vae_scale_factor * self.patch_size) != 0
+ ):
raise ValueError(
- f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
@@ -647,6 +689,11 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
+ )
+
def prepare_latents(
self,
batch_size,
@@ -733,6 +780,11 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ skip_guidance_layers: List[int] = None,
+ skip_layer_guidance_scale: int = 2.8,
+ skip_layer_guidance_stop: int = 0.2,
+ skip_layer_guidance_start: int = 0.01,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -758,7 +810,7 @@ def __call__(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 5.0):
+ guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -801,8 +853,8 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
+ a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -816,12 +868,29 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+ skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance.
+ If not provided, all layers will be used for guidance. If provided, the guidance will only be applied
+ to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is
+ [7, 8, 9].
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with
+ a scale of `1`.
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
Examples:
Returns:
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
@@ -843,9 +912,11 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
@@ -860,6 +931,11 @@ def __call__(
device = self._execution_device
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None)
+ if self.joint_attention_kwargs is not None
+ else None
+ )
(
prompt_embeds,
negative_prompt_embeds,
@@ -880,9 +956,15 @@ def __call__(
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
+ if skip_guidance_layers is not None:
+ original_prompt_embeds = prompt_embeds
+ original_pooled_prompt_embeds = pooled_prompt_embeds
+ # we do not combine the inference if we skip guidance layers.
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat(
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
@@ -909,8 +991,6 @@ def __call__(
generator,
latents,
)
- latents = latents.to(self.transformer.device)
- timesteps = timesteps.to(self.transformer.device)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -921,22 +1001,20 @@ def __call__(
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2)
- if self.do_classifier_free_guidance
+ if self.do_classifier_free_guidance and skip_guidance_layers is None
else latents
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
- hidden_states=latent_model_input.to(
- device=self.transformer.device, dtype=self.transformer.dtype
- ),
+ hidden_states=latent_model_input.to(device=self.transformer.device),
timestep=timestep,
encoder_hidden_states=prompt_embeds.to(
- device=self.transformer.device, dtype=self.transformer.dtype
+ device=self.transformer.device
),
pooled_projections=pooled_prompt_embeds.to(
- device=self.transformer.device, dtype=self.transformer.dtype
+ device=self.transformer.device
),
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
@@ -948,6 +1026,33 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
+ should_skip_layers = (
+ True
+ if i > num_inference_steps * skip_layer_guidance_start
+ and i < num_inference_steps * skip_layer_guidance_stop
+ else False
+ )
+ if skip_guidance_layers is not None and should_skip_layers:
+ noise_pred_skip_layers = self.transformer(
+ hidden_states=latent_model_input.to(
+ device=self.transformer.device,
+ ),
+ timestep=timestep,
+ encoder_hidden_states=original_prompt_embeds.to(
+ device=self.transformer.device,
+ ),
+ pooled_projections=original_pooled_prompt_embeds.to(
+ device=self.transformer.device,
+ ),
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ skip_layers=skip_guidance_layers,
+ )[0]
+ noise_pred = (
+ noise_pred
+ + (noise_pred_text - noise_pred_skip_layers)
+ * self._skip_layer_guidance_scale
+ )
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
@@ -1004,84 +1109,9 @@ def __call__(
return StableDiffusion3PipelineOutput(images=image)
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Callable, Dict, List, Optional, Union
-
-import PIL.Image
-import torch
-from transformers import (
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- T5EncoderModel,
- T5TokenizerFast,
-)
-
-from diffusers.image_processor import PipelineImageInput
-
-
-if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
-
- XLA_AVAILABLE = True
-else:
- XLA_AVAILABLE = False
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-EXAMPLE_DOC_STRING = """
- Examples:
- ```py
- >>> import torch
-
- >>> from diffusers import AutoPipelineForImage2Image
- >>> from diffusers.utils import load_image
-
- >>> device = "cuda"
- >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
- >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
- >>> pipe = pipe.to(device)
-
- >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
- >>> init_image = load_image(url).resize((512, 512))
-
- >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
-
- >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
- ```
-"""
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor,
- generator: Optional[torch.Generator] = None,
- sample_mode: str = "sample",
+class StableDiffusion3Img2ImgPipeline(
+ DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin
):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
-class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -1164,6 +1194,7 @@ def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
+ max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -1176,7 +1207,7 @@ def _get_t5_prompt_embeds(
if self.text_encoder_3 is None:
return torch.zeros(
(
- batch_size,
+ batch_size * num_images_per_prompt,
self.tokenizer_max_length,
self.transformer.config.joint_attention_dim,
),
@@ -1187,7 +1218,7 @@ def _get_t5_prompt_embeds(
text_inputs = self.tokenizer_3(
prompt,
padding="max_length",
- max_length=self.tokenizer_max_length,
+ max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
@@ -1204,8 +1235,8 @@ def _get_t5_prompt_embeds(
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
)
logger.warning(
- "The following part of your input was truncated because CLIP can only handle sequences up to"
- f" {self.tokenizer_max_length} tokens: {removed_text}"
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
@@ -1308,6 +1339,8 @@ def encode_prompt(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
+ max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
):
r"""
@@ -1353,9 +1386,22 @@ def encode_prompt(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
@@ -1388,6 +1434,7 @@ def encode_prompt(
t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
device=device,
)
@@ -1460,6 +1507,7 @@ def encode_prompt(
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
device=device,
)
@@ -1479,6 +1527,16 @@ def encode_prompt(
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
)
+ if self.text_encoder is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
return (
prompt_embeds,
negative_prompt_embeds,
@@ -1500,6 +1558,7 @@ def check_inputs(
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(
@@ -1586,6 +1645,11 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
+ )
+
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
@@ -1613,8 +1677,6 @@ def prepare_latents(
)
image = image.to(device=device, dtype=dtype)
- if image.shape[1] == self.vae.config.latent_channels:
- init_latents = image
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == self.vae.config.latent_channels:
@@ -1676,6 +1738,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
@property
def clip_skip(self):
return self._clip_skip
@@ -1719,9 +1785,11 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1747,7 +1815,7 @@ def __call__(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 5.0):
+ guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -1790,8 +1858,12 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
+ a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -1801,12 +1873,13 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
@@ -1824,10 +1897,12 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
+ self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
@@ -1840,6 +1915,12 @@ def __call__(
device = self._execution_device
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None)
+ if self.joint_attention_kwargs is not None
+ else None
+ )
+
(
prompt_embeds,
negative_prompt_embeds,
@@ -1860,6 +1941,8 @@ def __call__(
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
)
if self.do_classifier_free_guidance:
@@ -1878,7 +1961,7 @@ def __call__(
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device
)
- latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
if latents is None:
@@ -1916,6 +1999,7 @@ def __call__(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
+ joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
diff --git a/helpers/models/sd3/transformer.py b/helpers/models/sd3/transformer.py
new file mode 100644
index 00000000..943bdb7b
--- /dev/null
+++ b/helpers/models/sd3/transformer.py
@@ -0,0 +1,465 @@
+# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
+from diffusers.models.attention import JointTransformerBlock
+from diffusers.models.attention_processor import (
+ Attention,
+ AttentionProcessor,
+ FusedJointAttnProcessor2_0,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormContinuous
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_version,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.models.embeddings import (
+ CombinedTimestepTextProjEmbeddings,
+ PatchEmbed,
+)
+from diffusers.models.modeling_outputs import (
+ Transformer2DModelOutput,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SD3Transformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
+):
+ """
+ The Transformer model introduced in Stable Diffusion 3.
+
+ Reference: https://arxiv.org/abs/2403.03206
+
+ Parameters:
+ sample_size (`int`): The width of the latent images. This is fixed during training since
+ it is used to learn a number of position embeddings.
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ out_channels (`int`, defaults to 16): Number of output channels.
+
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ num_layers: int = 18,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 18,
+ joint_attention_dim: int = 4096,
+ caption_projection_dim: int = 1152,
+ pooled_projection_dim: int = 2048,
+ out_channels: int = 16,
+ pos_embed_max_size: int = 96,
+ dual_attention_layers: Tuple[
+ int, ...
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
+ qk_norm: Optional[str] = None,
+ ):
+ super().__init__()
+ default_out_channels = in_channels
+ self.out_channels = (
+ out_channels if out_channels is not None else default_out_channels
+ )
+ self.inner_dim = (
+ self.config.num_attention_heads * self.config.attention_head_dim
+ )
+
+ self.pos_embed = PatchEmbed(
+ height=self.config.sample_size,
+ width=self.config.sample_size,
+ patch_size=self.config.patch_size,
+ in_channels=self.config.in_channels,
+ embed_dim=self.inner_dim,
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
+ )
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=self.inner_dim,
+ pooled_projection_dim=self.config.pooled_projection_dim,
+ )
+ self.context_embedder = nn.Linear(
+ self.config.joint_attention_dim, self.config.caption_projection_dim
+ )
+
+ # `attention_head_dim` is doubled to account for the mixing.
+ # It needs to crafted when we get the actual checkpoints.
+ self.transformer_blocks = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ context_pre_only=i == num_layers - 1,
+ qk_norm=qk_norm,
+ use_dual_attention=True if i in dual_attention_layers else False,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
+ )
+ self.proj_out = nn.Linear(
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
+ )
+
+ self.gradient_checkpointing = False
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(
+ self, chunk_size: Optional[int] = None, dim: int = 0
+ ) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(
+ module: torch.nn.Module, chunk_size: int, dim: int
+ ):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
+ def disable_forward_chunking(self):
+ def fn_recursive_feed_forward(
+ module: torch.nn.Module, chunk_size: int, dim: int
+ ):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError(
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
+ )
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ pooled_projections: torch.FloatTensor = None,
+ timestep: torch.LongTensor = None,
+ block_controlnet_hidden_states: List = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ skip_layers: Optional[List[int]] = None,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`SD3Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ skip_layers (`list` of `int`, *optional*):
+ A list of layer indices to skip during the forward pass.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if (
+ joint_attention_kwargs is not None
+ and joint_attention_kwargs.get("scale", None) is not None
+ ):
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ height, width = hidden_states.shape[-2:]
+
+ hidden_states = self.pos_embed(
+ hidden_states
+ ) # takes care of adding positional embeddings too.
+ temb = self.time_text_embed(timestep, pooled_projections)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ # Skip specified layers
+ if skip_layers is not None and index_block in skip_layers:
+ if (
+ block_controlnet_hidden_states is not None
+ and block.context_pre_only is False
+ ):
+ interval_control = len(self.transformer_blocks) // len(
+ block_controlnet_hidden_states
+ )
+ hidden_states = (
+ hidden_states
+ + block_controlnet_hidden_states[
+ index_block // interval_control
+ ]
+ )
+ continue
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = (
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ )
+ encoder_hidden_states, hidden_states = (
+ torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ )
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ )
+
+ # controlnet residual
+ if (
+ block_controlnet_hidden_states is not None
+ and block.context_pre_only is False
+ ):
+ interval_control = len(self.transformer_blocks) // len(
+ block_controlnet_hidden_states
+ )
+ hidden_states = (
+ hidden_states
+ + block_controlnet_hidden_states[index_block // interval_control]
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # unpatchify
+ patch_size = self.config.patch_size
+ height = height // patch_size
+ width = width // patch_size
+
+ hidden_states = hidden_states.reshape(
+ shape=(
+ hidden_states.shape[0],
+ height,
+ width,
+ patch_size,
+ patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ hidden_states.shape[0],
+ self.out_channels,
+ height * patch_size,
+ width * patch_size,
+ )
+ )
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/helpers/prompts.py b/helpers/prompts.py
index 3e2fc9dd..98845f98 100644
--- a/helpers/prompts.py
+++ b/helpers/prompts.py
@@ -5,6 +5,13 @@
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training import image_file_extensions
+import numpy
+
+try:
+ import pandas as pd
+except ImportError:
+ raise ImportError("Pandas is required for the ParquetMetadataBackend.")
+
prompts = {
"alien_landscape": "Alien planet, strange rock formations, glowing plants, bizarre creatures, surreal atmosphere",
"alien_market": "Alien marketplace, bizarre creatures, exotic goods, vibrant colors, otherworldly atmosphere",
@@ -256,8 +263,10 @@ def prepare_instance_prompt_from_parquet(
)
if type(image_caption) == bytes:
image_caption = image_caption.decode("utf-8")
- if image_caption:
+ if type(image_caption) == str:
image_caption = image_caption.strip()
+ if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series):
+ image_caption = [str(item).strip() for item in image_caption if item is not None]
if prepend_instance_prompt:
if type(image_caption) == list:
image_caption = [instance_prompt + " " + x for x in image_caption]
@@ -436,17 +445,14 @@ def get_all_captions(
data_backend=data_backend,
)
elif caption_strategy == "parquet":
- try:
- caption = PromptHandler.prepare_instance_prompt_from_parquet(
- image_path,
- use_captions=use_captions,
- prepend_instance_prompt=prepend_instance_prompt,
- instance_prompt=instance_prompt,
- data_backend=data_backend,
- sampler_backend_id=data_backend.id,
- )
- except:
- continue
+ caption = PromptHandler.prepare_instance_prompt_from_parquet(
+ image_path,
+ use_captions=use_captions,
+ prepend_instance_prompt=prepend_instance_prompt,
+ instance_prompt=instance_prompt,
+ data_backend=data_backend,
+ sampler_backend_id=data_backend.id,
+ )
elif caption_strategy == "instanceprompt":
return [instance_prompt]
elif caption_strategy == "csv":
diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py
index f41600a2..b0031157 100644
--- a/helpers/publishing/metadata.py
+++ b/helpers/publishing/metadata.py
@@ -153,6 +153,15 @@ def _guidance_rescale(args):
return f"\n guidance_rescale={args.validation_guidance_rescale},"
+def _skip_layers(args):
+ if (
+ args.model_family.lower() not in ["sd3"]
+ or args.validation_guidance_skip_layers is None
+ ):
+ return ""
+ return f"\n skip_guidance_layers={args.validation_guidance_skip_layers},"
+
+
def _validation_resolution(args):
if args.validation_resolution == "" or args.validation_resolution is None:
return f"width=1024,\n" f" height=1024,"
@@ -185,7 +194,7 @@ def code_example(args, repo_id: str = None):
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
{_validation_resolution(args)}
- guidance_scale={args.validation_guidance},{_guidance_rescale(args)}
+ guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)}
).images[0]
image.save("output.png", format="PNG")
```
@@ -249,10 +258,38 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
- if args.model_type == "lora" and args.lora_type == "standard":
+ if (
+ args.model_type == "lora"
+ and args.lora_type == "standard"
+ and args.flux_lora_target is not None
+ ):
output_args.append(f"flux_lora_target={args.flux_lora_target}")
output_str = (
- f" (flux parameters={output_args})"
+ f" (extra parameters={output_args})"
+ if output_args
+ else " (no special parameters set)"
+ )
+
+ return output_str
+
+
+def sd3_schedule_info(args):
+ if args.model_family.lower() != "sd3":
+ return ""
+ output_args = []
+ if args.flux_schedule_auto_shift:
+ output_args.append("flux_schedule_auto_shift")
+ if args.flux_schedule_shift is not None:
+ output_args.append(f"shift={args.flux_schedule_shift}")
+ if args.flux_use_beta_schedule:
+ output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}")
+ output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
+ if args.flux_use_uniform_schedule:
+ output_args.append(f"flux_use_uniform_schedule")
+ # if args.model_type == "lora" and args.lora_type == "standard":
+ # output_args.append(f"flux_lora_target={args.flux_lora_target}")
+ output_str = (
+ f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)
@@ -260,6 +297,13 @@ def flux_schedule_info(args):
return output_str
+def model_schedule_info(args):
+ if args.model_family == "flux":
+ return flux_schedule_info(args)
+ if args.model_family == "sd3":
+ return sd3_schedule_info(args)
+
+
def save_model_card(
repo_id: str,
images=None,
@@ -384,7 +428,7 @@ def save_model_card(
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
-- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{flux_schedule_info(args=StateTracker.get_args())}
+- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
diff --git a/helpers/training/adapter.py b/helpers/training/adapter.py
index 04b99069..caaebdd8 100644
--- a/helpers/training/adapter.py
+++ b/helpers/training/adapter.py
@@ -107,7 +107,11 @@ def load_lora_weights(dictionary, filename, loraKey="default", use_dora=False):
missing_keys = set(
[x + ".lora_A.weight" for x in lora_layers.keys()]
+ [x + ".lora_B.weight" for x in lora_layers.keys()]
- + ([x + ".lora_magnitude_vector.weight"] if use_dora else [])
+ + (
+ [x + ".lora_magnitude_vector.weight" for x in lora_layers.keys()]
+ if use_dora
+ else []
+ )
)
for k, v in state_dict.items():
if "lora_A" in k:
diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py
index 78611dc8..5ac0c207 100644
--- a/helpers/training/diffusion_model.py
+++ b/helpers/training/diffusion_model.py
@@ -38,7 +38,7 @@ def load_diffusion_model(args, weight_dtype):
# Stable Diffusion 3 uses a Diffusion transformer.
logger.info("Loading Stable Diffusion 3 diffusion transformer..")
try:
- from diffusers import SD3Transformer2DModel
+ from helpers.models.sd3.transformer import SD3Transformer2DModel
except Exception as e:
logger.error(
f"Can not load SD3 model class. This release requires the latest version of Diffusers: {e}"
diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py
index 51fb85de..0186d834 100644
--- a/helpers/training/save_hooks.py
+++ b/helpers/training/save_hooks.py
@@ -11,6 +11,7 @@
from helpers.models.sdxl.pipeline import StableDiffusionXLPipeline
from helpers.training.state_tracker import StateTracker
from helpers.models.smoldit import SmolDiT2DModel, SmolDiTPipeline
+from helpers.models.sd3.transformer import SD3Transformer2DModel
import os
import logging
import shutil
@@ -27,7 +28,6 @@
from diffusers import (
UNet2DConditionModel,
StableDiffusion3Pipeline,
- SD3Transformer2DModel,
StableDiffusionPipeline,
FluxPipeline,
PixArtSigmaPipeline,
diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py
index a5e70d44..4a0eb031 100644
--- a/helpers/training/trainer.py
+++ b/helpers/training/trainer.py
@@ -2185,7 +2185,7 @@ def train(self):
if self.config.flow_matching:
if (
not self.config.flux_fast_schedule
- and not self.config.flux_use_beta_schedule
+ and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule])
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
@@ -2197,6 +2197,11 @@ def train(self):
sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
)
+ elif self.config.flux_use_uniform_schedule:
+ sigmas = torch.rand((bsz,), device=self.accelerator.device)
+ sigmas = apply_flux_schedule_shift(
+ self.config, self.noise_scheduler, sigmas, noise
+ )
elif self.config.flux_use_beta_schedule:
alpha = self.config.flux_beta_schedule_alpha
beta = self.config.flux_beta_schedule_beta
diff --git a/helpers/training/validation.py b/helpers/training/validation.py
index b9406a77..478d56af 100644
--- a/helpers/training/validation.py
+++ b/helpers/training/validation.py
@@ -969,6 +969,10 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
"vae": self.vae,
"safety_checker": None,
}
+ if self.args.model_family in ["sd3", "sdxl", "flux"]:
+ extra_pipeline_kwargs["text_encoder_2"] = None
+ if self.args.model_family in ["sd3"]:
+ extra_pipeline_kwargs["text_encoder_3"] = None
if type(pipeline_cls) is StableDiffusionXLPipeline:
del extra_pipeline_kwargs["safety_checker"]
del extra_pipeline_kwargs["text_encoder"]
@@ -1071,7 +1075,7 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
logger.error(e)
logger.error(traceback.format_exc())
continue
- return None
+ break
if self.args.validation_torch_compile:
if self.unet is not None and not is_compiled_module(self.unet):
logger.warning(
@@ -1192,6 +1196,23 @@ def validate_prompt(
else:
validation_resolution_width, validation_resolution_height = resolution
+ if (
+ self.args.model_family == "sd3"
+ and type(self.args.validation_guidance_skip_layers) is list
+ ):
+ extra_validation_kwargs["skip_layer_guidance_start"] = float(
+ self.args.validation_guidance_skip_layers_start
+ )
+ extra_validation_kwargs["skip_layer_guidance_stop"] = float(
+ self.args.validation_guidance_skip_layers_stop
+ )
+ extra_validation_kwargs["skip_layer_guidance_scale"] = float(
+ self.args.validation_guidance_skip_scale
+ )
+ extra_validation_kwargs["skip_guidance_layers"] = list(
+ self.args.validation_guidance_skip_layers
+ )
+
if not self.flow_matching and self.args.model_family not in [
"deepfloyd",
"pixart_sigma",
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index aa6d5d26..7ae433b5 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -1,10 +1,12 @@
import unittest
+import pandas as pd
from unittest.mock import patch, Mock, MagicMock
from PIL import Image
from pathlib import Path
from helpers.multiaspect.dataset import MultiAspectDataset
from helpers.metadata.backends.discovery import DiscoveryMetadataBackend
from helpers.data_backend.base import BaseDataBackend
+from helpers.data_backend.factory import check_column_values
class TestMultiAspectDataset(unittest.TestCase):
@@ -82,5 +84,62 @@ def test_getitem_invalid_image(self):
self.dataset.__getitem__(self.image_metadata)
+class TestDataBackendFactory(unittest.TestCase):
+ def test_all_null(self):
+ column_data = pd.Series([None, None, None])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains only null values", str(context.exception))
+
+ def test_arrays_with_nulls(self):
+ column_data = pd.Series([[1, 2], None, [3, 4]])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains null arrays", str(context.exception))
+
+ def test_empty_arrays(self):
+ column_data = pd.Series([[1, 2], [], [3, 4]])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains empty arrays", str(context.exception))
+
+ def test_null_elements_in_arrays(self):
+ column_data = pd.Series([[1, None], [2, 3], [3, 4]])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains null values within arrays", str(context.exception))
+
+ def test_empty_strings_in_arrays(self):
+ column_data = pd.Series([["", ""], ["", ""], ["", ""]])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains only empty strings within arrays", str(context.exception))
+
+ def test_scalar_strings_with_nulls(self):
+ column_data = pd.Series(["a", None, "b"])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains null values", str(context.exception))
+
+ def test_scalar_strings_with_empty(self):
+ column_data = pd.Series(["a", "", "b"])
+ with self.assertRaises(ValueError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("contains empty strings", str(context.exception))
+
+ def test_with_fallback_caption(self):
+ column_data = pd.Series([None, "", [None], [""]])
+ try:
+ check_column_values(column_data, "test_column", "test_file.parquet", fallback_caption_column=True)
+ except ValueError:
+ self.fail("check_column_values() raised ValueError unexpectedly with fallback_caption_column=True")
+
+ def test_invalid_data_type(self):
+ column_data = pd.Series([1, 2, 3])
+ with self.assertRaises(TypeError) as context:
+ check_column_values(column_data, "test_column", "test_file.parquet")
+ self.assertIn("Unsupported data type in column", str(context.exception))
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_model_card.py b/tests/test_model_card.py
new file mode 100644
index 00000000..b9c07d33
--- /dev/null
+++ b/tests/test_model_card.py
@@ -0,0 +1,283 @@
+import unittest
+from unittest.mock import MagicMock, patch
+import os
+import json
+
+from helpers.publishing.metadata import (
+ _negative_prompt,
+ _torch_device,
+ _model_imports,
+ _model_load,
+ _validation_resolution,
+ _skip_layers,
+ _guidance_rescale,
+)
+from helpers.publishing.metadata import *
+
+
+class TestMetadataFunctions(unittest.TestCase):
+ def setUp(self):
+ # Mock the args object
+ self.args = MagicMock()
+ self.args.lora_type = "standard"
+ self.args.model_type = "lora"
+ self.args.model_family = "sdxl"
+ self.args.validation_prompt = "A test prompt"
+ self.args.validation_negative_prompt = "A negative prompt"
+ self.args.validation_num_inference_steps = 50
+ self.args.validation_guidance = 7.5
+ self.args.validation_guidance_rescale = 0.7
+ self.args.validation_resolution = "512x512"
+ self.args.pretrained_model_name_or_path = "test-model"
+ self.args.output_dir = "test-output"
+ self.args.lora_rank = 4
+ self.args.lora_alpha = 1.0
+ self.args.lora_dropout = 0.0
+ self.args.lora_init_type = "kaiming_uniform"
+ self.args.model_card_note = "Test note"
+ self.args.validation_using_datasets = False
+ self.args.flow_matching_loss = "flow-matching"
+ self.args.flux_fast_schedule = False
+ self.args.flux_schedule_auto_shift = False
+ self.args.flux_schedule_shift = None
+ self.args.flux_guidance_value = None
+ self.args.flux_guidance_min = None
+ self.args.flux_guidance_max = None
+ self.args.flux_use_beta_schedule = False
+ self.args.flux_beta_schedule_alpha = None
+ self.args.flux_beta_schedule_beta = None
+ self.args.flux_attention_masked_training = False
+ self.args.flux_use_uniform_schedule = False
+ self.args.flux_lora_target = None
+ self.args.validation_guidance_skip_layers = None
+ self.args.validation_seed = 1234
+ self.args.validation_noise_scheduler = "ddim"
+ self.args.model_card_safe_for_work = True
+ self.args.learning_rate = 1e-4
+ self.args.max_grad_norm = 1.0
+ self.args.train_batch_size = 4
+ self.args.gradient_accumulation_steps = 1
+ self.args.optimizer = "AdamW"
+ self.args.optimizer_config = ""
+ self.args.mixed_precision = "fp16"
+ self.args.base_model_precision = "no_change"
+ self.args.enable_xformers_memory_efficient_attention = False
+
+ def test_model_imports(self):
+ self.args.lora_type = "standard"
+ self.args.model_type = "lora"
+ expected_output = "import torch\nfrom diffusers import DiffusionPipeline"
+ output = _model_imports(self.args)
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ self.args.lora_type = "lycoris"
+ output = _model_imports(self.args)
+ self.assertIn("from lycoris import create_lycoris_from_weights", output)
+
+ def test_model_load(self):
+ self.args.pretrained_model_name_or_path = "pretrained-model"
+ self.args.output_dir = "output-dir"
+ self.args.lora_type = "standard"
+ self.args.model_type = "lora"
+
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_hf_username",
+ return_value="testuser",
+ ):
+ output = _model_load(self.args, repo_id="repo-id")
+ self.assertIn("pipeline.load_lora_weights", output)
+ self.assertIn("adapter_id = 'testuser/repo-id'", output)
+
+ self.args.lora_type = "lycoris"
+ output = _model_load(self.args)
+ self.assertIn("pytorch_lora_weights.safetensors", output)
+
+ def test_torch_device(self):
+ output = _torch_device()
+ expected_output = "'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ def test_negative_prompt(self):
+ self.args.model_family = "sdxl"
+ output = _negative_prompt(self.args)
+ expected_output = "negative_prompt = 'A negative prompt'"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ output_in_call = _negative_prompt(self.args, in_call=True)
+ self.assertIn("negative_prompt=negative_prompt", output_in_call)
+
+ def test_guidance_rescale(self):
+ self.args.model_family = "sdxl"
+ output = _guidance_rescale(self.args)
+ expected_output = "\n guidance_rescale=0.7,"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ self.args.model_family = "flux"
+ output = _guidance_rescale(self.args)
+ self.assertEqual(output.strip(), "")
+
+ def test_skip_layers(self):
+ self.args.model_family = "sd3"
+ self.args.validation_guidance_skip_layers = 2
+ output = _skip_layers(self.args)
+ expected_output = "\n skip_guidance_layers=2,"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ self.args.model_family = "sdxl"
+ output = _skip_layers(self.args)
+ self.assertEqual(output.strip(), "")
+
+ def test_validation_resolution(self):
+ self.args.validation_resolution = "512x512"
+ output = _validation_resolution(self.args)
+ expected_output = "width=512,\n height=512,"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ self.args.validation_resolution = ""
+ output = _validation_resolution(self.args)
+ expected_output = "width=1024,\n height=1024,"
+ self.assertEqual(output.strip(), expected_output.strip())
+
+ def test_code_example(self):
+ with patch(
+ "helpers.publishing.metadata._model_imports",
+ return_value="import torch\nfrom diffusers import DiffusionPipeline",
+ ):
+ with patch(
+ "helpers.publishing.metadata._model_load", return_value="pipeline = ..."
+ ):
+ with patch(
+ "helpers.publishing.metadata._torch_device", return_value="'cuda'"
+ ):
+ with patch(
+ "helpers.publishing.metadata._negative_prompt",
+ return_value="negative_prompt = 'A negative prompt'",
+ ):
+ with patch(
+ "helpers.publishing.metadata._validation_resolution",
+ return_value="width=512,\n height=512,",
+ ):
+ output = code_example(self.args)
+ self.assertIn("import torch", output)
+ self.assertIn("pipeline = ...", output)
+ self.assertIn("pipeline.to('cuda')", output)
+
+ def test_model_type(self):
+ self.args.model_type = "lora"
+ self.args.lora_type = "standard"
+ output = model_type(self.args)
+ self.assertEqual(output, "standard PEFT LoRA")
+
+ self.args.lora_type = "lycoris"
+ output = model_type(self.args)
+ self.assertEqual(output, "LyCORIS adapter")
+
+ self.args.model_type = "full"
+ output = model_type(self.args)
+ self.assertEqual(output, "full rank finetune")
+
+ def test_lora_info(self):
+ self.args.model_type = "lora"
+ self.args.lora_type = "standard"
+ output = lora_info(self.args)
+ self.assertIn("LoRA Rank: 4", output)
+
+ self.args.lora_type = "lycoris"
+ # Mocking the file reading
+ lycoris_config = {"key": "value"}
+ with patch(
+ "builtins.open",
+ unittest.mock.mock_open(read_data=json.dumps(lycoris_config)),
+ ):
+ output = lora_info(self.args)
+ self.assertIn('"key": "value"', output)
+
+ def test_model_card_note(self):
+ output = model_card_note(self.args)
+ self.assertIn("Test note", output)
+
+ self.args.model_card_note = ""
+ output = model_card_note(self.args)
+ self.assertEqual(output.strip(), "")
+
+ def test_flux_schedule_info(self):
+ self.args.model_family = "flux"
+ output = flux_schedule_info(self.args)
+ self.assertIn("(no special parameters set)", output)
+
+ self.args.flux_fast_schedule = True
+ output = flux_schedule_info(self.args)
+ self.assertIn("flux_fast_schedule", output)
+
+ def test_sd3_schedule_info(self):
+ self.args.model_family = "sd3"
+ output = sd3_schedule_info(self.args)
+ self.assertIn("(no special parameters set)", output)
+
+ self.args.flux_schedule_auto_shift = True
+ output = sd3_schedule_info(self.args)
+ self.assertIn("flux_schedule_auto_shift", output)
+
+ def test_model_schedule_info(self):
+ with patch(
+ "helpers.publishing.metadata.flux_schedule_info", return_value="flux info"
+ ):
+ with patch(
+ "helpers.publishing.metadata.sd3_schedule_info", return_value="sd3 info"
+ ):
+ self.args.model_family = "flux"
+ output = model_schedule_info(self.args)
+ self.assertEqual(output, "flux info")
+
+ self.args.model_family = "sd3"
+ output = model_schedule_info(self.args)
+ self.assertEqual(output, "sd3 info")
+
+ def test_save_model_card(self):
+ # Mocking StateTracker methods
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_model_family",
+ return_value="sdxl",
+ ):
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_data_backends",
+ return_value={},
+ ):
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_epoch", return_value=1
+ ):
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_global_step",
+ return_value=1000,
+ ):
+ with patch(
+ "helpers.publishing.metadata.StateTracker.get_accelerator",
+ return_value=MagicMock(num_processes=1),
+ ):
+ with patch(
+ "helpers.publishing.metadata.code_example",
+ return_value="code example",
+ ):
+ with patch(
+ "builtins.open", unittest.mock.mock_open()
+ ) as mock_file:
+ save_model_card(
+ repo_id="test-repo",
+ images=None,
+ base_model="test-base-model",
+ train_text_encoder=True,
+ prompt="Test prompt",
+ validation_prompts=["Test prompt"],
+ validation_shortnames=["shortname"],
+ repo_folder="test-folder",
+ )
+ # Ensure the README.md was written
+ mock_file.assert_called_with(
+ os.path.join("test-folder", "README.md"),
+ "w",
+ encoding="utf-8",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
index 5d93cf18..c54789d6 100644
--- a/tests/test_trainer.py
+++ b/tests/test_trainer.py
@@ -139,6 +139,7 @@ def test_stats_memory_used_none(
output_dir="output_dir",
flux_schedule_shift=3,
flux_schedule_auto_shift=False,
+ validation_guidance_skip_layers=None,
),
)
def test_misc_init(