From 21f149938bbf0f57b0119944e3865ebe5feace07 Mon Sep 17 00:00:00 2001 From: Jimmy <39@🇺🇸.com> Date: Tue, 29 Oct 2024 20:36:20 -0400 Subject: [PATCH 01/27] Fix multi-caption parquets crashing in multiple locations (Closes #1092) --- helpers/data_backend/factory.py | 99 +++++++++++++++++++++++----- helpers/metadata/backends/parquet.py | 12 ++-- helpers/prompts.py | 30 +++++---- tests/test_dataset.py | 59 +++++++++++++++++ 4 files changed, 165 insertions(+), 35 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index dba8d7c4..8b3fbf2f 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -24,6 +24,8 @@ from tqdm import tqdm import queue from math import sqrt +import pandas as pd +import numpy as np logger = logging.getLogger("DataBackendFactory") if should_log(): @@ -48,6 +50,68 @@ def info_log(message): logger.info(message) +def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False): + # Determine if the column contains arrays or scalar values + non_null_values = column_data.dropna() + if non_null_values.empty: + # All values are null + raise ValueError( + f"Parquet file {parquet_path} contains only null values in the '{column_name}' column." + ) + + first_non_null = non_null_values.iloc[0] + if isinstance(first_non_null, (list, tuple, np.ndarray, pd.Series)): + # Column contains arrays + # Check for null arrays + if column_data.isnull().any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null arrays in the '{column_name}' column." + ) + + # Check for empty arrays + empty_arrays = column_data.apply(lambda x: len(x) == 0) + if empty_arrays.any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains empty arrays in the '{column_name}' column." + ) + + # Check for null elements within arrays + null_elements_in_arrays = column_data.apply( + lambda arr: any(pd.isnull(s) for s in arr) + ) + if null_elements_in_arrays.any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null values within arrays in the '{column_name}' column." + ) + + # Check for empty strings within arrays + empty_strings_in_arrays = column_data.apply( + lambda arr: any(s == "" for s in arr) + ) + if empty_strings_in_arrays.all() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains only empty strings within arrays in the '{column_name}' column." + ) + + elif isinstance(first_non_null, str): + # Column contains scalar strings + # Check for null values + if column_data.isnull().any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains null values in the '{column_name}' column." + ) + + # Check for empty strings + if (column_data == "").any() and not fallback_caption_column: + raise ValueError( + f"Parquet file {parquet_path} contains empty strings in the '{column_name}' column." + ) + else: + raise TypeError( + f"Unsupported data type in column '{column_name}'. Expected strings or arrays of strings." + ) + + def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output = {"id": backend["id"], "config": {}} if backend.get("dataset_type", None) == "text_embeds": @@ -292,24 +356,23 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken raise ValueError( f"Parquet file {parquet_path} does not contain a column named '{filename_column}'." ) - # Check for null values - if df[caption_column].isnull().values.any() and not fallback_caption_column: - raise ValueError( - f"Parquet file {parquet_path} contains null values in the '{caption_column}' column, but no fallback_caption_column was set." - ) - if df[filename_column].isnull().values.any(): - raise ValueError( - f"Parquet file {parquet_path} contains null values in the '{filename_column}' column." - ) - # Check for empty strings - if (df[caption_column] == "").sum() > 0 and not fallback_caption_column: - raise ValueError( - f"Parquet file {parquet_path} contains empty strings in the '{caption_column}' column." - ) - if (df[filename_column] == "").sum() > 0: - raise ValueError( - f"Parquet file {parquet_path} contains empty strings in the '{filename_column}' column." - ) + + # Apply the function to the caption_column. + check_column_values( + df[caption_column], + caption_column, + parquet_path, + fallback_caption_column=fallback_caption_column + ) + + # Apply the function to the filename_column. + check_column_values( + df[filename_column], + filename_column, + parquet_path, + fallback_caption_column=False # Always check filename_column + ) + # Store the database in StateTracker StateTracker.set_parquet_database( backend["id"], diff --git a/helpers/metadata/backends/parquet.py b/helpers/metadata/backends/parquet.py index 584cd9c0..2850f986 100644 --- a/helpers/metadata/backends/parquet.py +++ b/helpers/metadata/backends/parquet.py @@ -150,11 +150,13 @@ def _extract_captions_to_fast_list(self): if len(caption_column) > 0: caption = [row[c] for c in caption_column] else: - caption = row[caption_column] + caption = row.get(caption_column) + if isinstance(caption, (numpy.ndarray, pd.Series)): + caption = [str(item) for item in caption if item is not None] - if not caption and fallback_caption_column: - caption = row[fallback_caption_column] - if not caption: + if caption is None and fallback_caption_column: + caption = row.get(fallback_caption_column, None) + if caption is None or caption == "" or caption == []: raise ValueError( f"Could not locate caption for image {filename} in sampler_backend {self.id} with filename column {filename_column}, caption column {caption_column}, and a parquet database with {len(self.parquet_database)} entries." ) @@ -162,7 +164,7 @@ def _extract_captions_to_fast_list(self): caption = caption.decode("utf-8") elif type(caption) == list: caption = [c.strip() for c in caption if c.strip()] - if caption: + elif type(caption) == str: caption = caption.strip() captions[filename] = caption return captions diff --git a/helpers/prompts.py b/helpers/prompts.py index bac03d77..448f0e79 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -5,6 +5,13 @@ from helpers.training.multi_process import _get_rank as get_rank from helpers.training import image_file_extensions +import numpy + +try: + import pandas as pd +except ImportError: + raise ImportError("Pandas is required for the ParquetMetadataBackend.") + prompts = { "alien_landscape": "Alien planet, strange rock formations, glowing plants, bizarre creatures, surreal atmosphere", "alien_market": "Alien marketplace, bizarre creatures, exotic goods, vibrant colors, otherworldly atmosphere", @@ -256,8 +263,10 @@ def prepare_instance_prompt_from_parquet( ) if type(image_caption) == bytes: image_caption = image_caption.decode("utf-8") - if image_caption: + if type(image_caption) == str: image_caption = image_caption.strip() + if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series): + image_caption = [str(item).strip() for item in image_caption if item is not None] if prepend_instance_prompt: if type(image_caption) == list: image_caption = [instance_prompt + " " + x for x in image_caption] @@ -436,17 +445,14 @@ def get_all_captions( data_backend=data_backend, ) elif caption_strategy == "parquet": - try: - caption = PromptHandler.prepare_instance_prompt_from_parquet( - image_path, - use_captions=use_captions, - prepend_instance_prompt=prepend_instance_prompt, - instance_prompt=instance_prompt, - data_backend=data_backend, - sampler_backend_id=data_backend.id, - ) - except: - continue + caption = PromptHandler.prepare_instance_prompt_from_parquet( + image_path, + use_captions=use_captions, + prepend_instance_prompt=prepend_instance_prompt, + instance_prompt=instance_prompt, + data_backend=data_backend, + sampler_backend_id=data_backend.id, + ) elif caption_strategy == "instanceprompt": return [instance_prompt] elif caption_strategy == "csv": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index aa6d5d26..7ae433b5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,10 +1,12 @@ import unittest +import pandas as pd from unittest.mock import patch, Mock, MagicMock from PIL import Image from pathlib import Path from helpers.multiaspect.dataset import MultiAspectDataset from helpers.metadata.backends.discovery import DiscoveryMetadataBackend from helpers.data_backend.base import BaseDataBackend +from helpers.data_backend.factory import check_column_values class TestMultiAspectDataset(unittest.TestCase): @@ -82,5 +84,62 @@ def test_getitem_invalid_image(self): self.dataset.__getitem__(self.image_metadata) +class TestDataBackendFactory(unittest.TestCase): + def test_all_null(self): + column_data = pd.Series([None, None, None]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains only null values", str(context.exception)) + + def test_arrays_with_nulls(self): + column_data = pd.Series([[1, 2], None, [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null arrays", str(context.exception)) + + def test_empty_arrays(self): + column_data = pd.Series([[1, 2], [], [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains empty arrays", str(context.exception)) + + def test_null_elements_in_arrays(self): + column_data = pd.Series([[1, None], [2, 3], [3, 4]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null values within arrays", str(context.exception)) + + def test_empty_strings_in_arrays(self): + column_data = pd.Series([["", ""], ["", ""], ["", ""]]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains only empty strings within arrays", str(context.exception)) + + def test_scalar_strings_with_nulls(self): + column_data = pd.Series(["a", None, "b"]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains null values", str(context.exception)) + + def test_scalar_strings_with_empty(self): + column_data = pd.Series(["a", "", "b"]) + with self.assertRaises(ValueError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("contains empty strings", str(context.exception)) + + def test_with_fallback_caption(self): + column_data = pd.Series([None, "", [None], [""]]) + try: + check_column_values(column_data, "test_column", "test_file.parquet", fallback_caption_column=True) + except ValueError: + self.fail("check_column_values() raised ValueError unexpectedly with fallback_caption_column=True") + + def test_invalid_data_type(self): + column_data = pd.Series([1, 2, 3]) + with self.assertRaises(TypeError) as context: + check_column_values(column_data, "test_column", "test_file.parquet") + self.assertIn("Unsupported data type in column", str(context.exception)) + + if __name__ == "__main__": unittest.main() From 48cfc09db92ff64281ab7eb6ca9e41f59bb9e23f Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 6 Nov 2024 21:24:02 -0600 Subject: [PATCH 02/27] sd3: add skip layer guidance --- helpers/configuration/cmd_args.py | 8 + helpers/models/sd3/pipeline.py | 297 +++++++++++++++++++----------- helpers/training/validation.py | 8 + 3 files changed, 209 insertions(+), 104 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ee3b4f79..54cb693a 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1350,6 +1350,14 @@ def get_argument_parser(): " the default mode, provides the most benefit." ), ) + parser.add_argument( + "--validation_guidance_skip_layers", + type=list, + default=None, + help=( + "StabilityAI recommends a value of [7, 8, 9] for Stable Diffusion 3.5 Medium." + ), + ) parser.add_argument( "--allow_tf32", action="store_true", diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index 1bf3332f..a91aa784 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -29,9 +29,12 @@ from diffusers.models.transformers import SD3Transformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( + USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -39,6 +42,7 @@ StableDiffusion3PipelineOutput, ) +from diffusers.image_processor import PipelineImageInput if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -76,7 +80,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -221,11 +225,17 @@ def __init__( if hasattr(self, "transformer") and self.transformer is not None else 128 ) + self.patch_size = ( + self.transformer.config.patch_size + if hasattr(self, "transformer") and self.transformer is not None + else 2 + ) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, + max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -238,7 +248,7 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( ( - batch_size, + batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), @@ -249,7 +259,7 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer_3( prompt, padding="max_length", - max_length=self.tokenizer_max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -266,8 +276,8 @@ def _get_t5_prompt_embeds( untruncated_ids[:, self.tokenizer_max_length - 1 : -1] ) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] @@ -368,6 +378,8 @@ def encode_prompt( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" @@ -413,9 +425,22 @@ def encode_prompt( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -448,6 +473,7 @@ def encode_prompt( t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -520,6 +546,7 @@ def encode_prompt( t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -539,6 +566,16 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + return ( prompt_embeds, negative_prompt_embeds, @@ -561,10 +598,15 @@ def check_inputs( pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -647,6 +689,11 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError( + f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" + ) + def prepare_latents( self, batch_size, @@ -733,6 +780,11 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + skip_guidance_layers: List[int] = None, + skip_layer_guidance_scale: int = 2.8, + skip_layer_guidance_stop: int = 0.2, + skip_layer_guidance_start: int = 0.01, ): r""" Function invoked when calling the pipeline for generation. @@ -758,7 +810,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 5.0): + guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -801,8 +853,8 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -816,12 +868,29 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance. + If not provided, all layers will be used for guidance. If provided, the guidance will only be applied + to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is + [7, 8, 9]. + skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in + `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` + with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with + a scale of `1`. + skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in + `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. + skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will start. The guidance will be applied to the layers specified in + `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. Examples: Returns: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ @@ -843,9 +912,11 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False @@ -860,6 +931,11 @@ def __call__( device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) + if self.joint_attention_kwargs is not None + else None + ) ( prompt_embeds, negative_prompt_embeds, @@ -880,9 +956,15 @@ def __call__( device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if self.do_classifier_free_guidance: + if skip_guidance_layers is not None: + original_prompt_embeds = prompt_embeds + original_pooled_prompt_embeds = pooled_prompt_embeds + # we do not combine the inference if we skip guidance layers. prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 @@ -909,8 +991,6 @@ def __call__( generator, latents, ) - latents = latents.to(self.transformer.device) - timesteps = timesteps.to(self.transformer.device) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -921,7 +1001,7 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) - if self.do_classifier_free_guidance + if self.do_classifier_free_guidance and skip_guidance_layers is None else latents ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -948,6 +1028,36 @@ def __call__( noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) + should_skip_layers = ( + True + if i > num_inference_steps * skip_layer_guidance_start + and i < num_inference_steps * skip_layer_guidance_stop + else False + ) + if skip_guidance_layers is not None and should_skip_layers: + noise_pred_skip_layers = self.transformer( + hidden_states=latent_model_input.to( + device=self.transformer.device, + dtype=self.transformer.dtype, + ), + timestep=timestep, + encoder_hidden_states=original_prompt_embeds.to( + device=self.transformer.device, + dtype=self.transformer.dtype, + ), + pooled_projections=original_pooled_prompt_embeds.to( + device=self.transformer.device, + dtype=self.transformer.dtype, + ), + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + skip_layers=skip_guidance_layers, + )[0] + noise_pred = ( + noise_pred + + (noise_pred_text - noise_pred_skip_layers) + * self._skip_layer_guidance_scale + ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1004,84 +1114,9 @@ def __call__( return StableDiffusion3PipelineOutput(images=image) -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Dict, List, Optional, Union - -import PIL.Image -import torch -from transformers import ( - CLIPTextModelWithProjection, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from diffusers.image_processor import PipelineImageInput - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - - >>> from diffusers import AutoPipelineForImage2Image - >>> from diffusers.utils import load_image - - >>> device = "cuda" - >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" - >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16) - >>> pipe = pipe.to(device) - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - >>> init_image = load_image(url).resize((512, 512)) - - >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" - - >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0] - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, - generator: Optional[torch.Generator] = None, - sample_mode: str = "sample", +class StableDiffusion3Img2ImgPipeline( + DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -1164,6 +1199,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, + max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -1176,7 +1212,7 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( ( - batch_size, + batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), @@ -1187,7 +1223,7 @@ def _get_t5_prompt_embeds( text_inputs = self.tokenizer_3( prompt, padding="max_length", - max_length=self.tokenizer_max_length, + max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", @@ -1204,8 +1240,8 @@ def _get_t5_prompt_embeds( untruncated_ids[:, self.tokenizer_max_length - 1 : -1] ) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] @@ -1308,6 +1344,8 @@ def encode_prompt( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" @@ -1353,9 +1391,22 @@ def encode_prompt( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) @@ -1388,6 +1439,7 @@ def encode_prompt( t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -1460,6 +1512,7 @@ def encode_prompt( t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -1479,6 +1532,16 @@ def encode_prompt( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + return ( prompt_embeds, negative_prompt_embeds, @@ -1500,6 +1563,7 @@ def check_inputs( pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError( @@ -1586,6 +1650,11 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError( + f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" + ) + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -1613,8 +1682,6 @@ def prepare_latents( ) image = image.to(device=device, dtype=dtype) - if image.shape[1] == self.vae.config.latent_channels: - init_latents = image batch_size = batch_size * num_images_per_prompt if image.shape[1] == self.vae.config.latent_channels: @@ -1676,6 +1743,10 @@ def prepare_latents( def guidance_scale(self): return self._guidance_scale + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + @property def clip_skip(self): return self._clip_skip @@ -1719,9 +1790,11 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. @@ -1747,7 +1820,7 @@ def __call__( Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 5.0): + guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -1790,8 +1863,12 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -1801,12 +1878,13 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ @@ -1824,10 +1902,12 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters @@ -1840,6 +1920,12 @@ def __call__( device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) + if self.joint_attention_kwargs is not None + else None + ) + ( prompt_embeds, negative_prompt_embeds, @@ -1860,6 +1946,8 @@ def __call__( device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if self.do_classifier_free_guidance: @@ -1878,7 +1966,7 @@ def __call__( timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device ) - latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables if latents is None: @@ -1916,6 +2004,7 @@ def __call__( timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] diff --git a/helpers/training/validation.py b/helpers/training/validation.py index b9406a77..88c4498b 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1192,6 +1192,14 @@ def validate_prompt( else: validation_resolution_width, validation_resolution_height = resolution + if ( + self.args.model_family == "sd3" + and type(self.args.validation_guidance_skip_layers) is list + ): + extra_validation_kwargs["skip_guidance_layers"] = list( + self.args.validation_guidance_skip_layers + ) + if not self.flow_matching and self.args.model_family not in [ "deepfloyd", "pixart_sigma", From f0aa07ec9fd3cf07d59940e3ceee921e261d8707 Mon Sep 17 00:00:00 2001 From: Yannik Date: Thu, 7 Nov 2024 15:22:56 +0100 Subject: [PATCH 03/27] Add WSL support --- Dockerfile | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Dockerfile b/Dockerfile index eb9b0f43..3b2e60e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,12 @@ RUN apt-get update -y # on user input during build ENV DEBIAN_FRONTEND noninteractive +# Install libg dependencies +RUN apt install libgl1-mesa-glx -y +RUN apt-get install 'ffmpeg'\ + 'libsm6'\ + 'libxext6' -y + # Install misc unix libraries RUN apt-get install -y --no-install-recommends openssh-server \ openssh-client \ From d664101e074fb83c38464376d2ca00d5d5838bdd Mon Sep 17 00:00:00 2001 From: Yannik Date: Thu, 7 Nov 2024 16:24:59 +0100 Subject: [PATCH 04/27] Updated the docker installation guide --- INSTALL.md | 1 + documentation/DOCKER.md | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/INSTALL.md b/INSTALL.md index db15a629..0938d3a5 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -3,6 +3,7 @@ For users that wish to make use of Docker or another container orchestration platform, see [this document](/documentation/DOCKER.md) first. ### Installation +For users operating on windows. An Installation guide based on Docker and WSL is found here [this document](/documentation/DOCKER.md). Clone the SimpleTuner repository and set up the python venv: diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index 0b39c7c3..a3bdb17b 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -11,6 +11,11 @@ This Docker configuration provides a comprehensive environment for running the S ## Getting Started +### 0. Things to keep in mind when running on windows (WSL) + +The following guide was tested in a WSL2 Distro that has the Dockerengine installed and was also contained the SimpleTuner repository. + + ### 1. Building the Container Clone the repository and navigate to the directory containing the Dockerfile. Build the Docker image using: @@ -68,6 +73,43 @@ If you want to add custom startup scripts or modify configurations, extend the e If any capabilities cannot be achieved through this setup, please open a new issue. +### Docker Compose +If you want to use a `docker-compose.yaml` feel free to oriantate on the following template. + +Once the stack is deployed you can connect to the container and start operating in it as mentioned in the steps above. + +```bash +docker compose up -d + +docker exec -it simpletuner /bin/bash +``` + +```docker-compose.yaml +services: + simpletuner: + container_name: simpletuner + build: + context: [Path to the repository]/SimpleTuner + dockerfile: Dockerfile + ports: + - "[port to connect to the container]:22" + volumes: + - "[path to your datasets]:/datasets" + - "[path to your configs]:/workspace/SimpleTuner/config" + environment: + HUGGING_FACE_HUB_TOKEN: [your hugging face token] + WANDB_TOKEN: [your wanddb token] + command: ["tail", "-f", "/dev/null"] + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] +``` + + --- ## Troubleshooting From 810c58e1f45c8b789becb51dd1ec556b379bb174 Mon Sep 17 00:00:00 2001 From: Yannik Date: Thu, 7 Nov 2024 16:26:38 +0100 Subject: [PATCH 05/27] fixing a typo --- documentation/DOCKER.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index a3bdb17b..907042ef 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -13,7 +13,7 @@ This Docker configuration provides a comprehensive environment for running the S ### 0. Things to keep in mind when running on windows (WSL) -The following guide was tested in a WSL2 Distro that has the Dockerengine installed and was also contained the SimpleTuner repository. +The following guide was tested in a WSL2 Distro that has the Dockerengine installed and also contained the SimpleTuner repository. ### 1. Building the Container From 716e669d89c2d76326592db3548f28ea80cf35bb Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:51:11 -0600 Subject: [PATCH 06/27] Update INSTALL.md --- INSTALL.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/INSTALL.md b/INSTALL.md index 0938d3a5..a4ed4716 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -3,7 +3,8 @@ For users that wish to make use of Docker or another container orchestration platform, see [this document](/documentation/DOCKER.md) first. ### Installation -For users operating on windows. An Installation guide based on Docker and WSL is found here [this document](/documentation/DOCKER.md). + +For users operating on Windows 10 or newer, an installation guide based on Docker and WSL is available here [this document](/documentation/DOCKER.md). Clone the SimpleTuner repository and set up the python venv: From 48dd672dfdef15cc2dd6ba7c69544e47e756fe7c Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:52:12 -0600 Subject: [PATCH 07/27] Update documentation/DOCKER.md --- documentation/DOCKER.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index 907042ef..d2737a41 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -11,7 +11,7 @@ This Docker configuration provides a comprehensive environment for running the S ## Getting Started -### 0. Things to keep in mind when running on windows (WSL) +### Windows OS support via WSL (Experimental) The following guide was tested in a WSL2 Distro that has the Dockerengine installed and also contained the SimpleTuner repository. From c3e67e72efa9877be501f30e7a764db82cbcb57d Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:52:59 -0600 Subject: [PATCH 08/27] Update documentation/DOCKER.md --- documentation/DOCKER.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index d2737a41..0d8ed706 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -13,7 +13,7 @@ This Docker configuration provides a comprehensive environment for running the S ### Windows OS support via WSL (Experimental) -The following guide was tested in a WSL2 Distro that has the Dockerengine installed and also contained the SimpleTuner repository. +The following guide was tested in a WSL2 Distro that has Dockerengine installed. ### 1. Building the Container From dc21d1be9057a87ff20bec3e16eada851b7f0389 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:53:57 -0600 Subject: [PATCH 09/27] Update documentation/DOCKER.md --- documentation/DOCKER.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index 0d8ed706..2c7b9c07 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -74,7 +74,8 @@ If you want to add custom startup scripts or modify configurations, extend the e If any capabilities cannot be achieved through this setup, please open a new issue. ### Docker Compose -If you want to use a `docker-compose.yaml` feel free to oriantate on the following template. + +For users who prefer `docker-compose.yaml`, this template is provided for you to extend and customise for your needs. Once the stack is deployed you can connect to the container and start operating in it as mentioned in the steps above. From 3eafaeaedc93227c99fe3b572be018075d51bdb4 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:56:31 -0600 Subject: [PATCH 10/27] Update documentation/DOCKER.md --- documentation/DOCKER.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/DOCKER.md b/documentation/DOCKER.md index 2c7b9c07..e2eb4bbb 100644 --- a/documentation/DOCKER.md +++ b/documentation/DOCKER.md @@ -110,7 +110,7 @@ services: capabilities: [gpu] ``` - +> ⚠️ Please be cautious of handling your WandB and Hugging Face tokens! It's advised not to commit them even to a private version-control repository to ensure they are not leaked. For production use-cases, key management storage is recommended, but out of scope for this guide. --- ## Troubleshooting From 2167674947be9c92bfe3205afc1564edb52a3954 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 13:53:11 -0600 Subject: [PATCH 11/27] add custom sd3 transformer modeling code --- helpers/models/omnigen/pipeline.py | 367 ++++++++++++++++++++++ helpers/models/sd3/transformer.py | 465 ++++++++++++++++++++++++++++ helpers/training/adapter.py | 6 +- helpers/training/diffusion_model.py | 2 +- helpers/training/save_hooks.py | 2 +- 5 files changed, 839 insertions(+), 3 deletions(-) create mode 100644 helpers/models/omnigen/pipeline.py create mode 100644 helpers/models/sd3/transformer.py diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py new file mode 100644 index 00000000..fadbaf05 --- /dev/null +++ b/helpers/models/omnigen/pipeline.py @@ -0,0 +1,367 @@ +import os +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import gc + +from PIL import Image +import numpy as np +import torch +from huggingface_hub import snapshot_download +from peft import LoraConfig, PeftModel +from diffusers.models import AutoencoderKL +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from safetensors.torch import load_file + +from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from OmniGen import OmniGenPipeline + >>> pipe = FluxControlNetPipeline.from_pretrained( + ... base_model + ... ) + >>> prompt = "A woman holds a bouquet of flowers and faces the camera" + >>> image = pipe( + ... prompt, + ... guidance_scale=2.5, + ... num_inference_steps=50, + ... ).images[0] + >>> image.save("t2i.png") + ``` +""" + + +90 + + +class OmniGenPipeline: + def __init__( + self, + vae: AutoencoderKL, + model: OmniGen, + processor: OmniGenProcessor, + device: Union[str, torch.device], + ): + self.vae = vae + self.model = model + self.processor = processor + self.device = device + + self.model.to(torch.bfloat16) + self.model.eval() + self.vae.eval() + + self.model_cpu_offload = False + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, vae_path: str = None, **kwargs + ): + if not os.path.exists(pretrained_model_name_or_path) or ( + not os.path.exists( + os.path.join(pretrained_model_name_or_path, "model.safetensors") + ) + and pretrained_model_name_or_path == "Shitao/OmniGen-v1" + ): + logger.info("Model not found, downloading...") + cache_folder = os.getenv("HF_HUB_CACHE") + pretrained_model_name_or_path = snapshot_download( + repo_id=pretrained_model_name_or_path, + cache_dir=cache_folder, + ignore_patterns=[ + "flax_model.msgpack", + "rust_model.ot", + "tf_model.h5", + "model.pt", + ], + ) + logger.info(f"Downloaded model to {pretrained_model_name_or_path}") + model = OmniGen.from_pretrained(pretrained_model_name_or_path) + processor = OmniGenProcessor.from_pretrained(pretrained_model_name_or_path) + + if os.path.exists(os.path.join(pretrained_model_name_or_path, "vae")): + vae = AutoencoderKL.from_pretrained( + os.path.join(pretrained_model_name_or_path, "vae") + ) + elif vae_path is not None: + vae = AutoencoderKL.from_pretrained(vae_path).to( + StateTracker.get_accelerator().device + ) + else: + logger.info( + f"No VAE found in {pretrained_model_name_or_path}, downloading stabilityai/sdxl-vae from HF" + ) + vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to( + StateTracker.get_accelerator().device + ) + + print(f"OmniGenPipeline received unexpected arguments: {kwargs.keys()}") + + return cls(vae, model, processor) + + def merge_lora(self, lora_path: str): + model = PeftModel.from_pretrained(self.model, lora_path) + model.merge_and_unload() + + self.model = model + + def to(self, device: Union[str, torch.device]): + if isinstance(device, str): + device = torch.device(device) + self.model.to(device) + self.vae.to(device) + self.device = device + + def vae_encode(self, x, dtype): + if self.vae.config.shift_factor is not None: + x = self.vae.encode(x).latent_dist.sample() + x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + x = ( + self.vae.encode(x) + .latent_dist.sample() + .mul_(self.vae.config.scaling_factor) + ) + x = x.to(dtype) + return x + + def move_to_device(self, data): + if isinstance(data, list): + return [x.to(self.device) for x in data] + return data.to(self.device) + + def enable_model_cpu_offload(self): + self.model_cpu_offload = True + self.model.to("cpu") + self.vae.to("cpu") + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + def disable_model_cpu_offload(self): + self.model_cpu_offload = False + self.model.to(self.device) + self.vae.to(self.device) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + input_images: Union[List[str], List[List[str]]] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3, + use_img_guidance: bool = True, + img_guidance_scale: float = 1.6, + max_input_image_size: int = 1024, + separate_cfg_infer: bool = True, + offload_model: bool = False, + use_kv_cache: bool = True, + offload_kv_cache: bool = True, + use_input_image_size_as_output: bool = False, + dtype: torch.dtype = torch.bfloat16, + seed: int = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + input_images (`List[str]` or `List[List[str]]`, *optional*): + The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. The number must be a multiple of 16. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. The number must be a multiple of 16. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + use_img_guidance (`bool`, *optional*, defaults to True): + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + img_guidance_scale (`float`, *optional*, defaults to 1.6): + Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800). + max_input_image_size (`int`, *optional*, defaults to 1024): the maximum size of input image, which will be used to crop the input image to the maximum size + separate_cfg_infer (`bool`, *optional*, defaults to False): + Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference. + use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference + offload_kv_cache (`bool`, *optional*, defaults to True): offload the cached key and value to cpu, which can save memory but slow down the generation silightly + offload_model (`bool`, *optional*, defaults to False): offload the model to cpu, which can save memory but slow down the generation + use_input_image_size_as_output (bool, defaults to False): whether to use the input image size as the output image size, which can be used for single-image input, e.g., image editing task + seed (`int`, *optional*): + A random seed for generating output. + dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + data type for the model + Examples: + + Returns: + A list with the generated images. + """ + # check inputs: + if use_input_image_size_as_output: + assert ( + isinstance(prompt, str) and len(input_images) == 1 + ), "if you want to make sure the output image have the same size as the input image, please only input one image instead of multiple input images" + else: + assert ( + height % 16 == 0 and width % 16 == 0 + ), "The height and width must be a multiple of 16." + if input_images is None: + use_img_guidance = False + if isinstance(prompt, str): + prompt = [prompt] + input_images = [input_images] if input_images is not None else None + + # set model and processor + if max_input_image_size != self.processor.max_image_size: + self.processor = OmniGenProcessor( + self.processor.text_tokenizer, max_image_size=max_input_image_size + ) + if offload_model: + self.enable_model_cpu_offload() + else: + self.disable_model_cpu_offload() + + input_data = self.processor( + prompt, + input_images, + height=height, + width=width, + use_img_cfg=use_img_guidance, + separate_cfg_input=separate_cfg_infer, + use_input_image_size_as_output=use_input_image_size_as_output, + ) + print(f"Input shapes: {input_data['attention_mask'][0].shape}") + + num_prompt = len(prompt) + num_cfg = 2 if use_img_guidance else 1 + if use_input_image_size_as_output: + if separate_cfg_infer: + height, width = input_data["input_pixel_values"][0][0].shape[-2:] + else: + height, width = input_data["input_pixel_values"][0].shape[-2:] + latent_size_h, latent_size_w = height // 8, width // 8 + + if seed is not None: + generator = torch.Generator(device=self.device).manual_seed(seed) + else: + generator = None + latents = torch.randn( + num_prompt, + 4, + latent_size_h, + latent_size_w, + device=self.device, + generator=generator, + ) + latents = torch.cat([latents] * (1 + num_cfg), 0).to(dtype) + + if input_images is not None and self.model_cpu_offload: + self.vae.to(self.device) + input_img_latents = [] + if separate_cfg_infer: + for temp_pixel_values in input_data["input_pixel_values"]: + temp_input_latents = [] + for img in temp_pixel_values: + img = self.vae_encode(img.to(self.device), dtype) + temp_input_latents.append(img) + input_img_latents.append(temp_input_latents) + else: + for img in input_data["input_pixel_values"]: + img = self.vae_encode(img.to(self.device), dtype) + input_img_latents.append(img) + if input_images is not None and self.model_cpu_offload: + self.vae.to("cpu") + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + + model_kwargs = dict( + input_ids=self.move_to_device(input_data["input_ids"]), + input_img_latents=input_img_latents, + input_image_sizes=input_data["input_image_sizes"], + attention_mask=self.move_to_device(input_data["attention_mask"]), + position_ids=self.move_to_device(input_data["position_ids"]), + cfg_scale=guidance_scale, + img_cfg_scale=img_guidance_scale, + use_img_cfg=use_img_guidance, + use_kv_cache=use_kv_cache, + offload_model=offload_model, + ) + + if separate_cfg_infer: + func = self.model.forward_with_separate_cfg + else: + func = self.model.forward_with_cfg + self.model.to(dtype) + + if self.model_cpu_offload: + for name, param in self.model.named_parameters(): + if "layers" in name and "layers.0" not in name: + param.data = param.data.cpu() + else: + param.data = param.data.to(self.device) + for buffer_name, buffer in self.model.named_buffers(): + setattr(self.model, buffer_name, buffer.to(self.device)) + # else: + # self.model.to(self.device) + + scheduler = OmniGenScheduler(num_steps=num_inference_steps) + samples = scheduler( + latents, + func, + model_kwargs, + use_kv_cache=use_kv_cache, + offload_kv_cache=offload_kv_cache, + ) + samples = samples.chunk((1 + num_cfg), dim=0)[0] + + if self.model_cpu_offload: + self.model.to("cpu") + torch.cuda.empty_cache() + gc.collect() + + self.vae.to(self.device) + samples = samples.to(torch.float32) + if self.vae.config.shift_factor is not None: + samples = ( + samples / self.vae.config.scaling_factor + self.vae.config.shift_factor + ) + else: + samples = samples / self.vae.config.scaling_factor + samples = self.vae.decode( + samples.to(dtype=self.vae.dtype, device=self.vae.device) + ).sample + + if self.model_cpu_offload: + self.vae.to("cpu") + torch.cuda.empty_cache() + gc.collect() + + output_samples = (samples * 0.5 + 0.5).clamp(0, 1) * 255 + output_samples = ( + output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() + ) + output_images = [] + for i, sample in enumerate(output_samples): + output_images.append(Image.fromarray(sample)) + + torch.cuda.empty_cache() # Clear VRAM + gc.collect() # Run garbage collection to free system RAM + return output_images diff --git a/helpers/models/sd3/transformer.py b/helpers/models/sd3/transformer.py new file mode 100644 index 00000000..943bdb7b --- /dev/null +++ b/helpers/models/sd3/transformer.py @@ -0,0 +1,465 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, + FusedJointAttnProcessor2_0, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.models.embeddings import ( + CombinedTimestepTextProjEmbeddings, + PatchEmbed, +) +from diffusers.models.modeling_outputs import ( + Transformer2DModelOutput, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): + """ + The Transformer model introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + out_channels (`int`, defaults to 16): Number of output channels. + + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + num_layers: int = 18, + attention_head_dim: int = 64, + num_attention_heads: int = 18, + joint_attention_dim: int = 4096, + caption_projection_dim: int = 1152, + pooled_projection_dim: int = 2048, + out_channels: int = 16, + pos_embed_max_size: int = 96, + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, + ): + super().__init__() + default_out_channels = in_channels + self.out_channels = ( + out_channels if out_channels is not None else default_out_channels + ) + self.inner_dim = ( + self.config.num_attention_heads * self.config.attention_head_dim + ) + + self.pos_embed = PatchEmbed( + height=self.config.sample_size, + width=self.config.sample_size, + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=pos_embed_max_size, # hard-code for now. + ) + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.config.pooled_projection_dim, + ) + self.context_embedder = nn.Linear( + self.config.joint_attention_dim, self.config.caption_projection_dim + ) + + # `attention_head_dim` is doubled to account for the mixing. + # It needs to crafted when we get the actual checkpoints. + self.transformer_blocks = nn.ModuleList( + [ + JointTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + context_pre_only=i == num_layers - 1, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, + ) + for i in range(self.config.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) + + self.gradient_checkpointing = False + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking( + self, chunk_size: Optional[int] = None, dim: int = 0 + ) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward( + module: torch.nn.Module, chunk_size: int, dim: int + ): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedJointAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, + block_controlnet_hidden_states: List = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + skip_layers: Optional[List[int]] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + skip_layers (`list` of `int`, *optional*): + A list of layer indices to skip during the forward pass. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed( + hidden_states + ) # takes care of adding positional embeddings too. + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + # Skip specified layers + if skip_layers is not None and index_block in skip_layers: + if ( + block_controlnet_hidden_states is not None + and block.context_pre_only is False + ): + interval_control = len(self.transformer_blocks) // len( + block_controlnet_hidden_states + ) + hidden_states = ( + hidden_states + + block_controlnet_hidden_states[ + index_block // interval_control + ] + ) + continue + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + encoder_hidden_states, hidden_states = ( + torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + ) + + # controlnet residual + if ( + block_controlnet_hidden_states is not None + and block.context_pre_only is False + ): + interval_control = len(self.transformer_blocks) // len( + block_controlnet_hidden_states + ) + hidden_states = ( + hidden_states + + block_controlnet_hidden_states[index_block // interval_control] + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.config.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=( + hidden_states.shape[0], + height, + width, + patch_size, + patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + hidden_states.shape[0], + self.out_channels, + height * patch_size, + width * patch_size, + ) + ) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/helpers/training/adapter.py b/helpers/training/adapter.py index 04b99069..caaebdd8 100644 --- a/helpers/training/adapter.py +++ b/helpers/training/adapter.py @@ -107,7 +107,11 @@ def load_lora_weights(dictionary, filename, loraKey="default", use_dora=False): missing_keys = set( [x + ".lora_A.weight" for x in lora_layers.keys()] + [x + ".lora_B.weight" for x in lora_layers.keys()] - + ([x + ".lora_magnitude_vector.weight"] if use_dora else []) + + ( + [x + ".lora_magnitude_vector.weight" for x in lora_layers.keys()] + if use_dora + else [] + ) ) for k, v in state_dict.items(): if "lora_A" in k: diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 78611dc8..4ac29b30 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -38,7 +38,7 @@ def load_diffusion_model(args, weight_dtype): # Stable Diffusion 3 uses a Diffusion transformer. logger.info("Loading Stable Diffusion 3 diffusion transformer..") try: - from diffusers import SD3Transformer2DModel + from helpers.models.sd3 import SD3Transformer2DModel except Exception as e: logger.error( f"Can not load SD3 model class. This release requires the latest version of Diffusers: {e}" diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 51fb85de..0186d834 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -11,6 +11,7 @@ from helpers.models.sdxl.pipeline import StableDiffusionXLPipeline from helpers.training.state_tracker import StateTracker from helpers.models.smoldit import SmolDiT2DModel, SmolDiTPipeline +from helpers.models.sd3.transformer import SD3Transformer2DModel import os import logging import shutil @@ -27,7 +28,6 @@ from diffusers import ( UNet2DConditionModel, StableDiffusion3Pipeline, - SD3Transformer2DModel, StableDiffusionPipeline, FluxPipeline, PixArtSigmaPipeline, From c9025800c075e721cbf0f8e27ca191aa14abcc6e Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 13:57:37 -0600 Subject: [PATCH 12/27] add custom sd3 transformer modeling code (fix) --- helpers/training/diffusion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 4ac29b30..5ac0c207 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -38,7 +38,7 @@ def load_diffusion_model(args, weight_dtype): # Stable Diffusion 3 uses a Diffusion transformer. logger.info("Loading Stable Diffusion 3 diffusion transformer..") try: - from helpers.models.sd3 import SD3Transformer2DModel + from helpers.models.sd3.transformer import SD3Transformer2DModel except Exception as e: logger.error( f"Can not load SD3 model class. This release requires the latest version of Diffusers: {e}" From e1b0c4fd8bc5f4cde3da01c4185a2cb10379e559 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 14:08:21 -0600 Subject: [PATCH 13/27] use str type and load layers --- helpers/configuration/cmd_args.py | 14 +++++++++++++- tests/test_trainer.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 54cb693a..43e3c69e 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple import random import time +import json import logging import sys import torch @@ -1352,7 +1353,7 @@ def get_argument_parser(): ) parser.add_argument( "--validation_guidance_skip_layers", - type=list, + type=str, default=None, help=( "StabilityAI recommends a value of [7, 8, 9] for Stable Diffusion 3.5 Medium." @@ -2399,4 +2400,15 @@ def parse_cmdline_args(input_args=None): f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" ) + if args.validation_guidance_skip_layers is not None: + try: + import json + + args.validation_guidance_skip_layers = json.loads( + args.validation_guidance_skip_layers + ) + except Exception as e: + logger.error(f"Could not load skip layers: {e}") + raise + return args diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5d93cf18..c54789d6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -139,6 +139,7 @@ def test_stats_memory_used_none( output_dir="output_dir", flux_schedule_shift=3, flux_schedule_auto_shift=False, + validation_guidance_skip_layers=None, ), ) def test_misc_init( From 0695db40ae9644390c4782f9f4506ce3851af920 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 14:22:04 -0600 Subject: [PATCH 14/27] add more configuration values for SLG --- helpers/configuration/cmd_args.py | 23 +++++++++++++++++++++++ helpers/training/validation.py | 9 +++++++++ 2 files changed, 32 insertions(+) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 43e3c69e..c51635df 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1359,6 +1359,29 @@ def get_argument_parser(): "StabilityAI recommends a value of [7, 8, 9] for Stable Diffusion 3.5 Medium." ), ) + parser.add_argument( + "--validation_guidance_skip_layers_start", + type=float, + default=0.01, + help=("StabilityAI recommends a value of 0.01 for SLG start."), + ) + parser.add_argument( + "--validation_guidance_skip_layers_stop", + type=float, + default=0.01, + help=("StabilityAI recommends a value of 0.2 for SLG start."), + ) + parser.add_argument( + "--validation_guidance_skip_scale", + type=float, + default=2.8, + help=( + "StabilityAI recommends a value of 2.8 for SLG guidance skip scaling." + " When adding more layers, you must increase the scale, eg. adding one more layer requires doubling" + " the value given." + ), + ) + parser.add_argument( "--allow_tf32", action="store_true", diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 88c4498b..fd2f8ab6 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1196,6 +1196,15 @@ def validate_prompt( self.args.model_family == "sd3" and type(self.args.validation_guidance_skip_layers) is list ): + extra_validation_kwargs["validation_guidance_skip_layers_start"] = ( + float(args.validation_guidance_skip_layers_start) + ) + extra_validation_kwargs["validation_guidance_skip_layers_stop"] = float( + args.validation_guidance_skip_layers_stop + ) + extra_validation_kwargs["validation_guidance_skip_scale"] = float( + args.validation_guidance_skip_scale + ) extra_validation_kwargs["skip_guidance_layers"] = list( self.args.validation_guidance_skip_layers ) From 0c7dcd5e869100064ca0628ae34cf5cf240f6a55 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 14:23:23 -0600 Subject: [PATCH 15/27] sd3: add skip layer guidance to the quickstart --- documentation/quickstart/SD3.md | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index ece4e96b..20f1272a 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -264,6 +264,26 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu ## Notes & troubleshooting tips +### Skip-layer guidance (SD3.5 Medium) + +StabilityAI recommends enabling SLG (Skip-layer guidance) on SD 3.5 Medium inference. This doesn't impact training results, only the validation sample quality. + +The following values are recommended for `config.json`: + +```json +{ + "--validation_guidance_skip_layers": [7, 8, 9], + "--validation_guidance_skip_layer_start": 0.01, + "--validation_guidance_skip_layer_stop": 0.2, + "--validation_guidance_skip_scale": 2.8, + "--validation_guidance_scale": 4.0 +} +``` + +When adding more layers (eg. increasing to [6, 7, 8, 9]) the SLG scale should be doubled from 2.8 to 5.6 or greater. + +**Lower CFG must be used during inference.** + ### Model instability The SD 3.5 Large 8B model has potential instabilities during training: @@ -288,12 +308,13 @@ Some changes were made to SimpleTuner's SD3.5 support: #### Stable configuration values These options have been known to keep SD3.5 in-tact for as long as possible: -- optimizer=optimi-stableadamw -- learning_rate=1e-5 +- optimizer=adamw_bf16 +- learning_rate=1e-4 - batch_size=4 * 3 GPUs -- max_grad_norm=0.01 +- max_grad_norm=0.1 - base_model_precision=int8-quanto - No loss masking or dataset regularisation, as their contribution to this instability is unknown +- `validation_guidance_skip_layers=[7,8,9]` ### Lowest VRAM config From e4a3a34287452ad23e112e9c1e57d5a4ed567892 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 14:31:44 -0600 Subject: [PATCH 16/27] sd3: add skip layer guidance to the quickstart (typo) --- documentation/quickstart/SD3.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 20f1272a..069c1012 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -273,8 +273,8 @@ The following values are recommended for `config.json`: ```json { "--validation_guidance_skip_layers": [7, 8, 9], - "--validation_guidance_skip_layer_start": 0.01, - "--validation_guidance_skip_layer_stop": 0.2, + "--validation_guidance_skip_layers_start": 0.01, + "--validation_guidance_skip_layers_stop": 0.2, "--validation_guidance_skip_scale": 2.8, "--validation_guidance_scale": 4.0 } From 0dab649fd1b9f32ee004142556bd9ab637a81c45 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 16:20:56 -0600 Subject: [PATCH 17/27] sd3: fix typo reference to validation args --- helpers/training/validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index fd2f8ab6..2c9f33f8 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1197,13 +1197,13 @@ def validate_prompt( and type(self.args.validation_guidance_skip_layers) is list ): extra_validation_kwargs["validation_guidance_skip_layers_start"] = ( - float(args.validation_guidance_skip_layers_start) + float(self.args.validation_guidance_skip_layers_start) ) extra_validation_kwargs["validation_guidance_skip_layers_stop"] = float( - args.validation_guidance_skip_layers_stop + self.args.validation_guidance_skip_layers_stop ) extra_validation_kwargs["validation_guidance_skip_scale"] = float( - args.validation_guidance_skip_scale + self.args.validation_guidance_skip_scale ) extra_validation_kwargs["skip_guidance_layers"] = list( self.args.validation_guidance_skip_layers From 1614839fb876d1cc83242ff1b3e4474fbadeaaf4 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 16:41:21 -0600 Subject: [PATCH 18/27] update args for sd3 pipeline --- helpers/training/validation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 2c9f33f8..089fe290 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1196,13 +1196,13 @@ def validate_prompt( self.args.model_family == "sd3" and type(self.args.validation_guidance_skip_layers) is list ): - extra_validation_kwargs["validation_guidance_skip_layers_start"] = ( - float(self.args.validation_guidance_skip_layers_start) + extra_validation_kwargs["skip_layer_guidance_start"] = float( + self.args.validation_guidance_skip_layers_start ) - extra_validation_kwargs["validation_guidance_skip_layers_stop"] = float( + extra_validation_kwargs["skip_layer_guidance_stop"] = float( self.args.validation_guidance_skip_layers_stop ) - extra_validation_kwargs["validation_guidance_skip_scale"] = float( + extra_validation_kwargs["skip_layer_guidance_scale"] = float( self.args.validation_guidance_skip_scale ) extra_validation_kwargs["skip_guidance_layers"] = list( From 8b869518afa8033d7822d8f13cb82c55b41bb478 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 16:55:36 -0600 Subject: [PATCH 19/27] sd3: do not cast inputs for quanto compat --- helpers/models/sd3/pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index a91aa784..b04c9096 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -1008,15 +1008,13 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( - hidden_states=latent_model_input.to( - device=self.transformer.device, dtype=self.transformer.dtype - ), + hidden_states=latent_model_input.to(device=self.transformer.device), timestep=timestep, encoder_hidden_states=prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device ), pooled_projections=pooled_prompt_embeds.to( - device=self.transformer.device, dtype=self.transformer.dtype + device=self.transformer.device ), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, From 40907f59ee5c38be4c082a4a35711facb51d427f Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 19:02:21 -0600 Subject: [PATCH 20/27] sd3: add shift value of 1 suggestion to quickstart --- documentation/quickstart/SD3.md | 1 + 1 file changed, 1 insertion(+) diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 069c1012..d313f973 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -309,6 +309,7 @@ Some changes were made to SimpleTuner's SD3.5 support: These options have been known to keep SD3.5 in-tact for as long as possible: - optimizer=adamw_bf16 +- flux_schedule_shift=1 - learning_rate=1e-4 - batch_size=4 * 3 GPUs - max_grad_norm=0.1 From e33d588a7e9f8a4cb9321012d58b85875518604e Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 7 Nov 2024 19:12:08 -0600 Subject: [PATCH 21/27] sd3: update SLG guidance doc --- documentation/quickstart/SD3.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index d313f973..4af8e56d 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -276,11 +276,18 @@ The following values are recommended for `config.json`: "--validation_guidance_skip_layers_start": 0.01, "--validation_guidance_skip_layers_stop": 0.2, "--validation_guidance_skip_scale": 2.8, - "--validation_guidance_scale": 4.0 + "--validation_guidance": 4.0 } ``` -When adding more layers (eg. increasing to [6, 7, 8, 9]) the SLG scale should be doubled from 2.8 to 5.6 or greater. +- `..skip_scale` determines how much to scale the positive prompt prediction during skip-layer guidance. The default value of 2.8 is safe for the base model's skip value of `7, 8, 9` but will need to be increased if more layers are skipped, doubling it for each additional layer. +- `..skip_layers` tells which layers to skip during the negative prompt prediction. +- `..skip_layers_start` determine the fraction of the inference pipeline during which skip-layer guidance should begin to be applied. +- `..skip_layers_stop` will set the fraction of the total number of inference steps after which SLG will no longer be applied. + +SLG can be applied for fewer steps for a weaker effect or less reduction of inference speed. + +It seems that extensive training of a LoRA or LyCORIS model will require modification to these values, though it's not clear how exactly it changes. **Lower CFG must be used during inference.** From 25cad1c74747745a20ced9ef426fcc471cb5bf91 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 11:06:04 +0000 Subject: [PATCH 22/27] sd3: fix cpu / gpu location mismatch and dtype mismatch for quanto --- helpers/models/sd3/pipeline.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index b04c9096..653c2a6a 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -1036,16 +1036,13 @@ def __call__( noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input.to( device=self.transformer.device, - dtype=self.transformer.dtype, ), timestep=timestep, encoder_hidden_states=original_prompt_embeds.to( device=self.transformer.device, - dtype=self.transformer.dtype, ), pooled_projections=original_pooled_prompt_embeds.to( device=self.transformer.device, - dtype=self.transformer.dtype, ), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, From c4fef7e37d0b0cea963f71e597c42374131854a9 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 14:28:34 +0000 Subject: [PATCH 23/27] flux and sd3 could use uniform sampling instead of beta or sigmoid --- helpers/configuration/cmd_args.py | 9 +++++++++ helpers/training/trainer.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c51635df..14e1b470 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -149,6 +149,15 @@ def get_argument_parser(): " which has improved results in short experiments. Thanks to @mhirki for the contribution." ), ) + parser.add_argument( + "--flux_use_uniform_schedule", + action="store_true", + help=( + "Whether or not to use a uniform schedule with Flux instead of sigmoid." + " Using uniform sampling may help preserve more capabilities from the base model." + " Some tasks may not benefit from this." + ), + ) parser.add_argument( "--flux_use_beta_schedule", action="store_true", diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a5e70d44..4a0eb031 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -2185,7 +2185,7 @@ def train(self): if self.config.flow_matching: if ( not self.config.flux_fast_schedule - and not self.config.flux_use_beta_schedule + and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule]) ): # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF # also used by: https://github.com/XLabs-AI/x-flux/tree/main @@ -2197,6 +2197,11 @@ def train(self): sigmas = apply_flux_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) + elif self.config.flux_use_uniform_schedule: + sigmas = torch.rand((bsz,), device=self.accelerator.device) + sigmas = apply_flux_schedule_shift( + self.config, self.noise_scheduler, sigmas, noise + ) elif self.config.flux_use_beta_schedule: alpha = self.config.flux_beta_schedule_alpha beta = self.config.flux_beta_schedule_beta From e4e5097e555a4073c51342e4719aa8c3b9458ce1 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 09:21:19 -0600 Subject: [PATCH 24/27] sd3: model card detail expansion --- helpers/publishing/metadata.py | 52 +++++- tests/test_model_card.py | 287 +++++++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 4 deletions(-) create mode 100644 tests/test_model_card.py diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index f41600a2..b0031157 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -153,6 +153,15 @@ def _guidance_rescale(args): return f"\n guidance_rescale={args.validation_guidance_rescale}," +def _skip_layers(args): + if ( + args.model_family.lower() not in ["sd3"] + or args.validation_guidance_skip_layers is None + ): + return "" + return f"\n skip_guidance_layers={args.validation_guidance_skip_layers}," + + def _validation_resolution(args): if args.validation_resolution == "" or args.validation_resolution is None: return f"width=1024,\n" f" height=1024," @@ -185,7 +194,7 @@ def code_example(args, repo_id: str = None): num_inference_steps={args.validation_num_inference_steps}, generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826), {_validation_resolution(args)} - guidance_scale={args.validation_guidance},{_guidance_rescale(args)} + guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)} ).images[0] image.save("output.png", format="PNG") ``` @@ -249,10 +258,38 @@ def flux_schedule_info(args): output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") if args.flux_attention_masked_training: output_args.append("flux_attention_masked_training") - if args.model_type == "lora" and args.lora_type == "standard": + if ( + args.model_type == "lora" + and args.lora_type == "standard" + and args.flux_lora_target is not None + ): output_args.append(f"flux_lora_target={args.flux_lora_target}") output_str = ( - f" (flux parameters={output_args})" + f" (extra parameters={output_args})" + if output_args + else " (no special parameters set)" + ) + + return output_str + + +def sd3_schedule_info(args): + if args.model_family.lower() != "sd3": + return "" + output_args = [] + if args.flux_schedule_auto_shift: + output_args.append("flux_schedule_auto_shift") + if args.flux_schedule_shift is not None: + output_args.append(f"shift={args.flux_schedule_shift}") + if args.flux_use_beta_schedule: + output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}") + output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") + if args.flux_use_uniform_schedule: + output_args.append(f"flux_use_uniform_schedule") + # if args.model_type == "lora" and args.lora_type == "standard": + # output_args.append(f"flux_lora_target={args.flux_lora_target}") + output_str = ( + f" (extra parameters={output_args})" if output_args else " (no special parameters set)" ) @@ -260,6 +297,13 @@ def flux_schedule_info(args): return output_str +def model_schedule_info(args): + if args.model_family == "flux": + return flux_schedule_info(args) + if args.model_family == "sd3": + return sd3_schedule_info(args) + + def save_model_card( repo_id: str, images=None, @@ -384,7 +428,7 @@ def save_model_card( - Micro-batch size: {StateTracker.get_args().train_batch_size} - Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps} - Number of GPUs: {StateTracker.get_accelerator().num_processes} -- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{flux_schedule_info(args=StateTracker.get_args())} +- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())} - Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr} - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} diff --git a/tests/test_model_card.py b/tests/test_model_card.py new file mode 100644 index 00000000..0744e887 --- /dev/null +++ b/tests/test_model_card.py @@ -0,0 +1,287 @@ +import unittest +from unittest.mock import MagicMock, patch +import os +import json + +# Assuming the functions are in a module named 'metadata.py' +from helpers.publishing.metadata import ( + _negative_prompt, + _torch_device, + _model_imports, + _model_load, + _validation_resolution, + _skip_layers, + _guidance_rescale, +) +from helpers.publishing.metadata import * + +# For demonstration purposes, I'll redefine the functions here. +# In your actual test file, import them from your module as shown above. + + +class TestMetadataFunctions(unittest.TestCase): + def setUp(self): + # Mock the args object + self.args = MagicMock() + self.args.lora_type = "standard" + self.args.model_type = "lora" + self.args.model_family = "sdxl" + self.args.validation_prompt = "A test prompt" + self.args.validation_negative_prompt = "A negative prompt" + self.args.validation_num_inference_steps = 50 + self.args.validation_guidance = 7.5 + self.args.validation_guidance_rescale = 0.7 + self.args.validation_resolution = "512x512" + self.args.pretrained_model_name_or_path = "test-model" + self.args.output_dir = "test-output" + self.args.lora_rank = 4 + self.args.lora_alpha = 1.0 + self.args.lora_dropout = 0.0 + self.args.lora_init_type = "kaiming_uniform" + self.args.model_card_note = "Test note" + self.args.validation_using_datasets = False + self.args.flow_matching_loss = "flow-matching" + self.args.flux_fast_schedule = False + self.args.flux_schedule_auto_shift = False + self.args.flux_schedule_shift = None + self.args.flux_guidance_value = None + self.args.flux_guidance_min = None + self.args.flux_guidance_max = None + self.args.flux_use_beta_schedule = False + self.args.flux_beta_schedule_alpha = None + self.args.flux_beta_schedule_beta = None + self.args.flux_attention_masked_training = False + self.args.flux_use_uniform_schedule = False + self.args.flux_lora_target = None + self.args.validation_guidance_skip_layers = None + self.args.validation_seed = 1234 + self.args.validation_noise_scheduler = "ddim" + self.args.model_card_safe_for_work = True + self.args.learning_rate = 1e-4 + self.args.max_grad_norm = 1.0 + self.args.train_batch_size = 4 + self.args.gradient_accumulation_steps = 1 + self.args.optimizer = "AdamW" + self.args.optimizer_config = "" + self.args.mixed_precision = "fp16" + self.args.base_model_precision = "no_change" + self.args.enable_xformers_memory_efficient_attention = False + + def test_model_imports(self): + self.args.lora_type = "standard" + self.args.model_type = "lora" + expected_output = "import torch\nfrom diffusers import DiffusionPipeline" + output = _model_imports(self.args) + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.lora_type = "lycoris" + output = _model_imports(self.args) + self.assertIn("from lycoris import create_lycoris_from_weights", output) + + def test_model_load(self): + self.args.pretrained_model_name_or_path = "pretrained-model" + self.args.output_dir = "output-dir" + self.args.lora_type = "standard" + self.args.model_type = "lora" + + with patch( + "helpers.publishing.metadata.StateTracker.get_hf_username", + return_value="testuser", + ): + output = _model_load(self.args, repo_id="repo-id") + self.assertIn("pipeline.load_lora_weights", output) + self.assertIn("adapter_id = 'testuser/repo-id'", output) + + self.args.lora_type = "lycoris" + output = _model_load(self.args) + self.assertIn("pytorch_lora_weights.safetensors", output) + + def test_torch_device(self): + output = _torch_device() + expected_output = "'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'" + self.assertEqual(output.strip(), expected_output.strip()) + + def test_negative_prompt(self): + self.args.model_family = "sdxl" + output = _negative_prompt(self.args) + expected_output = "negative_prompt = 'A negative prompt'" + self.assertEqual(output.strip(), expected_output.strip()) + + output_in_call = _negative_prompt(self.args, in_call=True) + self.assertIn("negative_prompt=negative_prompt", output_in_call) + + def test_guidance_rescale(self): + self.args.model_family = "sdxl" + output = _guidance_rescale(self.args) + expected_output = "\n guidance_rescale=0.7," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.model_family = "flux" + output = _guidance_rescale(self.args) + self.assertEqual(output.strip(), "") + + def test_skip_layers(self): + self.args.model_family = "sd3" + self.args.validation_guidance_skip_layers = 2 + output = _skip_layers(self.args) + expected_output = "\n skip_guidance_layers=2," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.model_family = "sdxl" + output = _skip_layers(self.args) + self.assertEqual(output.strip(), "") + + def test_validation_resolution(self): + self.args.validation_resolution = "512x512" + output = _validation_resolution(self.args) + expected_output = "width=512,\n height=512," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.validation_resolution = "" + output = _validation_resolution(self.args) + expected_output = "width=1024,\n height=1024," + self.assertEqual(output.strip(), expected_output.strip()) + + def test_code_example(self): + with patch( + "helpers.publishing.metadata._model_imports", + return_value="import torch\nfrom diffusers import DiffusionPipeline", + ): + with patch( + "helpers.publishing.metadata._model_load", return_value="pipeline = ..." + ): + with patch( + "helpers.publishing.metadata._torch_device", return_value="'cuda'" + ): + with patch( + "helpers.publishing.metadata._negative_prompt", + return_value="negative_prompt = 'A negative prompt'", + ): + with patch( + "helpers.publishing.metadata._validation_resolution", + return_value="width=512,\n height=512,", + ): + output = code_example(self.args) + self.assertIn("import torch", output) + self.assertIn("pipeline = ...", output) + self.assertIn("pipeline.to('cuda')", output) + + def test_model_type(self): + self.args.model_type = "lora" + self.args.lora_type = "standard" + output = model_type(self.args) + self.assertEqual(output, "standard PEFT LoRA") + + self.args.lora_type = "lycoris" + output = model_type(self.args) + self.assertEqual(output, "LyCORIS adapter") + + self.args.model_type = "full" + output = model_type(self.args) + self.assertEqual(output, "full rank finetune") + + def test_lora_info(self): + self.args.model_type = "lora" + self.args.lora_type = "standard" + output = lora_info(self.args) + self.assertIn("LoRA Rank: 4", output) + + self.args.lora_type = "lycoris" + # Mocking the file reading + lycoris_config = {"key": "value"} + with patch( + "builtins.open", + unittest.mock.mock_open(read_data=json.dumps(lycoris_config)), + ): + output = lora_info(self.args) + self.assertIn('"key": "value"', output) + + def test_model_card_note(self): + output = model_card_note(self.args) + self.assertIn("Test note", output) + + self.args.model_card_note = "" + output = model_card_note(self.args) + self.assertEqual(output.strip(), "") + + def test_flux_schedule_info(self): + self.args.model_family = "flux" + output = flux_schedule_info(self.args) + self.assertIn("(no special parameters set)", output) + + self.args.flux_fast_schedule = True + output = flux_schedule_info(self.args) + self.assertIn("flux_fast_schedule", output) + + def test_sd3_schedule_info(self): + self.args.model_family = "sd3" + output = sd3_schedule_info(self.args) + self.assertIn("(no special parameters set)", output) + + self.args.flux_schedule_auto_shift = True + output = sd3_schedule_info(self.args) + self.assertIn("flux_schedule_auto_shift", output) + + def test_model_schedule_info(self): + with patch( + "helpers.publishing.metadata.flux_schedule_info", return_value="flux info" + ): + with patch( + "helpers.publishing.metadata.sd3_schedule_info", return_value="sd3 info" + ): + self.args.model_family = "flux" + output = model_schedule_info(self.args) + self.assertEqual(output, "flux info") + + self.args.model_family = "sd3" + output = model_schedule_info(self.args) + self.assertEqual(output, "sd3 info") + + def test_save_model_card(self): + # Mocking StateTracker methods + with patch( + "helpers.publishing.metadata.StateTracker.get_model_family", + return_value="sdxl", + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_data_backends", + return_value={}, + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_epoch", return_value=1 + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_global_step", + return_value=1000, + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_accelerator", + return_value=MagicMock(num_processes=1), + ): + with patch( + "helpers.publishing.metadata.code_example", + return_value="code example", + ): + with patch( + "builtins.open", unittest.mock.mock_open() + ) as mock_file: + save_model_card( + repo_id="test-repo", + images=None, + base_model="test-base-model", + train_text_encoder=True, + prompt="Test prompt", + validation_prompts=["Test prompt"], + validation_shortnames=["shortname"], + repo_folder="test-folder", + ) + # Ensure the README.md was written + mock_file.assert_called_with( + os.path.join("test-folder", "README.md"), + "w", + encoding="utf-8", + ) + + +if __name__ == "__main__": + unittest.main() From 73c344119e1a311a1d1438f580c574ff07115678 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 09:23:33 -0600 Subject: [PATCH 25/27] remove boilerplate template text --- tests/test_model_card.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 0744e887..b9c07d33 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -3,7 +3,6 @@ import os import json -# Assuming the functions are in a module named 'metadata.py' from helpers.publishing.metadata import ( _negative_prompt, _torch_device, @@ -15,9 +14,6 @@ ) from helpers.publishing.metadata import * -# For demonstration purposes, I'll redefine the functions here. -# In your actual test file, import them from your module as shown above. - class TestMetadataFunctions(unittest.TestCase): def setUp(self): From 115ee0b2ba9e9c774d3633adc7c200ee9d8b55b8 Mon Sep 17 00:00:00 2001 From: Mikael Hirki Date: Sat, 9 Nov 2024 15:52:47 +0200 Subject: [PATCH 26/27] Revert early return in setup_pipeline back to a break. This fixes random validation errors with SD3.5 after commit 48cfc09db92ff64281ab7eb6ca9e41f59bb9e23f removed some earlier fixes. This also fixes torch.compile not getting called for the validation pipeline. Calling self.pipeline.to(self.inference_device) appears to have an unwanted side-effect: it moves additional text encoders to the accelerator device. In the case of SD3.5, I saw text_encoder_2 and text_encoder_3 getting moved to the GPU. This caused my RTX 3090 to go OOM when trying to generate validation images during training. Explicitly setting text_encoder_2 and text_encoder_3 to None in extra_pipeline_kwargs fixes this issue. --- helpers/training/validation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 089fe290..8b80acd5 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -968,6 +968,8 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): "tokenizer": self.tokenizer_1, "vae": self.vae, "safety_checker": None, + "text_encoder_2": None, + "text_encoder_3": None, } if type(pipeline_cls) is StableDiffusionXLPipeline: del extra_pipeline_kwargs["safety_checker"] @@ -1071,7 +1073,7 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): logger.error(e) logger.error(traceback.format_exc()) continue - return None + break if self.args.validation_torch_compile: if self.unet is not None and not is_compiled_module(self.unet): logger.warning( From 648c4ada916743b25ef60c7dde2cea9a3580edec Mon Sep 17 00:00:00 2001 From: Mikael Hirki Date: Sun, 10 Nov 2024 12:19:57 +0200 Subject: [PATCH 27/27] Apply suggested changes proposed by bghira. --- helpers/training/validation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 8b80acd5..478d56af 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -968,9 +968,11 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): "tokenizer": self.tokenizer_1, "vae": self.vae, "safety_checker": None, - "text_encoder_2": None, - "text_encoder_3": None, } + if self.args.model_family in ["sd3", "sdxl", "flux"]: + extra_pipeline_kwargs["text_encoder_2"] = None + if self.args.model_family in ["sd3"]: + extra_pipeline_kwargs["text_encoder_3"] = None if type(pipeline_cls) is StableDiffusionXLPipeline: del extra_pipeline_kwargs["safety_checker"] del extra_pipeline_kwargs["text_encoder"]