From fd3a14c9d7d277165a06bd60af10a2a6ddb3e006 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 23 Jan 2025 23:41:49 +0800 Subject: [PATCH 001/221] Update sdxl_train_util.py Try force stay FP16 on text encoders --- library/sdxl_train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91a..1c4e1965f 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -107,10 +107,10 @@ def _load_target_model( text_encoder2 = pipe.text_encoder_2 # convert to fp32 for cache text_encoders outputs - if text_encoder1.dtype != torch.float32: - text_encoder1 = text_encoder1.to(dtype=torch.float32) - if text_encoder2.dtype != torch.float32: - text_encoder2 = text_encoder2.to(dtype=torch.float32) + #if text_encoder1.dtype != torch.float32: + # text_encoder1 = text_encoder1.to(dtype=torch.float32) + #if text_encoder2.dtype != torch.float32: + # text_encoder2 = text_encoder2.to(dtype=torch.float32) vae = pipe.vae unet = pipe.unet From 75a225d74b82598e9977b3ec26e96edfffe73010 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 23 Jan 2025 23:54:25 +0800 Subject: [PATCH 002/221] Update sdxl_gen_img.py --- sdxl_gen_img.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index d52f85a8f..b01ddbc2b 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -13,6 +13,7 @@ import os import random import re +import gc import diffusers import numpy as np @@ -1489,7 +1490,7 @@ def main(args): files = glob.glob(args.ckpt) if len(files) == 1: args.ckpt = files[0] - + gc.collect() (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) @@ -1657,7 +1658,7 @@ def __getattr__(self, item): text_encoder1.eval() text_encoder2.eval() unet.eval() - + gc.collect() # networkを組み込む if args.network_module: networks = [] From 68c65e0ae586a12b868baaa65a7eb80ae00e4f24 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 00:17:40 +0800 Subject: [PATCH 003/221] Update sdxl_gen_img.py --- sdxl_gen_img.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index b01ddbc2b..ba853bf28 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1490,9 +1490,9 @@ def main(args): files = glob.glob(args.ckpt) if len(files) == 1: args.ckpt = files[0] - gc.collect() + device = get_preferred_device() (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) @@ -1622,7 +1622,7 @@ def __getattr__(self, item): # scheduler.config.clip_sample = True # deviceを決定する - device = get_preferred_device() + # custom pipelineをコピったやつを生成する if args.vae_slices: @@ -1649,16 +1649,15 @@ def __getattr__(self, item): if args.no_half_vae: logger.info("set vae_dtype to float32") vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) + #vae.to(vae_dtype).to(device) vae.eval() - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) + #text_encoder1.to(dtype).to(device) + #text_encoder2.to(dtype).to(device) + #unet.to(dtype).to(device) text_encoder1.eval() text_encoder2.eval() unet.eval() - gc.collect() # networkを組み込む if args.network_module: networks = [] From 8daa8b32836469bee34010d107df7e38d3059136 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 00:19:09 +0800 Subject: [PATCH 004/221] Update sdxl_train_util.py --- library/sdxl_train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1c4e1965f..b74bea91a 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -107,10 +107,10 @@ def _load_target_model( text_encoder2 = pipe.text_encoder_2 # convert to fp32 for cache text_encoders outputs - #if text_encoder1.dtype != torch.float32: - # text_encoder1 = text_encoder1.to(dtype=torch.float32) - #if text_encoder2.dtype != torch.float32: - # text_encoder2 = text_encoder2.to(dtype=torch.float32) + if text_encoder1.dtype != torch.float32: + text_encoder1 = text_encoder1.to(dtype=torch.float32) + if text_encoder2.dtype != torch.float32: + text_encoder2 = text_encoder2.to(dtype=torch.float32) vae = pipe.vae unet = pipe.unet From 6231883ef6590f57665504f59c71464351784ccd Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 00:30:06 +0800 Subject: [PATCH 005/221] Update sdxl_gen_img.py --- sdxl_gen_img.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ba853bf28..f325ecd66 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1491,11 +1491,14 @@ def main(args): if len(files) == 1: args.ckpt = files[0] device = get_preferred_device() + logger.info(f"preferred device: {device}") (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) # xformers、Hypernetwork対応 if not args.diffusers_xformers: mem_eff = not (args.xformers or args.sdpa) @@ -1649,12 +1652,9 @@ def __getattr__(self, item): if args.no_half_vae: logger.info("set vae_dtype to float32") vae_dtype = torch.float32 - #vae.to(vae_dtype).to(device) + vae.to(vae_dtype).to(device) vae.eval() - #text_encoder1.to(dtype).to(device) - #text_encoder2.to(dtype).to(device) - #unet.to(dtype).to(device) text_encoder1.eval() text_encoder2.eval() unet.eval() From 3ca8dc51e944448c76107a1082301353f2d49285 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 02:37:59 +0800 Subject: [PATCH 006/221] Update sdxl_gen_img.py --- sdxl_gen_img.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index f325ecd66..bf6c89076 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1492,8 +1492,9 @@ def main(args): args.ckpt = files[0] device = get_preferred_device() logger.info(f"preferred device: {device}") + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) text_encoder1.to(dtype).to(device) @@ -3194,6 +3195,10 @@ def setup_parser() -> argparse.ArgumentParser: help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) + parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16") + parser.add_argument( + "--full_bf16", action="store_true", help="Loading model in bf16" + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" From 8f850249178fd7a5a591bd27cf0c6fb4ee27aa42 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:17:25 +0800 Subject: [PATCH 007/221] Create accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3215 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3215 insertions(+) create mode 100644 accel_sdxl_gen_img.py diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py new file mode 100644 index 000000000..bf6c89076 --- /dev/null +++ b/accel_sdxl_gen_img.py @@ -0,0 +1,3215 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re +import gc + +import diffusers +import numpy as np + +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + logger.info("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + logger.info("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + logger.info("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + logger.info("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + logger.info("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: InferSdxlUNet2DConditionModel = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet # not supported yet + self.control_nets: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # logger.info("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_nets: + # ControlNetのhintにguide imageを流用する + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if reginonal_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + clean_memory() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_nets: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable control net if ratio is set + if self.control_nets and self.control_net_enabled: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + i += 1 + + if return_latents: + return latents + + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + clean_memory() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + clean_memory() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + logger.info(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # -2 is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-2] + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-2] + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + logger.warning(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + device = get_preferred_device() + logger.info(f"preferred device: {device}") + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + logger.info("loading tokenizer") + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # logger.info("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + logger.info("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + text_encoder1.eval() + text_encoder2.eval() + unet.eval() + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + logger.info(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + logger.info(f"import network module: {network_module}") + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + logger.info(f"load network weights from: {network_weight}") + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + logger.info(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + logger.warning("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to([text_encoder1, text_encoder2], unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + logger.info(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + logger.info("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + logger.info("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[Tuple[ControlNetLLLite, float]] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + logger.info(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append((control_net, ratio)) + + if args.opt_channels_last: + logger.info(f"set optimizing: channels last") + text_encoder1.to(memory_format=torch.channels_last) + text_encoder2.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + [text_encoder1, text_encoder2], + [tokenizer1, tokenizer2], + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + logger.info("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizer1.add_tokens(token_strings) + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) + token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoder1.resize_token_embeddings(len(tokenizer1)) + text_encoder2.resize_token_embeddings(len(tokenizer2)) + token_embeds1 = text_encoder1.get_input_embeddings().weight.data + token_embeds2 = text_encoder2.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + logger.info(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + logger.info(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + logger.info(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + logger.info(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + logger.info(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + logger.info(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + logger.info(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + logger.info("get prompts from images' metadata") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + logger.info(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + logger.info(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + logger.info("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + logger.info(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 + if args.H is None: + args.H = 1024 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + logger.info("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + logger.info("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + logger.info("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + logger.info("") + logger.info("Type prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + logger.error("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + logger.info("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16") + parser.add_argument( + "--full_bf16", action="store_true", help="Loading model in bf16" + ) + + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + setup_logging(args, reset=True) + main(args) From 8b241c4c8b30ebd9c65f99bfe523541ccab6c697 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:39:01 +0800 Subject: [PATCH 008/221] Update accel_sdxl_gen_img.py draft accel image gen --- accel_sdxl_gen_img.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bf6c89076..7948efaaf 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1486,15 +1486,20 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" # モデルを読み込む + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + distributed_state = PartialState() + device = distributed_state.device + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う files = glob.glob(args.ckpt) if len(files) == 1: args.ckpt = files[0] - device = get_preferred_device() + #device = get_preferred_device() logger.info(f"preferred device: {device}") - model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) text_encoder1.to(dtype).to(device) @@ -1819,7 +1824,7 @@ def __getattr__(self, item): args.clip_skip, ) pipe.set_control_nets(control_nets) - logger.info("pipeline is ready.") + logger.info(f"pipeline on {device} is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() From 0770c7ba1bb5db64e82bb3cd160a02ff0d016064 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 18:35:25 +0800 Subject: [PATCH 009/221] Update sdxl_gen_img.py testing logger --- sdxl_gen_img.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index bf6c89076..a01aed69c 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -2810,7 +2810,9 @@ def scale_and_round(x): num_sub_prompts, ), ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + logger.info("Does this run? When number of prompts less than batch?") process_batch(batch_data, highres_fix) batch_data.clear() @@ -2820,6 +2822,7 @@ def scale_and_round(x): batch_data.clear() global_step += 1 + logger.info(f"Global Step: {global_step}") prompt_index += 1 From e5cf6b6d065385a6ff2bd4ab601243ad8c134d54 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 19:28:38 +0800 Subject: [PATCH 010/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7948efaaf..6d7ad988b 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2816,12 +2816,22 @@ def scale_and_round(x): ), ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] + if len(batch_data) == args.batch_size*distributed_state.num_processes: + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() global_step += 1 @@ -2829,7 +2839,12 @@ def scale_and_round(x): prompt_index += 1 if len(batch_data) > 0: - process_batch(batch_data, highres_fix) + batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + prev_image = process_batch(batch_list, highres_fix)[0] + accelerator.wait_for_everyone() batch_data.clear() logger.info("done!") From a15d8f748bd9eb0c1c6a9506c819668061971e6b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 24 Jan 2025 23:58:23 +0800 Subject: [PATCH 011/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6d7ad988b..5de9b282e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2816,20 +2816,27 @@ def scale_and_round(x): ), ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") batch_data_split = np.array_split(batch_data, distributed_state.num_processes) with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + + logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") + for i in range(batch_list): + logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: + logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") batch_data_split = np.array_split(batch_data, distributed_state.num_processes) with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") + for i in range(batch_list): + logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() @@ -2842,7 +2849,9 @@ def scale_and_round(x): batch_data_split = np.array_split(batch_data, distributed_state.num_processes) with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_lists)} onto device {distributed_state.device}") + logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") + for i in range(batch_list): + logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() From 66d4a62792f94ff259dbcc6981aff24a0a0bdcf2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:22:07 +0800 Subject: [PATCH 012/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5de9b282e..31b34eac5 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2822,7 +2822,7 @@ def scale_and_round(x): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(batch_list): + for i in len(batch_list): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() @@ -2835,7 +2835,7 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(batch_list): + for i in len(batch_list): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() @@ -2850,7 +2850,7 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(batch_list): + for i in len(batch_list): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() From d1e365081a9907b41ac554e551a0da1e2feda637 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:22:59 +0800 Subject: [PATCH 013/221] Update sdxl_gen_img.py --- sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index a01aed69c..68a9a775f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -2820,9 +2820,9 @@ def scale_and_round(x): if len(batch_data) == args.batch_size: prev_image = process_batch(batch_data, highres_fix)[0] batch_data.clear() - - global_step += 1 logger.info(f"Global Step: {global_step}") + global_step += 1 + prompt_index += 1 From 0933ee90f43d86dd8b0363161b86a809fd366f23 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:29:44 +0800 Subject: [PATCH 014/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 31b34eac5..66ac7f9c0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2850,8 +2850,8 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in len(batch_list): - logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") + for i in len(batch_list): + logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() From f7fde60f4edbf7d75c3f02e0b4bcf7f684bd1d09 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:34:21 +0800 Subject: [PATCH 015/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 66ac7f9c0..1c7c871f3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2860,8 +2860,8 @@ def scale_and_round(x): def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - + parser = train_network.setup_parser() + sdxl_train_util.add_sdxl_training_arguments(parser) add_logging_arguments(parser) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") From 2c8cf941504ef7852e68516b343573a33fb3980b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:37:41 +0800 Subject: [PATCH 016/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 1c7c871f3..e926e466a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -14,6 +14,7 @@ import random import re import gc +import sdxl_train_network import diffusers import numpy as np @@ -2860,8 +2861,8 @@ def scale_and_round(x): def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() - sdxl_train_util.add_sdxl_training_arguments(parser) + parser = sdxl_train_network.setup_parser() + #sdxl_train_util.add_sdxl_training_arguments(parser) add_logging_arguments(parser) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") From f1fc65e61d71da81055bd21ab638c3896d4d15fd Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:39:22 +0800 Subject: [PATCH 017/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e926e466a..7ed356208 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2863,7 +2863,7 @@ def scale_and_round(x): def setup_parser() -> argparse.ArgumentParser: parser = sdxl_train_network.setup_parser() #sdxl_train_util.add_sdxl_training_arguments(parser) - add_logging_arguments(parser) + #add_logging_arguments(parser) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( From 8a3548a380afe34e0495e5164bb17b6dfc849dcd Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:20:10 +0800 Subject: [PATCH 018/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7ed356208..60c1b6e61 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1487,9 +1487,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" # モデルを読み込む - logger.info("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process + logger.info("preparing pipes") + #accelerator = train_util.prepare_accelerator(args) + #is_main_process = accelerator.is_main_process distributed_state = PartialState() device = distributed_state.device @@ -1499,13 +1499,12 @@ def main(args): args.ckpt = files[0] #device = get_preferred_device() logger.info(f"preferred device: {device}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) + # xformers、Hypernetwork対応 if not args.diffusers_xformers: mem_eff = not (args.xformers or args.sdpa) @@ -1662,6 +1661,9 @@ def __getattr__(self, item): vae.to(vae_dtype).to(device) vae.eval() + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) text_encoder1.eval() text_encoder2.eval() unet.eval() @@ -2861,9 +2863,9 @@ def scale_and_round(x): def setup_parser() -> argparse.ArgumentParser: - parser = sdxl_train_network.setup_parser() + parser = argparse.ArgumentParser() #sdxl_train_util.add_sdxl_training_arguments(parser) - #add_logging_arguments(parser) + add_logging_arguments(parser) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument( From 3534842076c04b897874b4fce175bc1a47c5e9b9 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:22:40 +0800 Subject: [PATCH 019/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 60c1b6e61..ddc755e7d 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -14,7 +14,6 @@ import random import re import gc -import sdxl_train_network import diffusers import numpy as np From 319a25cd2c2d616841a954d0666b568612dd6759 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:26:28 +0800 Subject: [PATCH 020/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index ddc755e7d..2c07778d3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -14,6 +14,7 @@ import random import re import gc +from accelerate import PartialState import diffusers import numpy as np From e9cb46672cec64f423d5390597f189e37fc1afcf Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:28:41 +0800 Subject: [PATCH 021/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 2c07778d3..4bee12166 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -3231,7 +3231,11 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--full_bf16", action="store_true", help="Loading model in bf16" ) - + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) From ba91cce8656ed3c72ae2a4a90fd8501560ced9f7 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:34:38 +0800 Subject: [PATCH 022/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 4bee12166..eacbf5865 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2834,6 +2834,7 @@ def scale_and_round(x): batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") + logger.info(f"{batch_data}") batch_data_split = np.array_split(batch_data, distributed_state.num_processes) with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: From 7cfae3d593eab384d38282de7d2cd0bba78ecb7d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:38:05 +0800 Subject: [PATCH 023/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index eacbf5865..c55488b89 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2835,6 +2835,8 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"{batch_data}") + test_split = np.array_split(batch_data, 2) + logger.info(f"{test_split}") batch_data_split = np.array_split(batch_data, distributed_state.num_processes) with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: From 0a16ab6331bf0fc3bd5a5fe0d1da7a1c3806302b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 02:31:14 +0800 Subject: [PATCH 024/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index c55488b89..9bed7cd0a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2835,9 +2835,21 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"{batch_data}") - test_split = np.array_split(batch_data, 2) - logger.info(f"{test_split}") - batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + batch_data_split = [] + batch_index = 0 + test_batch_data_split = [] + test_batch_index = 0 + for i in len(batch_data): + logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_data_split[batch_index].append(batch_data[i]) + test_batch_data_split[test_batch_index].append(batch_data[i]) + if i % args.batch_size == 0: + batch_index += 1 + if i % 4 == 0: + test_batch_index += 1 + logger.info(f"{batch_data_split}") + logger.info(f"{test_batch_data_split}") + with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") From 11814604f3c7bda1bfb947acd4c9832d2b2fd2f3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 02:48:06 +0800 Subject: [PATCH 025/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9bed7cd0a..8241f2f9e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2835,18 +2835,16 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"{batch_data}") - batch_data_split = [] + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = 0 - test_batch_data_split = [] - test_batch_index = 0 - for i in len(batch_data): + test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(my_list), 3)] + #test_batch_index = 0 + for i while i < len(batch_data): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_data_split[batch_index].append(batch_data[i]) test_batch_data_split[test_batch_index].append(batch_data[i]) if i % args.batch_size == 0: batch_index += 1 - if i % 4 == 0: - test_batch_index += 1 logger.info(f"{batch_data_split}") logger.info(f"{test_batch_data_split}") From 9131e2f592a352f5fdfc67361046c41ce2c3ec3c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 02:50:46 +0800 Subject: [PATCH 026/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8241f2f9e..efc49863f 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2839,7 +2839,7 @@ def scale_and_round(x): batch_index = 0 test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(my_list), 3)] #test_batch_index = 0 - for i while i < len(batch_data): + for i in range(len(batch_data)): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_data_split[batch_index].append(batch_data[i]) test_batch_data_split[test_batch_index].append(batch_data[i]) From be3607088bf933cb93b8d8749e9f1d1cc120a208 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 02:52:21 +0800 Subject: [PATCH 027/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index efc49863f..26ec0d1fe 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2862,7 +2862,18 @@ def scale_and_round(x): prompt_index += 1 if len(batch_data) > 0: - batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = 0 + test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(my_list), 3)] + #test_batch_index = 0 + for i in range(len(batch_data)): + logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_data_split[batch_index].append(batch_data[i]) + test_batch_data_split[test_batch_index].append(batch_data[i]) + if i % args.batch_size == 0: + batch_index += 1 + logger.info(f"{batch_data_split}") + logger.info(f"{test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") From c0d3abd2bdb922f73bd0f7db64560d6762629588 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 02:54:46 +0800 Subject: [PATCH 028/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 26ec0d1fe..f16906a84 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2837,12 +2837,11 @@ def scale_and_round(x): logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = 0 - test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(my_list), 3)] + test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 for i in range(len(batch_data)): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_data_split[batch_index].append(batch_data[i]) - test_batch_data_split[test_batch_index].append(batch_data[i]) if i % args.batch_size == 0: batch_index += 1 logger.info(f"{batch_data_split}") @@ -2864,12 +2863,11 @@ def scale_and_round(x): if len(batch_data) > 0: batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = 0 - test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(my_list), 3)] + test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 for i in range(len(batch_data)): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_data_split[batch_index].append(batch_data[i]) - test_batch_data_split[test_batch_index].append(batch_data[i]) if i % args.batch_size == 0: batch_index += 1 logger.info(f"{batch_data_split}") From e45db09646698f03a941503b82ab48474e3f24ca Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:01:54 +0800 Subject: [PATCH 029/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index f16906a84..86da7a95c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2835,7 +2835,7 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"{batch_data}") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_data_split = [][] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = 0 test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 @@ -2861,7 +2861,7 @@ def scale_and_round(x): prompt_index += 1 if len(batch_data) > 0: - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_data_split = [][] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = 0 test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 From a5e3715064e7981d4fde53b9ca2f2807e6924581 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:07:29 +0800 Subject: [PATCH 030/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 86da7a95c..e07acd46f 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2835,15 +2835,16 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"{batch_data}") - batch_data_split = [][] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = 0 + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 for i in range(len(batch_data)): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") - batch_data_split[batch_index].append(batch_data[i]) + batch_index.append(batch_data[i]) if i % args.batch_size == 0: - batch_index += 1 + batch_data_split.append(batch_index) + batch_index.clear() logger.info(f"{batch_data_split}") logger.info(f"{test_batch_data_split}") @@ -2861,15 +2862,16 @@ def scale_and_round(x): prompt_index += 1 if len(batch_data) > 0: - batch_data_split = [][] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = 0 + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] #test_batch_index = 0 for i in range(len(batch_data)): logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") - batch_data_split[batch_index].append(batch_data[i]) + batch_index.append(batch_data[i]) if i % args.batch_size == 0: - batch_index += 1 + batch_data_split.append(batch_index) + batch_index.clear() logger.info(f"{batch_data_split}") logger.info(f"{test_batch_data_split}") with torch.no_grad(): From 85ae5272196948d490d41d85fcd19ae2de761d0b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:13:57 +0800 Subject: [PATCH 031/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e07acd46f..8855e50e8 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2820,12 +2820,25 @@ def scale_and_round(x): ) if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = np.array_split(batch_data, distributed_state.num_processes) + logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") + logger.info(f"{batch_data}") + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] + test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] + #test_batch_index = 0 + for i in range(len(batch_data)): + logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_index.append(batch_data[i]) + if i % args.batch_size == 0: + batch_data_split.append(batch_index) + batch_index.clear() + logger.info(f"batch_data_split: {batch_data_split}") + logger.info(f"test_batch_data_split: {test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in len(batch_list): + for i in range(len(batch_list)): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() @@ -2845,13 +2858,13 @@ def scale_and_round(x): if i % args.batch_size == 0: batch_data_split.append(batch_index) batch_index.clear() - logger.info(f"{batch_data_split}") - logger.info(f"{test_batch_data_split}") + logger.info(f"batch_data_split: {batch_data_split}") + logger.info(f"test_batch_data_split: {test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in len(batch_list): + for i in range(len(batch_list)): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() @@ -2877,7 +2890,7 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in len(batch_list): + for i in range(len(batch_list)): logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() From cad34c5bb535100bb16c903dc9799c9e1551b225 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:17:30 +0800 Subject: [PATCH 032/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8855e50e8..fbd738a8e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2891,7 +2891,7 @@ def scale_and_round(x): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") for i in range(len(batch_list)): - logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") + logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() From 0c08de4bf28338344444ac5caebcf09fb91719a5 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:20:15 +0800 Subject: [PATCH 033/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index fbd738a8e..2de568ce3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2839,7 +2839,7 @@ def scale_and_round(x): logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") for i in range(len(batch_list)): - logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") + logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() @@ -2865,7 +2865,7 @@ def scale_and_round(x): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") for i in range(len(batch_list)): - logger.info(f"Prompt {i}: {batch_list[i].base.prompt}") + logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list, highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() From 9550106016be71112a28d9bd33ffb0e0da2c61e8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:25:38 +0800 Subject: [PATCH 034/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 2de568ce3..9273ec537 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2836,11 +2836,11 @@ def scale_and_round(x): logger.info(f"test_batch_data_split: {test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - - logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(len(batch_list)): + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"batch_list: {batch_list}") + for i in range(len(batch_list[0])): logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list, highres_fix)[0] + prev_image = process_batch(batch_list[0], highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() @@ -2863,10 +2863,11 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(len(batch_list)): + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"batch_list: {batch_list}") + for i in range(len(batch_list[0])): logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list, highres_fix)[0] + prev_image = process_batch(batch_list[0], highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() @@ -2889,10 +2890,11 @@ def scale_and_round(x): logger.info(f"{test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list)} prompts onto device {distributed_state.device}:") - for i in range(len(batch_list)): + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"batch_list: {batch_list}") + for i in range(len(batch_list[0])): logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list, highres_fix)[0] + prev_image = process_batch(batch_list[0], highres_fix)[0] accelerator.wait_for_everyone() batch_data.clear() From b927a7d345e5d66c872a2219ed5dd82900c98998 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 03:55:11 +0800 Subject: [PATCH 035/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9273ec537..245c656de 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2824,24 +2824,21 @@ def scale_and_round(x): logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] - #test_batch_index = 0 for i in range(len(batch_data)): - logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_index.append(batch_data[i]) - if i % args.batch_size == 0: + if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index) batch_index.clear() logger.info(f"batch_data_split: {batch_data_split}") - logger.info(f"test_batch_data_split: {test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list: {batch_list}") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") + logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - accelerator.wait_for_everyone() + distributed_state.wait_for_everyone() batch_data.clear() batch_data.append(b1) @@ -2850,25 +2847,32 @@ def scale_and_round(x): logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] - #test_batch_index = 0 + test_batch_data_split = [] + test_batch_index = [] for i in range(len(batch_data)): - logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_index.append(batch_data[i]) - if i % args.batch_size == 0: + test_batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index) batch_index.clear() + if (i+1) % 4 == 0: + test_batch_data_split.append(batch_index) + test_batch_index.clear() logger.info(f"batch_data_split: {batch_data_split}") - logger.info(f"test_batch_data_split: {test_batch_data_split}") + for i in range(len(test_batch_data_split)): + logger.info(f"test_batch_data_split[{i}]: {test_batch_data_split[i]}") + for j in range(len(test_batch_data_split[i])): + logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list: {batch_list}") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") + logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - accelerator.wait_for_everyone() + distributed_state.wait_for_everyone() batch_data.clear() global_step += 1 @@ -2878,24 +2882,21 @@ def scale_and_round(x): if len(batch_data) > 0: batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - test_batch_data_split = [batch_data[i:i+3] for i in range(0, len(batch_data), 3)] - #test_batch_index = 0 for i in range(len(batch_data)): - logger.info(f"Prompt {i}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_index.append(batch_data[i]) - if i % args.batch_size == 0: + if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index) batch_index.clear() logger.info(f"{batch_data_split}") - logger.info(f"{test_batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list: {batch_list}") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i}: {batch_list[0][i].base.prompt}") + logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - accelerator.wait_for_everyone() + distributed_state.wait_for_everyone() batch_data.clear() logger.info("done!") From 77ca8713933a5349c03dcab7a9b0ed7063ef1e01 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 04:08:42 +0800 Subject: [PATCH 036/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 245c656de..f378a4578 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2854,10 +2854,12 @@ def scale_and_round(x): batch_index.append(batch_data[i]) test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: + logger.info(f"Loading {batch_index}") batch_data_split.append(batch_index) batch_index.clear() if (i+1) % 4 == 0: - test_batch_data_split.append(batch_index) + logger.info(f"Loading {test_batch_index}") + test_batch_data_split.append(test_batch_index) test_batch_index.clear() logger.info(f"batch_data_split: {batch_data_split}") for i in range(len(test_batch_data_split)): From 0ed9bbf23b5e1d0e949d63561c6c91013989394a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 04:19:30 +0800 Subject: [PATCH 037/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index f378a4578..df1b22ccf 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2828,7 +2828,7 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index) + batch_data_split.append(batch_index.copy()) batch_index.clear() logger.info(f"batch_data_split: {batch_data_split}") with torch.no_grad(): @@ -2855,11 +2855,11 @@ def scale_and_round(x): test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: logger.info(f"Loading {batch_index}") - batch_data_split.append(batch_index) + batch_data_split.append(batch_index.copy()) batch_index.clear() if (i+1) % 4 == 0: logger.info(f"Loading {test_batch_index}") - test_batch_data_split.append(test_batch_index) + test_batch_data_split.append(test_batch_index.copy()) test_batch_index.clear() logger.info(f"batch_data_split: {batch_data_split}") for i in range(len(test_batch_data_split)): @@ -2888,7 +2888,7 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index) + batch_data_split.append(batch_index.copy()) batch_index.clear() logger.info(f"{batch_data_split}") with torch.no_grad(): From ce9dfda59f9d8984801e66b2e2aa3858b8bca5f7 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 11:40:59 +0800 Subject: [PATCH 038/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 608 +++++++++++++++++++++--------------------- 1 file changed, 309 insertions(+), 299 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index df1b22ccf..d7ffc6282 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2440,272 +2440,275 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + if distributed_state.is_main_process: + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + if distributed_state.is_main_process: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue + if distributed_state.is_main_process: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2825,12 +2828,14 @@ def scale_and_round(x): batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - logger.info(f"batch_data_split: {batch_data_split}") + if distributed_state.is_main_process: + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + if distributed_state.is_main_process: + logger.info(f"batch_data_split: {batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") @@ -2839,33 +2844,35 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - batch_data.clear() - - batch_data.append(b1) + if distributed_state.is_main_process: + batch_data.clear() + if distributed_state.is_main_process: + batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"{batch_data}") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - test_batch_data_split = [] - test_batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") - batch_index.append(batch_data[i]) - test_batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - logger.info(f"Loading {batch_index}") - batch_data_split.append(batch_index.copy()) - batch_index.clear() - if (i+1) % 4 == 0: - logger.info(f"Loading {test_batch_index}") - test_batch_data_split.append(test_batch_index.copy()) - test_batch_index.clear() - logger.info(f"batch_data_split: {batch_data_split}") - for i in range(len(test_batch_data_split)): - logger.info(f"test_batch_data_split[{i}]: {test_batch_data_split[i]}") - for j in range(len(test_batch_data_split[i])): - logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") + if distributed_state.is_main_process: + logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") + logger.info(f"{batch_data}") + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] + test_batch_data_split = [] + test_batch_index = [] + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_index.append(batch_data[i]) + test_batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + logger.info(f"Loading {batch_index}") + batch_data_split.append(batch_index.copy()) + batch_index.clear() + if (i+1) % 4 == 0: + logger.info(f"Loading {test_batch_index}") + test_batch_data_split.append(test_batch_index.copy()) + test_batch_index.clear() + logger.info(f"batch_data_split: {batch_data_split}") + for i in range(len(test_batch_data_split)): + logger.info(f"test_batch_data_split[{i}]: {test_batch_data_split[i]}") + for j in range(len(test_batch_data_split[i])): + logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: @@ -2875,7 +2882,8 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - batch_data.clear() + if distributed_state.is_main_process: + batch_data.clear() global_step += 1 @@ -2884,13 +2892,14 @@ def scale_and_round(x): if len(batch_data) > 0: batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - logger.info(f"{batch_data_split}") + if distributed_state.is_main_process: + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + logger.info(f"{batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") @@ -2899,7 +2908,8 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - batch_data.clear() + if distributed_state.is_main_process: + batch_data.clear() logger.info("done!") From c710432f0a066a1b0529abc2a28febabc95e4801 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 11:56:04 +0800 Subject: [PATCH 039/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d7ffc6282..2885f78e0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -20,7 +20,7 @@ import numpy as np import torch -from library.device_utils import init_ipex, clean_memory, get_preferred_device +from library.device_utils import init_ipex, clean_memory, get_preferred_device, clean_memory_on_device init_ipex() import torchvision @@ -1499,11 +1499,13 @@ def main(args): args.ckpt = files[0] #device = get_preferred_device() logger.info(f"preferred device: {device}") + clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + distributed_state.wait_for_everyone() # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1828,6 +1830,7 @@ def __getattr__(self, item): ) pipe.set_control_nets(control_nets) logger.info(f"pipeline on {device} is ready.") + distributed_state.wait_for_everyone() if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2834,12 +2837,10 @@ def scale_and_round(x): if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) batch_index.clear() - if distributed_state.is_main_process: - logger.info(f"batch_data_split: {batch_data_split}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list: {batch_list}") + logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] @@ -2851,7 +2852,6 @@ def scale_and_round(x): if len(batch_data) == args.batch_size*distributed_state.num_processes: if distributed_state.is_main_process: logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] test_batch_data_split = [] @@ -2861,23 +2861,20 @@ def scale_and_round(x): batch_index.append(batch_data[i]) test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: - logger.info(f"Loading {batch_index}") batch_data_split.append(batch_index.copy()) batch_index.clear() if (i+1) % 4 == 0: - logger.info(f"Loading {test_batch_index}") test_batch_data_split.append(test_batch_index.copy()) test_batch_index.clear() - logger.info(f"batch_data_split: {batch_data_split}") for i in range(len(test_batch_data_split)): - logger.info(f"test_batch_data_split[{i}]: {test_batch_data_split[i]}") + logger.info(f"test_batch_data_split[{i}]:") for j in range(len(test_batch_data_split[i])): logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list: {batch_list}") + logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] From 6955fb804f05650378f5a1d56e77bc3394fc6c32 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:26:50 +0800 Subject: [PATCH 040/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 451 +++++++++++++++++++++--------------------- 1 file changed, 225 insertions(+), 226 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 2885f78e0..db093707d 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2487,231 +2487,231 @@ def scale_and_round(x): logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: - if distributed_state.is_main_process: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue + + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") # override Deep Shrink if ds_depth_1 is not None: @@ -2827,12 +2827,11 @@ def scale_and_round(x): if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"{batch_data}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(batch_data)): if distributed_state.is_main_process: - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) @@ -2857,7 +2856,7 @@ def scale_and_round(x): test_batch_data_split = [] test_batch_index = [] for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}}") batch_index.append(batch_data[i]) test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: @@ -2891,7 +2890,7 @@ def scale_and_round(x): batch_index = [] if distributed_state.is_main_process: for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}\n{batch_data[i]}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) @@ -2900,7 +2899,7 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list: {batch_list}") + logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] From 9027fe94a21e5da6140a44d48bce4eb83d865566 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:36:20 +0800 Subject: [PATCH 041/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index db093707d..5097ebfa5 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2856,7 +2856,7 @@ def scale_and_round(x): test_batch_data_split = [] test_batch_index = [] for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}}") + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: From 9a89208709e4e980c5daa416f5f20ac49fcdab0e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:43:39 +0800 Subject: [PATCH 042/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 66 +++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5097ebfa5..0765bab10 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2447,44 +2447,44 @@ def scale_and_round(x): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] if pi == 0 or len(raw_prompts) > 1: - if distributed_state.is_main_process: + # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: From 56e6956d678337ae3fafdfb8c796e560ca907f53 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:23:41 +0800 Subject: [PATCH 043/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 83 ++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0765bab10..5d5cead26 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1501,10 +1501,13 @@ def main(args): logger.info(f"preferred device: {device}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone() # xformers、Hypernetwork対応 @@ -2443,9 +2446,7 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - if distributed_state.is_main_process: - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2555,7 +2556,7 @@ def scale_and_round(x): logger.info(f"scale: {scale}") continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) if m: # negative scale if m.group(1).lower() == "none": negative_scale = None @@ -2844,31 +2845,29 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() - if distributed_state.is_main_process: - batch_data.append(b1) + batch_data.clear() + + batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: - if distributed_state.is_main_process: - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - test_batch_data_split = [] - test_batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - test_batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - if (i+1) % 4 == 0: - test_batch_data_split.append(test_batch_index.copy()) - test_batch_index.clear() - for i in range(len(test_batch_data_split)): - logger.info(f"test_batch_data_split[{i}]:") - for j in range(len(test_batch_data_split[i])): - logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") + logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") + batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] + batch_index = [] + test_batch_data_split = [] + test_batch_index = [] + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + test_batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + if (i+1) % 4 == 0: + test_batch_data_split.append(test_batch_index.copy()) + test_batch_index.clear() + for i in range(len(test_batch_data_split)): + logger.info(f"test_batch_data_split[{i}]:") + for j in range(len(test_batch_data_split[i])): + logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: @@ -2878,8 +2877,7 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() + batch_data.clear() global_step += 1 @@ -2888,14 +2886,12 @@ def scale_and_round(x): if len(batch_data) > 0: batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - if distributed_state.is_main_process: - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - logger.info(f"{batch_data_split}") + for i in range(len(batch_data)): + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") @@ -2904,8 +2900,7 @@ def scale_and_round(x): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() - if distributed_state.is_main_process: - batch_data.clear() + batch_data.clear() logger.info("done!") From bd2dd6b885be504bbc4cd86a4e7595c9e1da0f05 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:30:36 +0800 Subject: [PATCH 044/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5d5cead26..de636c6e3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2556,7 +2556,7 @@ def scale_and_round(x): logger.info(f"scale: {scale}") continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) if m: # negative scale if m.group(1).lower() == "none": negative_scale = None From b369058040e8bed6733976341e720fb40fa70b66 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:47:42 +0800 Subject: [PATCH 045/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index de636c6e3..a53b992b2 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2843,9 +2843,9 @@ def scale_and_round(x): logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() + prev_image = process_batch(batch_list[0], highres_fix)[0] batch_data.clear() + distributed_state.wait_for_everyone() batch_data.append(b1) if len(batch_data) == args.batch_size*distributed_state.num_processes: @@ -2876,8 +2876,9 @@ def scale_and_round(x): for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() batch_data.clear() + distributed_state.wait_for_everyone() + global_step += 1 @@ -2899,8 +2900,9 @@ def scale_and_round(x): for i in range(len(batch_list[0])): logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] - distributed_state.wait_for_everyone() batch_data.clear() + distributed_state.wait_for_everyone() + logger.info("done!") From fd5b11c634ed26bf42537a9035f75c3a98817232 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:53:25 +0800 Subject: [PATCH 046/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index a53b992b2..632707cbb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2842,7 +2842,7 @@ def scale_and_round(x): logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list:") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] batch_data.clear() distributed_state.wait_for_everyone() @@ -2852,29 +2852,19 @@ def scale_and_round(x): logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - test_batch_data_split = [] - test_batch_index = [] for i in range(len(batch_data)): logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") batch_index.append(batch_data[i]) - test_batch_index.append(batch_data[i]) if (i+1) % args.batch_size == 0: batch_data_split.append(batch_index.copy()) batch_index.clear() - if (i+1) % 4 == 0: - test_batch_data_split.append(test_batch_index.copy()) - test_batch_index.clear() - for i in range(len(test_batch_data_split)): - logger.info(f"test_batch_data_split[{i}]:") - for j in range(len(test_batch_data_split[i])): - logger.info(f"Prompt {j}: {test_batch_data_split[i][j].base.prompt}") with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list:") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] batch_data.clear() distributed_state.wait_for_everyone() @@ -2898,7 +2888,7 @@ def scale_and_round(x): logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") logger.info(f"batch_list:") for i in range(len(batch_list[0])): - logger.info(f"Prompt {i+1}: {batch_list[0][i].base.prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") prev_image = process_batch(batch_list[0], highres_fix)[0] batch_data.clear() distributed_state.wait_for_everyone() From b7781f9b25841c483db34ff680b6dd40bf90fc2d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:59:29 +0800 Subject: [PATCH 047/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 632707cbb..724258268 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2831,12 +2831,12 @@ def scale_and_round(x): batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(batch_data)): - if distributed_state.is_main_process: - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() + logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") + batch_index.append(batch_data[i]) + if (i+1) % args.batch_size == 0: + batch_data_split.append(batch_index.copy()) + batch_index.clear() + with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") From 04739d7cab1b5884144dcab341f51813b6aadfe2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 00:38:06 +0800 Subject: [PATCH 048/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 859 ++++++++++++++++++++---------------------- 1 file changed, 417 insertions(+), 442 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 724258268..14f5e88c4 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -15,6 +15,7 @@ import re import gc from accelerate import PartialState +from accelerate.utils import gather_object import diffusers import numpy as np @@ -81,7 +82,17 @@ 高速化のためのモジュール入れ替え """ +def get_batches(items, batch_size): + num_batches = (len(items) + batch_size - 1) // batch_size + batches = [] + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(items)) + batch = items[start_index:end_index] + batches.append(batch) + + return batches def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") @@ -2442,458 +2453,422 @@ def scale_and_round(x): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + raw_prompts = [] + if distributed_state.is_main_process: + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + else: + distributed_state.wait_for_everyone() + raw_prompts = gather_object(raw_prompts) + if distributed_state.is_main_process: # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 or len(raw_prompts) > 1: - - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + if pi == 0 or len(raw_prompts) > 1: + + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) else: - logger.error("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.warning( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + logger.error("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - logger.info(f"When does this run?\n Loaded {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - batch_data.append(b1) - if len(batch_data) == args.batch_size*distributed_state.num_processes: - logger.info(f"Collected {len(batch_data)} prompts for {distributed_state.num_processes} devices.") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - - global_step += 1 - - prompt_index += 1 - + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + batch_data.extend(b1) + global_step += 1 + + prompt_index += 1 + else: + distributed_state.wait_for_everyone() + batch_data = gather_object(batch_data) + logger.info(f"Total prompts: {len(batch_data)}") if len(batch_data) > 0: + data_loader = get_batches(items=batch_data, batch_size=args.batch_size) + logger.info(f"Total batches: {len(batch_data)}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - for i in range(len(batch_data)): - logger.info(f"Prompt {i+1}: {batch_data[i].base.prompt}") - batch_index.append(batch_data[i]) - if (i+1) % args.batch_size == 0: - batch_data_split.append(batch_index.copy()) - batch_index.clear() - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data.clear() - distributed_state.wait_for_everyone() - - + for i in range(len(data_loader)): + logger.info(f"Loading Batch {i+1} of {len(data_loader)}") + batch_data_split.append(data_loader[i]) + if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader) + continue + with torch.no_grad(): + with distributed_state.split_between_processes(batch_data_split) as batch_list: + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"batch_list:") + for i in range(len(batch_list[0])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") + prev_image = process_batch(batch_list[0], highres_fix)[0] + batch_data_split.clear() + distributed_state.wait_for_everyone() logger.info("done!") From 70ffc73ff1cbfc82c9112d8bca308b7074761996 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 00:50:04 +0800 Subject: [PATCH 049/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 14f5e88c4..45ae068a9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2858,7 +2858,7 @@ def scale_and_round(x): for i in range(len(data_loader)): logger.info(f"Loading Batch {i+1} of {len(data_loader)}") batch_data_split.append(data_loader[i]) - if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader) + if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): continue with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: From 86d572cac5d4a8c6c90a241b57ba27b2f303e52c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 00:54:15 +0800 Subject: [PATCH 050/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 45ae068a9..271d354d8 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2842,7 +2842,7 @@ def scale_and_round(x): num_sub_prompts, ), ) - batch_data.extend(b1) + batch_data.append(b1) global_step += 1 prompt_index += 1 From 8e5642907443707d6dd59f8c981d24c41f8401ad Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 01:08:12 +0800 Subject: [PATCH 051/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 271d354d8..aec137d2c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2852,7 +2852,7 @@ def scale_and_round(x): logger.info(f"Total prompts: {len(batch_data)}") if len(batch_data) > 0: data_loader = get_batches(items=batch_data, batch_size=args.batch_size) - logger.info(f"Total batches: {len(batch_data)}") + logger.info(f"Total batches: {len(data_loader)}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] for i in range(len(data_loader)): @@ -2862,7 +2862,7 @@ def scale_and_round(x): continue with torch.no_grad(): with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.device}:") + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[0])): logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") From 8340fdce63bc1c5d4d19356e837b2255d00d3d08 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 01:24:09 +0800 Subject: [PATCH 052/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index aec137d2c..2964ffac0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2456,8 +2456,6 @@ def scale_and_round(x): raw_prompts = [] if distributed_state.is_main_process: raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - else: - distributed_state.wait_for_everyone() raw_prompts = gather_object(raw_prompts) if distributed_state.is_main_process: @@ -2846,8 +2844,7 @@ def scale_and_round(x): global_step += 1 prompt_index += 1 - else: - distributed_state.wait_for_everyone() + distributed_state.wait_for_everyone() batch_data = gather_object(batch_data) logger.info(f"Total prompts: {len(batch_data)}") if len(batch_data) > 0: From e7f7ade84347ffdba96314fdca1dcd8b654f96f8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 01:28:24 +0800 Subject: [PATCH 053/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 2964ffac0..05eb2e2ca 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2844,7 +2844,7 @@ def scale_and_round(x): global_step += 1 prompt_index += 1 - distributed_state.wait_for_everyone() + batch_data = gather_object(batch_data) logger.info(f"Total prompts: {len(batch_data)}") if len(batch_data) > 0: From 2b5a95eede2de1aca363961f9a29b7305b45bd0c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:06:25 +0800 Subject: [PATCH 054/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 05eb2e2ca..dbb702c5c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2852,20 +2852,28 @@ def scale_and_round(x): logger.info(f"Total batches: {len(data_loader)}") batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] - for i in range(len(data_loader)): - logger.info(f"Loading Batch {i+1} of {len(data_loader)}") - batch_data_split.append(data_loader[i]) - if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): - continue - with torch.no_grad(): - with distributed_state.split_between_processes(batch_data_split) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] - batch_data_split.clear() - distributed_state.wait_for_everyone() + with distributed_state.split_between_processes(data_loader) as batch_list: + logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"batch_list:") + for i in range(len(batch_list[0])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") + prev_image = process_batch(batch_list[0], highres_fix)[0] + + distributed_state.wait_for_everyone() + #for i in range(len(data_loader)): + # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") + # batch_data_split.append(data_loader[i]) + # if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): + # continue + # with torch.no_grad(): + # with distributed_state.split_between_processes(batch_data_split) as batch_list: + # logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") + # logger.info(f"batch_list:") + # for i in range(len(batch_list[0])): + # logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") + # prev_image = process_batch(batch_list[0], highres_fix)[0] + # batch_data_split.clear() + # distributed_state.wait_for_everyone() logger.info("done!") From 25321e14480ef039ab6e03d0d22f8d1fb3be0427 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:18:28 +0800 Subject: [PATCH 055/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index dbb702c5c..aeea43f7c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2853,11 +2853,12 @@ def scale_and_round(x): batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] batch_index = [] with distributed_state.split_between_processes(data_loader) as batch_list: - logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[0])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] + for j in range(len(batch_list)): + logger.info(f"Loading batch of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"batch_list:") + for i in range(len(batch_list[j])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}") + prev_image = process_batch(batch_list[0], highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From 067096819f5604b6ae1a37321d22421be116e9c9 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:29:16 +0800 Subject: [PATCH 056/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index aeea43f7c..5b8ee0ff8 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2846,7 +2846,9 @@ def scale_and_round(x): prompt_index += 1 batch_data = gather_object(batch_data) - logger.info(f"Total prompts: {len(batch_data)}") + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: + logger.info(f"Total prompts: {len(batch_data)} for {distributed_state.local_process_index}") if len(batch_data) > 0: data_loader = get_batches(items=batch_data, batch_size=args.batch_size) logger.info(f"Total batches: {len(data_loader)}") From 251e9a7d2930c0a94024647a5bf520f057546219 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:30:56 +0800 Subject: [PATCH 057/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5b8ee0ff8..e74447588 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2858,9 +2858,9 @@ def scale_and_round(x): for j in range(len(batch_list)): logger.info(f"Loading batch of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") - for i in range(len(batch_list[j])): + for i in range(len(batch_list[j])): logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}") - prev_image = process_batch(batch_list[0], highres_fix)[0] + prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From 7d04dcb1a0931e458181e3a3ffab51e34dad5e1b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:46:56 +0800 Subject: [PATCH 058/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e74447588..67dd44933 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2457,7 +2457,9 @@ def scale_and_round(x): if distributed_state.is_main_process: raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) raw_prompts = gather_object(raw_prompts) - + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: + logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}") if distributed_state.is_main_process: # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): From 0806954d67d89d72d4ec063e4be3ff4378796555 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:02:30 +0800 Subject: [PATCH 059/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 67dd44933..775c3c429 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2454,12 +2454,15 @@ def scale_and_round(x): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) raw_prompts = [] + logger.info(f"Total raw prompts before load: {len(raw_prompts)} for {distributed_state.local_process_index}") if distributed_state.is_main_process: raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) raw_prompts = gather_object(raw_prompts) for pi in range(distributed_state.num_processes): if pi == distributed_state.local_process_index: logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}") + if not distributed_state.is_main_process: + break if distributed_state.is_main_process: # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): From df9936eb13e77e09937b7bcbd50c626e9fca08ed Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:12:14 +0800 Subject: [PATCH 060/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 773 +++++++++++++++++++++--------------------- 1 file changed, 384 insertions(+), 389 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 775c3c429..67b528d3d 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2433,7 +2433,7 @@ def scale_and_round(x): prompt_index = 0 global_step = 0 batch_data = [] - while args.interactive or prompt_index < len(prompt_list): + while args.interactive or (prompt_index < len(prompt_list) and not distributed_state.is_main_process): if len(prompt_list) == 0: # interactive valid = False @@ -2453,403 +2453,398 @@ def scale_and_round(x): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = [] - logger.info(f"Total raw prompts before load: {len(raw_prompts)} for {distributed_state.local_process_index}") - if distributed_state.is_main_process: - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - raw_prompts = gather_object(raw_prompts) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + for pi in range(distributed_state.num_processes): if pi == distributed_state.local_process_index: logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}") - if not distributed_state.is_main_process: - break - if distributed_state.is_main_process: + # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 or len(raw_prompts) > 1: - - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + if pi == 0 or len(raw_prompts) > 1: + + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - - # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - logger.error("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.warning( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." + logger.error("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - batch_data.append(b1) - global_step += 1 - - prompt_index += 1 - + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + batch_data.append(b1) + global_step += 1 + + prompt_index += 1 + distributed_state.wait_for_everyone() batch_data = gather_object(batch_data) for pi in range(distributed_state.num_processes): if pi == distributed_state.local_process_index: From a832925553a736d0c626539ae2410a1ad2295a9d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:44:55 +0800 Subject: [PATCH 061/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 67b528d3d..bbe55f33e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2856,7 +2856,7 @@ def scale_and_round(x): batch_index = [] with distributed_state.split_between_processes(data_loader) as batch_list: for j in range(len(batch_list)): - logger.info(f"Loading batch of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"Loading batch {j}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}") From 2328128b1ca335b4b1edcb4ec42554b496ea1edc Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:49:59 +0800 Subject: [PATCH 062/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bbe55f33e..98d334131 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2454,10 +2454,6 @@ def scale_and_round(x): # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}") # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): @@ -2846,14 +2842,9 @@ def scale_and_round(x): prompt_index += 1 distributed_state.wait_for_everyone() batch_data = gather_object(batch_data) - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"Total prompts: {len(batch_data)} for {distributed_state.local_process_index}") + if len(batch_data) > 0: data_loader = get_batches(items=batch_data, batch_size=args.batch_size) - logger.info(f"Total batches: {len(data_loader)}") - batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)] - batch_index = [] with distributed_state.split_between_processes(data_loader) as batch_list: for j in range(len(batch_list)): logger.info(f"Loading batch {j}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") From ebcca5512c9fcc366e33933c70652180c3f1a1fa Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 04:30:29 +0800 Subject: [PATCH 063/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 98d334131..085f31bbb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2433,6 +2433,7 @@ def scale_and_round(x): prompt_index = 0 global_step = 0 batch_data = [] + logger.info(f"args.interactive: {args.interactive}, (prompt_index: {prompt_index} < len(prompt_list): {len(prompt_list)} and not distributed_state.is_main_process: {not distributed_state.is_main_process})") while args.interactive or (prompt_index < len(prompt_list) and not distributed_state.is_main_process): if len(prompt_list) == 0: # interactive From 9f47400cb02521d24daa3a58be92fb775b05e692 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 04:42:02 +0800 Subject: [PATCH 064/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 085f31bbb..3d29cffe2 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2434,7 +2434,7 @@ def scale_and_round(x): global_step = 0 batch_data = [] logger.info(f"args.interactive: {args.interactive}, (prompt_index: {prompt_index} < len(prompt_list): {len(prompt_list)} and not distributed_state.is_main_process: {not distributed_state.is_main_process})") - while args.interactive or (prompt_index < len(prompt_list) and not distributed_state.is_main_process): + while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): if len(prompt_list) == 0: # interactive valid = False From d310bc99ac91e09eae30af6b5d578920b8b3bd90 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 04:44:00 +0800 Subject: [PATCH 065/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3d29cffe2..9a7c7eddd 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2848,7 +2848,7 @@ def scale_and_round(x): data_loader = get_batches(items=batch_data, batch_size=args.batch_size) with distributed_state.split_between_processes(data_loader) as batch_list: for j in range(len(batch_list)): - logger.info(f"Loading batch {j}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}") From fecaa7d52b11694282d7922d0bbaf9f955cd3e66 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:24:05 +0800 Subject: [PATCH 066/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9a7c7eddd..fcc17d6d7 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2851,7 +2851,7 @@ def scale_and_round(x): logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\NNegative Prompt: {batch_list[j][i].base.negative_prompt}") prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() From ce620f52ab75c2698901b8726594e65304156e19 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:25:13 +0800 Subject: [PATCH 067/221] Create accel_sdxl_gen_img_v2.py --- accel_sdxl_gen_img_v2.py | 3261 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 3261 insertions(+) create mode 100644 accel_sdxl_gen_img_v2.py diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img_v2.py new file mode 100644 index 000000000..fcc17d6d7 --- /dev/null +++ b/accel_sdxl_gen_img_v2.py @@ -0,0 +1,3261 @@ +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re +import gc +from accelerate import PartialState +from accelerate.utils import gather_object + +import diffusers +import numpy as np + +import torch +from library.device_utils import init_ipex, clean_memory, get_preferred_device, clean_memory_on_device +init_ipex() + +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite +from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +def get_batches(items, batch_size): + num_batches = (len(items) + batch_size - 1) // batch_size + batches = [] + + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(items)) + batch = items[start_index:end_index] + batches.append(batch) + + return batches +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + logger.info("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + logger.info("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + logger.info("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + logger.info("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + logger.info("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: InferSdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: InferSdxlUNet2DConditionModel = unet + self.scheduler = scheduler + self.safety_checker = None + + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet # not supported yet + self.control_nets: List[ControlNetLLLite] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + self.gradual_latent: GradualLatent = None + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # logger.info("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + def set_gradual_latent(self, gradual_latent): + if gradual_latent is None: + logger.info("gradual_latent is disabled") + self.gradual_latent = None + else: + logger.info(f"gradual_latent is enabled: {gradual_latent}") + self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + # use last text_pool, because it is from text encoder 2 + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_nets: + # ControlNetのhintにguide imageを流用する + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + if reginonal_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + + if init_image is not None and self.clip_vision_model is not None: + logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + clean_memory() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net, _ in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net, _ in self.control_nets: + control_net.set_cond_image(None) + + each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) + + # # first, we downscale the latents to the half of the size + # # 最初に1/2に縮小する + # height, width = latents.shape[-2:] + # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( + # # latents.dtype + # # ) + # latents = latents[:, :, ::2, ::2] + # current_scale = 0.5 + + # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) + # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) + # scale_step = 0.125 + + # # timesteps at which to start increasing the scale: 1000 seems to be enough + # # 拡大を開始するtimesteps: 1000で十分そうである + # start_timesteps = 1000 + + # # how many steps to wait before increasing the scale again + # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) + # # large values leads to flat images + + # # 何ステップごとに拡大するか + # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) + # # 大きすぎると細部が書き込まれずのっぺりした感じになる + # every_n_steps = 5 + + # scale_step = input("scale step:") + # scale_step = float(scale_step) + # start_timesteps = input("start timesteps:") + # start_timesteps = int(start_timesteps) + # every_n_steps = input("every n steps:") + # every_n_steps = int(every_n_steps) + + # # for i, t in enumerate(tqdm(timesteps)): + # i = 0 + # last_step = 0 + # while i < len(timesteps): + # t = timesteps[i] + # print(f"[{i}] t={t}") + + # print(i, t, current_scale, latents.shape) + # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: + # if i == last_step: + # pass + # else: + # print("upscale") + # current_scale = min(current_scale + scale_step, 1.0) + + # h = int(height * current_scale) // 8 * 8 + # w = int(width * current_scale) // 8 * 8 + + # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( + # latents.dtype + # ) + # last_step = i + # i = max(0, i - every_n_steps + 1) + + # diff = timesteps[i] - timesteps[last_step] + # # resized_init_noise = torch.nn.functional.interpolate( + # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False + # # ).to(latents.dtype) + # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) + # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) + # # latents += torch.randn_like(latents) / 100 * diff + # continue + + enable_gradual_latent = False + if self.gradual_latent: + if not hasattr(self.scheduler, "set_gradual_latent_params"): + logger.info("gradual_latent is not supported for this scheduler. Ignoring.") + logger.info(f'{self.scheduler.__class__.__name__}') + else: + enable_gradual_latent = True + step_elapsed = 1000 + current_ratio = self.gradual_latent.ratio + + # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする + height, width = latents.shape[-2:] + org_dtype = latents.dtype + if org_dtype == torch.bfloat16: + latents = latents.float() + latents = torch.nn.functional.interpolate( + latents, scale_factor=current_ratio, mode="bicubic", align_corners=False + ).to(org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if self.gradual_latent.gaussian_blur_ksize: + latents = self.gradual_latent.apply_unshark_mask(latents) + + for i, t in enumerate(tqdm(timesteps)): + resized_size = None + if enable_gradual_latent: + # gradually upscale the latents / latentsを徐々にアップスケールする + if ( + t < self.gradual_latent.start_timesteps + and current_ratio < 1.0 + and step_elapsed >= self.gradual_latent.every_n_steps + ): + current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) + # make divisible by 8 because size of latents must be divisible at bottom of UNet + h = int(height * current_ratio) // 8 * 8 + w = int(width * current_ratio) // 8 * 8 + resized_size = (h, w) + self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) + step_elapsed = 0 + else: + self.scheduler.set_gradual_latent_params(None, None) + step_elapsed += 1 + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # disable control net if ratio is set + if self.control_nets and self.control_net_enabled: + for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): + if not enabled or ratio >= 1.0: + continue + if ratio < i / len(timesteps): + logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") + control_net.set_cond_image(None) + each_control_net_enabled[j] = False + + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + i += 1 + + if return_latents: + return latents + + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + clean_memory() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + clean_memory() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + logger.info(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # -2 is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-2] + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-2] + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + logger.warning(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + +# def load_clip_l14_336(dtype): +# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + raw_prompt: str + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + # モデルを読み込む + logger.info("preparing pipes") + #accelerator = train_util.prepare_accelerator(args) + #is_main_process = accelerator.is_main_process + distributed_state = PartialState() + device = distributed_state.device + + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + #device = get_preferred_device() + logger.info(f"preferred device: {device}") + clean_memory_on_device(device) + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + distributed_state.wait_for_everyone() + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + logger.info("loading tokenizer") + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteSchedulerGL + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # logger.info("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + logger.info("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + vae.eval() + + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) + text_encoder1.eval() + text_encoder2.eval() + unet.eval() + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + logger.info(f"network_merge: {network_merge}") + + for i, network_module in enumerate(args.network_module): + logger.info(f"import network module: {network_module}") + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") + + network_weight = args.network_weights[i] + logger.info(f"load network weights from: {network_weight}") + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + logger.info(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) + if network is None: + return + + mergeable = network.is_mergeable() + if network_merge and not mergeable: + logger.warning("network is not mergiable. ignore merge option.") + + if not mergeable or i >= network_merge: + # not merging + network.apply_to([text_encoder1, text_encoder2], unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + logger.info(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + logger.info("backup original weights") + network.backup_weights() + + networks.append(network) + network_default_muls.append(network_mul) + else: + network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + logger.info(f"import upscaler module: {args.highres_fix_upscaler}") + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + logger.info("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[Tuple[ControlNetLLLite, float]] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + logger.info(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + multiplier = ( + 1.0 + if not args.control_net_multipliers or len(args.control_net_multipliers) <= i + else args.control_net_multipliers[i] + ) + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append((control_net, ratio)) + + if args.opt_channels_last: + logger.info(f"set optimizing: channels last") + text_encoder1.to(memory_format=torch.channels_last) + text_encoder2.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + [text_encoder1, text_encoder2], + [tokenizer1, tokenizer2], + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + logger.info(f"pipeline on {device} is ready.") + distributed_state.wait_for_everyone() + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + + # Gradual Latent + if args.gradual_latent_timesteps is not None: + if args.gradual_latent_unsharp_params: + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + + gradual_latent = GradualLatent( + args.gradual_latent_ratio, + args.gradual_latent_timesteps, + args.gradual_latent_every_n_steps, + args.gradual_latent_ratio_step, + args.gradual_latent_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizer1.add_tokens(token_strings) + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) + token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) + logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoder1.resize_token_embeddings(len(tokenizer1)) + text_encoder2.resize_token_embeddings(len(tokenizer2)) + token_embeds1 = text_encoder1.get_input_embeddings().weight.data + token_embeds2 = text_encoder2.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + logger.info(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + logger.info(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + logger.info(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + logger.info(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + logger.info(f"CLIP Vision model loaded.") + + else: + init_images = None + + if args.mask_path is not None: + logger.info(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + logger.info(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + logger.info("get prompts from images' metadata") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + logger.info(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + logger.info(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + logger.info("use mask as region") + + size = None + for i, network in enumerate(networks): + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: + np_mask = np.array(mask_images[0]) + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + logger.info(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + logger.warning( + f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" + ) + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 + if args.H is None: + args.H = 1024 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + logger.info(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + logger.info("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + logger.info("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + raw_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, ( + _, + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + _, + ) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + raw_prompts.append(raw_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + logger.info("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + logger.info(f"args.interactive: {args.interactive}, (prompt_index: {prompt_index} < len(prompt_list): {len(prompt_list)} and not distributed_state.is_main_process: {not distributed_state.is_main_process})") + while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + logger.info("") + logger.info("Type prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + if pi == 0 or len(raw_prompts) > 1: + + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + + # Gradual Latent + gl_timesteps = None # means no override + gl_ratio = args.gradual_latent_ratio + gl_every_n_steps = args.gradual_latent_every_n_steps + gl_ratio_step = args.gradual_latent_ratio_step + gl_s_noise = args.gradual_latent_s_noise + gl_unsharp_params = args.gradual_latent_unsharp_params + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + logger.info(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + logger.info(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + logger.info(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + logger.info(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + logger.info(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + logger.info(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + logger.info(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + logger.info(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + logger.info(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + logger.info(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + logger.info(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + logger.info(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + logger.info(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + logger.info(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + logger.info(f"network mul: {network_muls}") + continue + + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + logger.info(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + logger.info(f"deep shrink ratio: {ds_ratio}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + # Gradual Latent + m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent timesteps + gl_timesteps = int(m.group(1)) + logger.info(f"gradual latent timesteps: {gl_timesteps}") + continue + + m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio + gl_ratio = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio: {ds_ratio}") + continue + + m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent every n steps + gl_every_n_steps = int(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent every n steps: {gl_every_n_steps}") + continue + + m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent ratio step + gl_ratio_step = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent ratio step: {gl_ratio_step}") + continue + + m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) + if m: # gradual latent s noise + gl_s_noise = float(m.group(1)) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent s noise: {gl_s_noise}") + continue + + m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # gradual latent unsharp params + gl_unsharp_params = m.group(1) + gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override + logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") + + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + + # override Gradual Latent + if gl_timesteps is not None: + if gl_timesteps < 0: + gl_timesteps = args.gradual_latent_timesteps or 650 + if gl_unsharp_params is not None: + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) + else: + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None + gradual_latent = GradualLatent( + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, + ) + pipe.set_gradual_latent(gradual_latent) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + logger.error("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + logger.info(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + logger.warning( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase( + global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + ), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + batch_data.append(b1) + global_step += 1 + + prompt_index += 1 + distributed_state.wait_for_everyone() + batch_data = gather_object(batch_data) + + if len(batch_data) > 0: + data_loader = get_batches(items=batch_data, batch_size=args.batch_size) + with distributed_state.split_between_processes(data_loader) as batch_list: + for j in range(len(batch_list)): + logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"batch_list:") + for i in range(len(batch_list[j])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\NNegative Prompt: {batch_list[j][i].base.negative_prompt}") + prev_image = process_batch(batch_list[j], highres_fix)[0] + + distributed_state.wait_for_everyone() + #for i in range(len(data_loader)): + # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") + # batch_data_split.append(data_loader[i]) + # if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): + # continue + # with torch.no_grad(): + # with distributed_state.split_between_processes(batch_data_split) as batch_list: + # logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") + # logger.info(f"batch_list:") + # for i in range(len(batch_list[0])): + # logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") + # prev_image = process_batch(batch_list[0], highres_fix)[0] + # batch_data_split.clear() + # distributed_state.wait_for_everyone() + logger.info("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + #sdxl_train_util.add_sdxl_training_arguments(parser) + add_logging_arguments(parser) + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", + type=str, + default=None, + help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument( + "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" + ) + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", + type=int, + default=None, + help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width", + type=int, + default=None, + help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument( + "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" + ) + parser.add_argument( + "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument( + "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument( + "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", + action="store_true", + help="set channels last option to model / モデルにchannels lastを指定し最適化する", + ) + parser.add_argument( + "--network_module", + type=str, + default=None, + nargs="*", + help="additional network module to use / 追加ネットワークを使う時そのモジュール名", + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument( + "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" + ) + parser.add_argument( + "--network_merge_n_models", + type=int, + default=None, + help="merge this number of networks / この数だけネットワークをマージする", + ) + parser.add_argument( + "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" + ) + parser.add_argument( + "--network_pre_calc", + action="store_true", + help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", + ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument( + "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" + ) + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", + type=int, + default=28, + help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", + action="store_true", + help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", + type=str, + default=None, + help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", + type=float, + default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", + ) + + parser.add_argument( + "--control_net_lllite_models", + type=str, + default=None, + nargs="*", + help="ControlNet models to use / 使用するControlNetのモデル名", + ) + # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + parser.add_argument( + "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" + ) + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + + # gradual latent + parser.add_argument( + "--gradual_latent_timesteps", + type=int, + default=None, + help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", + ) + parser.add_argument( + "--gradual_latent_ratio", + type=float, + default=0.5, + help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", + ) + parser.add_argument( + "--gradual_latent_ratio_step", + type=float, + default=0.125, + help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", + ) + parser.add_argument( + "--gradual_latent_every_n_steps", + type=int, + default=3, + help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", + ) + parser.add_argument( + "--gradual_latent_s_noise", + type=float, + default=1.0, + help="s_noise for Gradual Latent / Gradual Latentのs_noise", + ) + parser.add_argument( + "--gradual_latent_unsharp_params", + type=str, + default=None, + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", + ) + parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16") + parser.add_argument( + "--full_bf16", action="store_true", help="Loading model in bf16" + ) + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + # # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + setup_logging(args, reset=True) + main(args) From dd2a1f23a66caabb48b4f9e91ca0beac0e31b411 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 14:31:02 +0800 Subject: [PATCH 068/221] Update accel_sdxl_gen_img_v2.py --- accel_sdxl_gen_img_v2.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img_v2.py index fcc17d6d7..5a4359ebe 100644 --- a/accel_sdxl_gen_img_v2.py +++ b/accel_sdxl_gen_img_v2.py @@ -1482,6 +1482,7 @@ class BatchDataExt(NamedTuple): class BatchData(NamedTuple): return_latents: bool + global_count: int base: BatchDataBase ext: BatchDataExt @@ -2127,7 +2128,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): logger.info("process 1st stage") batch_1st = [] - for _, base, ext in batch: + for _, global_count, base, ext in batch: def scale_and_round(x): if x is None: @@ -2164,7 +2165,7 @@ def scale_and_round(x): ext.network_muls, ext.num_sub_prompts, ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + batch_1st.append(BatchData(is_1st_latent, global_count, base, ext_1st)) pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする images_1st = process_batch(batch_1st, True, True) @@ -2206,7 +2207,7 @@ def scale_and_round(x): batch_2nd = [] for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + bd_2nd = BatchData(False, bd.global_count, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) batch_2nd.append(bd_2nd) batch = batch_2nd @@ -2216,6 +2217,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, + _, (step_first, _, _, _, init_image, mask_image, _, guide_image, _), ( width, @@ -2236,6 +2238,7 @@ def scale_and_round(x): ) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + global_counter = [] prompts = [] negative_prompts = [] raw_prompts = [] @@ -2271,9 +2274,11 @@ def scale_and_round(x): all_guide_images_are_same = True for i, ( _, + globalcount, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), _, ) in enumerate(batch): + global_counter.append(globalcount) prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) @@ -2376,8 +2381,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed @@ -2408,9 +2413,9 @@ def scale_and_round(x): else: fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + fln = f"im_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" image.save(os.path.join(args.outdir, fln), pnginfo=metadata) @@ -2817,6 +2822,7 @@ def scale_and_round(x): b1 = BatchData( False, + global_step, BatchDataBase( global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt ), From 884ea48600120815225336ea5985841b118363b7 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 15:02:10 +0800 Subject: [PATCH 069/221] Update accel_sdxl_gen_img_v2.py --- accel_sdxl_gen_img_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img_v2.py index 5a4359ebe..2adc1f872 100644 --- a/accel_sdxl_gen_img_v2.py +++ b/accel_sdxl_gen_img_v2.py @@ -2857,7 +2857,7 @@ def scale_and_round(x): logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\NNegative Prompt: {batch_list[j][i].base.negative_prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}") prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() From 29ac2238a55db852cc2641c486d88b65a0f321fe Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 15:02:55 +0800 Subject: [PATCH 070/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index fcc17d6d7..20cf529cd 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2851,7 +2851,7 @@ def scale_and_round(x): logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\NNegative Prompt: {batch_list[j][i].base.negative_prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}") prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() From 13a40bdf8edf8bb7af548bb067f7ba5cca4add78 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 26 Jan 2025 15:11:58 +0800 Subject: [PATCH 071/221] Update accel_sdxl_gen_img_v2.py --- accel_sdxl_gen_img_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img_v2.py index 2adc1f872..fae18d8e8 100644 --- a/accel_sdxl_gen_img_v2.py +++ b/accel_sdxl_gen_img_v2.py @@ -2417,6 +2417,9 @@ def scale_and_round(x): else: fln = f"im_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") + if negative_prompt is not None: + logger.info(f"Negative Prompt: {negative_prompt}\n") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: From c0b771baee283498c6022263a43205b3d9e519b8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 01:13:12 +0800 Subject: [PATCH 072/221] Update accel_sdxl_gen_img_v2.py --- accel_sdxl_gen_img_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img_v2.py index fae18d8e8..d0a3d2fd1 100644 --- a/accel_sdxl_gen_img_v2.py +++ b/accel_sdxl_gen_img_v2.py @@ -2441,7 +2441,6 @@ def scale_and_round(x): prompt_index = 0 global_step = 0 batch_data = [] - logger.info(f"args.interactive: {args.interactive}, (prompt_index: {prompt_index} < len(prompt_list): {len(prompt_list)} and not distributed_state.is_main_process: {not distributed_state.is_main_process})") while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): if len(prompt_list) == 0: # interactive From dba2f03937c6404e1f07faa3998cfa13536707c5 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 01:16:53 +0800 Subject: [PATCH 073/221] Delete accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3261 ----------------------------------------- 1 file changed, 3261 deletions(-) delete mode 100644 accel_sdxl_gen_img.py diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py deleted file mode 100644 index 20cf529cd..000000000 --- a/accel_sdxl_gen_img.py +++ /dev/null @@ -1,3261 +0,0 @@ -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re -import gc -from accelerate import PartialState -from accelerate.utils import gather_object - -import diffusers -import numpy as np - -import torch -from library.device_utils import init_ipex, clean_memory, get_preferred_device, clean_memory_on_device -init_ipex() - -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -import library.sdxl_model_util as sdxl_model_util -import library.sdxl_train_util as sdxl_train_util -from networks.lora import LoRANetwork -from library.sdxl_original_unet import InferSdxlUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from networks.control_net_lllite import ControlNetLLLite -from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -def get_batches(items, batch_size): - num_batches = (len(items) + batch_size - 1) // batch_size - batches = [] - - for i in range(num_batches): - start_index = i * batch_size - end_index = min((i + 1) * batch_size, len(items)) - batch = items[start_index:end_index] - batches.append(batch) - - return batches -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - logger.info("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - logger.info("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - logger.info("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - logger.info("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - logger.info("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoders: List[CLIPTextModel], - tokenizers: List[CLIPTokenizer], - unet: InferSdxlUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoders = text_encoders - self.tokenizers = tokenizers - self.unet: InferSdxlUNet2DConditionModel = unet - self.scheduler = scheduler - self.safety_checker = None - - self.clip_vision_model: CLIPVisionModelWithProjection = None - self.clip_vision_processor: CLIPImageProcessor = None - self.clip_vision_strength = 0.0 - - # Textual Inversion - self.token_replacements_list = [] - for _ in range(len(self.text_encoders)): - self.token_replacements_list.append({}) - - # ControlNet # not supported yet - self.control_nets: List[ControlNetLLLite] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - self.gradual_latent: GradualLatent = None - - # Textual Inversion - def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): - self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def get_token_replacer(self, tokenizer): - tokenizer_index = self.tokenizers.index(tokenizer) - token_replacements = self.token_replacements_list[tokenizer_index] - - def replace_tokens(tokens): - # logger.info("replace_tokens", tokens, "=>", token_replacements) - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - new_tokens = [] - for token in tokens: - if token in token_replacements: - replacement = token_replacements[token] - new_tokens.extend(replacement) - else: - new_tokens.append(token) - return new_tokens - - return replace_tokens - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - def set_gradual_latent(self, gradual_latent): - if gradual_latent is None: - logger.info("gradual_latent is disabled") - self.gradual_latent = None - else: - logger.info(f"gradual_latent is enabled: {gradual_latent}") - self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 1024, - width: int = 1024, - original_height: int = None, - original_width: int = None, - original_height_negative: int = None, - original_width_negative: int = None, - crop_top: int = 0, - crop_left: int = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_guide_images=None, - **kwargs, - ): - # TODO support secondary prompt - num_images_per_prompt = 1 # fixed because already prompt is repeated - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - reginonal_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - tes_text_embs = [] - tes_uncond_embs = [] - tes_real_uncond_embs = [] - - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - token_replacer = self.get_token_replacer(tokenizer) - - # use last text_pool, because it is from text encoder 2 - text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_text_embs.append(text_embeddings) - tes_uncond_embs.append(uncond_embeddings) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - token_replacer, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_real_uncond_embs.append(real_uncond_embeddings) - - # concat text encoder outputs - text_embeddings = tes_text_embs[0] - uncond_embeddings = tes_uncond_embs[0] - for i in range(1, len(tes_text_embs)): - text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - if do_classifier_free_guidance: - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - if self.control_nets: - # ControlNetのhintにguide imageを流用する - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - if isinstance(clip_guide_images[0], PIL.Image.Image): - clip_guide_images = [preprocess_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images) - if isinstance(clip_guide_images, list): - clip_guide_images = torch.stack(clip_guide_images) - - clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) - - # create size embs - if original_height is None: - original_height = height - if original_width is None: - original_width = width - if original_height_negative is None: - original_height_negative = original_height - if original_width_negative is None: - original_width_negative = original_width - if crop_top is None: - crop_top = 0 - if crop_left is None: - crop_left = 0 - emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - uc_emb1 = sdxl_train_util.get_timestep_embedding( - torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 - ) - emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - - if reginonal_network: - # use last pool for conditioning - num_sub_prompts = len(text_pool) // batch_size - text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt - - if init_image is not None and self.clip_vision_model is not None: - logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") - vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) - pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) - - clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) - clip_vision_embeddings = clip_vision_embeddings.image_embeds - - if len(clip_vision_embeddings) == 1 and batch_size > 1: - clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) - - clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength - assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" - text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) - - c_vector = torch.cat([text_pool, c_vector], dim=1) - if do_classifier_free_guidance: - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - vector_embeddings = torch.cat([uc_vector, c_vector]) - else: - vector_embeddings = c_vector - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - clean_memory() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( - self.vae.dtype - ) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - if self.control_net_enabled: - for control_net, _ in self.control_nets: - with torch.no_grad(): - control_net.set_cond_image(clip_guide_images) - else: - for control_net, _ in self.control_nets: - control_net.set_cond_image(None) - - each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) - - # # first, we downscale the latents to the half of the size - # # 最初に1/2に縮小する - # height, width = latents.shape[-2:] - # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( - # # latents.dtype - # # ) - # latents = latents[:, :, ::2, ::2] - # current_scale = 0.5 - - # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) - # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) - # scale_step = 0.125 - - # # timesteps at which to start increasing the scale: 1000 seems to be enough - # # 拡大を開始するtimesteps: 1000で十分そうである - # start_timesteps = 1000 - - # # how many steps to wait before increasing the scale again - # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) - # # large values leads to flat images - - # # 何ステップごとに拡大するか - # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) - # # 大きすぎると細部が書き込まれずのっぺりした感じになる - # every_n_steps = 5 - - # scale_step = input("scale step:") - # scale_step = float(scale_step) - # start_timesteps = input("start timesteps:") - # start_timesteps = int(start_timesteps) - # every_n_steps = input("every n steps:") - # every_n_steps = int(every_n_steps) - - # # for i, t in enumerate(tqdm(timesteps)): - # i = 0 - # last_step = 0 - # while i < len(timesteps): - # t = timesteps[i] - # print(f"[{i}] t={t}") - - # print(i, t, current_scale, latents.shape) - # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: - # if i == last_step: - # pass - # else: - # print("upscale") - # current_scale = min(current_scale + scale_step, 1.0) - - # h = int(height * current_scale) // 8 * 8 - # w = int(width * current_scale) // 8 * 8 - - # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( - # latents.dtype - # ) - # last_step = i - # i = max(0, i - every_n_steps + 1) - - # diff = timesteps[i] - timesteps[last_step] - # # resized_init_noise = torch.nn.functional.interpolate( - # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False - # # ).to(latents.dtype) - # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) - # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) - # # latents += torch.randn_like(latents) / 100 * diff - # continue - - enable_gradual_latent = False - if self.gradual_latent: - if not hasattr(self.scheduler, "set_gradual_latent_params"): - logger.info("gradual_latent is not supported for this scheduler. Ignoring.") - logger.info(f'{self.scheduler.__class__.__name__}') - else: - enable_gradual_latent = True - step_elapsed = 1000 - current_ratio = self.gradual_latent.ratio - - # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする - height, width = latents.shape[-2:] - org_dtype = latents.dtype - if org_dtype == torch.bfloat16: - latents = latents.float() - latents = torch.nn.functional.interpolate( - latents, scale_factor=current_ratio, mode="bicubic", align_corners=False - ).to(org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - latents = self.gradual_latent.apply_unshark_mask(latents) - - for i, t in enumerate(tqdm(timesteps)): - resized_size = None - if enable_gradual_latent: - # gradually upscale the latents / latentsを徐々にアップスケールする - if ( - t < self.gradual_latent.start_timesteps - and current_ratio < 1.0 - and step_elapsed >= self.gradual_latent.every_n_steps - ): - current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) - # make divisible by 8 because size of latents must be divisible at bottom of UNet - h = int(height * current_ratio) // 8 * 8 - w = int(width * current_ratio) // 8 * 8 - resized_size = (h, w) - self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) - step_elapsed = 0 - else: - self.scheduler.set_gradual_latent_params(None, None) - step_elapsed += 1 - - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # disable control net if ratio is set - if self.control_nets and self.control_net_enabled: - for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): - if not enabled or ratio >= 1.0: - continue - if ratio < i / len(timesteps): - logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") - control_net.set_cond_image(None) - each_control_net_enabled[j] = False - - # predict the noise residual - # TODO Diffusers' ControlNet - # if self.control_nets and self.control_net_enabled: - # if reginonal_network: - # num_sub_and_neg_prompts = len(text_embeddings) // batch_size - # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - # else: - # text_emb_last = text_embeddings - - # # not working yet - # noise_pred = original_control_net.call_unet_and_control_net( - # i, - # num_latent_input, - # self.unet, - # self.control_nets, - # guided_hints, - # i / len(timesteps), - # latent_model_input, - # t, - # text_emb_last, - # ).sample - # else: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - i += 1 - - if return_latents: - return latents - - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents.to(self.vae.dtype)).sample - else: - clean_memory() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode( - (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) - ).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - clean_memory() - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - return image - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - logger.info(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(tokenizer.eos_token_id) - # else: - text_token.append(tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - - token = token_replacer(token) # for Textual Inversion - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - text_encoder: CLIPTextModel, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - pool = None - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - # -2 is same for Text Encoder 1 and 2 - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-2] - if pool is None: - pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-2] - pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) - return text_embeddings, pool - - -def get_weighted_text_embeddings( - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - token_replacer=None, - device=None, - **kwargs, -): - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - text_encoder, - uncond_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens - return text_embeddings, text_pool, None, None, prompt_tokens - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - logger.warning(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - -# def load_clip_l14_336(dtype): -# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - raw_prompt: str - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - original_width: int - original_height: int - original_width_negative: int - original_height_negative: int - crop_left: int - crop_top: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - # モデルを読み込む - logger.info("preparing pipes") - #accelerator = train_util.prepare_accelerator(args) - #is_main_process = accelerator.is_main_process - distributed_state = PartialState() - device = distributed_state.device - - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - #device = get_preferred_device() - logger.info(f"preferred device: {device}") - clean_memory_on_device(device) - model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - distributed_state.wait_for_everyone() - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - logger.info("loading tokenizer") - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # schedulerを用意する - sched_init_args = {} - has_steps_offset = True - has_clip_sample = True - scheduler_num_noises_per_step = 1 - - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - has_clip_sample = False - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - has_clip_sample = False - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - has_clip_sample = False - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteSchedulerGL - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - has_clip_sample = False - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - has_clip_sample = False - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - has_clip_sample = False - has_steps_offset = False - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - has_clip_sample = False - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - has_clip_sample = False - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - has_clip_sample = False - - # 警告を出さないようにする - if has_steps_offset: - sched_init_args["steps_offset"] = 1 - if has_clip_sample: - sched_init_args["clip_sample"] = False - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # ↓以下は結局PipeでFalseに設定されるので意味がなかった - # # clip_sample=Trueにする - # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # logger.info("set clip_sample to True") - # scheduler.config.clip_sample = True - - # deviceを決定する - - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - - vae_dtype = dtype - if args.no_half_vae: - logger.info("set vae_dtype to float32") - vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) - vae.eval() - - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) - text_encoder1.eval() - text_encoder2.eval() - unet.eval() - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - # merge関連の引数を統合する - if args.network_merge: - network_merge = len(args.network_module) # all networks are merged - elif args.network_merge_n_models: - network_merge = args.network_merge_n_models - else: - network_merge = 0 - logger.info(f"network_merge: {network_merge}") - - for i, network_module in enumerate(args.network_module): - logger.info(f"import network module: {network_module}") - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights is None or len(args.network_weights) <= i: - raise ValueError("No weight. Weight is required.") - - network_weight = args.network_weights[i] - logger.info(f"load network weights from: {network_weight}") - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - logger.info(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - if network is None: - return - - mergeable = network.is_mergeable() - if network_merge and not mergeable: - logger.warning("network is not mergiable. ignore merge option.") - - if not mergeable or i >= network_merge: - # not merging - network.apply_to([text_encoder1, text_encoder2], unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - logger.info(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - logger.info("backup original weights") - network.backup_weights() - - networks.append(network) - network_default_muls.append(network_mul) - else: - network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - logger.info(f"import upscaler module: {args.highres_fix_upscaler}") - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - logger.info("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[Tuple[ControlNetLLLite, float]] = [] - # if args.control_net_models: - # for i, model in enumerate(args.control_net_models): - # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) - # prep = original_control_net.load_preprocess(prep_type) - # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - if args.control_net_lllite_models: - for i, model_file in enumerate(args.control_net_lllite_models): - logger.info(f"loading ControlNet-LLLite: {model_file}") - - from safetensors.torch import load_file - - state_dict = load_file(model_file) - mlp_dim = None - cond_emb_dim = None - for key, value in state_dict.items(): - if mlp_dim is None and "down.0.weight" in key: - mlp_dim = value.shape[0] - elif cond_emb_dim is None and "conditioning1.0" in key: - cond_emb_dim = value.shape[0] * 2 - if mlp_dim is not None and cond_emb_dim is not None: - break - assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" - - multiplier = ( - 1.0 - if not args.control_net_multipliers or len(args.control_net_multipliers) <= i - else args.control_net_multipliers[i] - ) - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) - control_net.apply_to() - control_net.load_state_dict(state_dict) - control_net.to(dtype).to(device) - control_net.set_batch_cond_only(False, False) - control_nets.append((control_net, ratio)) - - if args.opt_channels_last: - logger.info(f"set optimizing: channels last") - text_encoder1.to(memory_format=torch.channels_last) - text_encoder2.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.to(memory_format=torch.channels_last) - # cn.unet.to(memory_format=torch.channels_last) - # cn.net.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - device, - vae, - [text_encoder1, text_encoder2], - [tokenizer1, tokenizer2], - unet, - scheduler, - args.clip_skip, - ) - pipe.set_control_nets(control_nets) - logger.info(f"pipeline on {device} is ready.") - distributed_state.wait_for_everyone() - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Deep Shrink - if args.ds_depth_1 is not None: - unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) - - # Gradual Latent - if args.gradual_latent_timesteps is not None: - if args.gradual_latent_unsharp_params: - us_params = args.gradual_latent_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] - us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - - gradual_latent = GradualLatent( - args.gradual_latent_ratio, - args.gradual_latent_timesteps, - args.gradual_latent_every_n_steps, - args.gradual_latent_ratio_step, - args.gradual_latent_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds1 = [] - token_ids_embeds2 = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - - embeds1 = data["clip_l"] # text encoder 1 - embeds2 = data["clip_g"] # text encoder 2 - - num_vectors_per_token = embeds1.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) - assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( - f"tokenizer has same word to token string (filename): {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" - ) - - token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) - token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") - assert ( - min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 - ), f"token ids1 is not ordered" - assert ( - min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 - ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... - pipe.add_token_replacement(1, token_ids2[0], token_ids2) - - token_ids_embeds1.append((token_ids1, embeds1)) - token_ids_embeds2.append((token_ids2, embeds2)) - - text_encoder1.resize_token_embeddings(len(tokenizer1)) - text_encoder2.resize_token_embeddings(len(tokenizer2)) - token_embeds1 = text_encoder1.get_input_embeddings().weight.data - token_embeds2 = text_encoder2.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds1: - for token_id, embed in zip(token_ids, embeds): - token_embeds1[token_id] = embed - for token_ids, embeds in token_ids_embeds2: - for token_id, embed in zip(token_ids, embeds): - token_embeds2[token_id] = embed - - # promptを取得する - if args.from_file is not None: - logger.info(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - logger.info(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - logger.info(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - logger.info(f"loaded {len(init_images)} images for img2img") - - # CLIP Vision - if args.clip_vision_strength is not None: - logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") - vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) - vision_model.to(device, dtype) - processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) - - pipe.clip_vision_model = vision_model - pipe.clip_vision_processor = processor - pipe.clip_vision_strength = args.clip_vision_strength - logger.info(f"CLIP Vision model loaded.") - - else: - init_images = None - - if args.mask_path is not None: - logger.info(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - logger.info(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - logger.info("get prompts from images' metadata") - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - logger.info(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - logger.info(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - logger.info("use mask as region") - - size = None - for i, network in enumerate(networks): - if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: - np_mask = np.array(mask_images[0]) - - if args.network_regional_mask_max_color_codes: - # カラーコードでマスクを指定する - ch0 = (i + 1) & 1 - ch1 = ((i + 1) >> 1) & 1 - ch2 = ((i + 1) >> 2) & 1 - np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) - np_mask = np_mask.astype(np.uint8) * 255 - else: - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - logger.info(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - logger.warning( - f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" - ) - guide_images = None - else: - guide_images = None - - # seed指定時はseedを決めておく - if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 1024 - if args.H is None: - args.H = 1024 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - logger.info(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - logger.info("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - - def scale_and_round(x): - if x is None: - return None - return int(x * args.highres_fix_scale + 0.5) - - width_1st = scale_and_round(ext.width) - height_1st = scale_and_round(ext.height) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - original_width_1st = scale_and_round(ext.original_width) - original_height_1st = scale_and_round(ext.original_height) - original_width_negative_1st = scale_and_round(ext.original_width_negative) - original_height_negative_1st = scale_and_round(ext.original_height_negative) - crop_left_1st = scale_and_round(ext.crop_left) - crop_top_1st = scale_and_round(ext.crop_top) - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - original_width_1st, - original_height_1st, - original_width_negative_1st, - original_height_negative_1st, - crop_left_1st, - crop_top_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - logger.info("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), - ( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - network_muls, - num_sub_prompts, - ), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - raw_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, ( - _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), - _, - ) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - raw_prompts.append(raw_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - logger.info("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - original_height, - original_width, - original_height_negative, - original_width_negative, - crop_top, - crop_left, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - ) - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.error( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - logger.info(f"args.interactive: {args.interactive}, (prompt_index: {prompt_index} < len(prompt_list): {len(prompt_list)} and not distributed_state.is_main_process: {not distributed_state.is_main_process})") - while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - logger.info("") - logger.info("Type prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompt_list[prompt_index] - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 or len(raw_prompts) > 1: - - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - - # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - logger.info(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - logger.error("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.warning( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - batch_data.append(b1) - global_step += 1 - - prompt_index += 1 - distributed_state.wait_for_everyone() - batch_data = gather_object(batch_data) - - if len(batch_data) > 0: - data_loader = get_batches(items=batch_data, batch_size=args.batch_size) - with distributed_state.split_between_processes(data_loader) as batch_list: - for j in range(len(batch_list)): - logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}") - prev_image = process_batch(batch_list[j], highres_fix)[0] - - distributed_state.wait_for_everyone() - #for i in range(len(data_loader)): - # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") - # batch_data_split.append(data_loader[i]) - # if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): - # continue - # with torch.no_grad(): - # with distributed_state.split_between_processes(batch_data_split) as batch_list: - # logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") - # logger.info(f"batch_list:") - # for i in range(len(batch_list[0])): - # logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") - # prev_image = process_batch(batch_list[0], highres_fix)[0] - # batch_data_split.clear() - # distributed_state.wait_for_everyone() - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - #sdxl_train_util.add_sdxl_training_arguments(parser) - add_logging_arguments(parser) - - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", - type=str, - default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", - ) - parser.add_argument( - "--interactive", - action="store_true", - help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument( - "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" - ) - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument( - "--original_height", - type=int, - default=None, - help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width", - type=int, - default=None, - help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--original_height_negative", - type=int, - default=None, - help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width_negative", - type=int, - default=None, - help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" - ) - parser.add_argument( - "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument( - "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" - ) - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument( - "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--vae", - type=str, - default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", - action="store_true", - help="set channels last option to model / モデルにchannels lastを指定し最適化する", - ) - parser.add_argument( - "--network_module", - type=str, - default=None, - nargs="*", - help="additional network module to use / 追加ネットワークを使う時そのモジュール名", - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument( - "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" - ) - parser.add_argument( - "--network_args", - type=str, - default=None, - nargs="*", - help="additional arguments for network (key=value) / ネットワークへの追加の引数", - ) - parser.add_argument( - "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" - ) - parser.add_argument( - "--network_merge_n_models", - type=int, - default=None, - help="merge this number of networks / この数だけネットワークをマージする", - ) - parser.add_argument( - "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" - ) - parser.add_argument( - "--network_pre_calc", - action="store_true", - help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", - ) - parser.add_argument( - "--network_regional_mask_max_color_codes", - type=int, - default=None, - help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument( - "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" - ) - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", - type=int, - default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", - action="store_true", - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", - type=str, - default=None, - help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", - type=float, - default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", - ) - - parser.add_argument( - "--control_net_lllite_models", - type=str, - default=None, - nargs="*", - help="ControlNet models to use / 使用するControlNetのモデル名", - ) - # parser.add_argument( - # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - # ) - # parser.add_argument( - # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - # ) - parser.add_argument( - "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" - ) - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - parser.add_argument( - "--clip_vision_strength", - type=float, - default=None, - help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", - ) - - # Deep Shrink - parser.add_argument( - "--ds_depth_1", - type=int, - default=None, - help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", - ) - parser.add_argument( - "--ds_timesteps_1", - type=int, - default=650, - help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", - ) - parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") - parser.add_argument( - "--ds_timesteps_2", - type=int, - default=650, - help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", - ) - parser.add_argument( - "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" - ) - - # gradual latent - parser.add_argument( - "--gradual_latent_timesteps", - type=int, - default=None, - help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", - ) - parser.add_argument( - "--gradual_latent_ratio", - type=float, - default=0.5, - help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", - ) - parser.add_argument( - "--gradual_latent_ratio_step", - type=float, - default=0.125, - help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", - ) - parser.add_argument( - "--gradual_latent_every_n_steps", - type=int, - default=3, - help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", - ) - parser.add_argument( - "--gradual_latent_s_noise", - type=float, - default=1.0, - help="s_noise for Gradual Latent / Gradual Latentのs_noise", - ) - parser.add_argument( - "--gradual_latent_unsharp_params", - type=str, - default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", - ) - parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16") - parser.add_argument( - "--full_bf16", action="store_true", help="Loading model in bf16" - ) - parser.add_argument( - "--lowram", - action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", - ) - # # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - setup_logging(args, reset=True) - main(args) From 099cea2e24bf53941d2ad66d0c8195dfbf03f2ed Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 01:17:14 +0800 Subject: [PATCH 074/221] Rename accel_sdxl_gen_img_v2.py to accel_sdxl_gen_img.py --- accel_sdxl_gen_img_v2.py => accel_sdxl_gen_img.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename accel_sdxl_gen_img_v2.py => accel_sdxl_gen_img.py (100%) diff --git a/accel_sdxl_gen_img_v2.py b/accel_sdxl_gen_img.py similarity index 100% rename from accel_sdxl_gen_img_v2.py rename to accel_sdxl_gen_img.py From 25283008d300bf568e84e8e62d580e1741a434fb Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 12:29:38 +0800 Subject: [PATCH 075/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d0a3d2fd1..447eb5b58 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2782,8 +2782,7 @@ def scale_and_round(x): if seed is None: seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") + logger.info(f"seed: {seed}") # prepare init image, guide image and mask init_image = mask_image = guide_image = None @@ -2859,7 +2858,7 @@ def scale_and_round(x): logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}") + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}\nSeed: {batch_list[j][i].base.seed}") prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() From e96ea1b841c7e6e3be3d760ba37ecb2cbc92f9ec Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:41:12 +0800 Subject: [PATCH 076/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 447eb5b58..68c9de892 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2768,6 +2768,8 @@ def scale_and_round(x): # 数が足りないなら前のをそのまま使う if len(seeds) > 0: seed = seeds.pop(0) + if len(seeds) == 1: + seeds = None else: if predefined_seeds is not None: if len(predefined_seeds) > 0: From 4c7a7c83f3c6925e506bacb4f54059a22631cf51 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:51:51 +0800 Subject: [PATCH 077/221] Setting seeds to random after exhausted Popping last seeds value does not destroy last element in list, resulting in remaining repeats of prompt to have seed set identical to last seed provided in prompt --- gen_img.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 9427a8940..b0a04a634 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2835,8 +2835,12 @@ def scale_and_round(x): # prepare seed if seeds is not None: # given in prompt # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う - if len(seeds) > 0: + # Previous implementation may result in unexpected behaiviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts + if len(seeds) > 1: seed = seeds.pop(0) + elif len(seeds) == 1: + seed = seeds.pop(0) + seeds = None else: if args.iter_same_seed: seed = iter_seed @@ -2847,6 +2851,7 @@ def scale_and_round(x): seed = seed_random.randint(0, 2**32 - 1) if args.interactive: logger.info(f"seed: {seed}") + # logger.info(f"seed: {seed}") #debugging logger. Uncomment to verify if expected seed is added correctly. # prepare init image, guide image and mask init_image = mask_image = guide_image = None From 7dfc5c94e2966097b025cce914427391367d50ba Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:56:47 +0800 Subject: [PATCH 078/221] Update gen_img.py fix typo --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index b0a04a634..d3c0d86d8 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2835,7 +2835,7 @@ def scale_and_round(x): # prepare seed if seeds is not None: # given in prompt # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う - # Previous implementation may result in unexpected behaiviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts + # Previous implementation may result in unexpected behaviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts if len(seeds) > 1: seed = seeds.pop(0) elif len(seeds) == 1: From b1ca9d948589dad3d1b04718e18ec98b9fdc8eb0 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 05:33:34 +0800 Subject: [PATCH 079/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 57 +++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 68c9de892..1ffd0b6f7 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2441,6 +2441,7 @@ def scale_and_round(x): prompt_index = 0 global_step = 0 batch_data = [] + extinfo = [] while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): if len(prompt_list) == 0: # interactive @@ -2766,10 +2767,11 @@ def scale_and_round(x): # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: + if len(seeds) > 1: seed = seeds.pop(0) - if len(seeds) == 1: - seeds = None + elif len(seeds) == 1: + seed = seeds.pop(0) + seeds = None else: if predefined_seeds is not None: if len(predefined_seeds) > 0: @@ -2847,23 +2849,48 @@ def scale_and_round(x): ), ) batch_data.append(b1) + extinfo.append(b1.ext) global_step += 1 prompt_index += 1 + batch_separated_list = [] + if distributed_state.is_main_process and len(batch_data) > 0: + unique_extinfo = list(set(extinfo)) + # splits list of prompts into sublists where BatchDataExt ext is identical + for i in range(len(unique_extinfo)): + templist = [] + res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] + for index in res: + templist.append(batch_data[index]) + split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + if(len(split_into_batches) % distributed_state.num_processes != 0): + #Distributes last round of batches across available processes if last round of batches less than availble processes and there is more than one prompt in the last batch + sublist = [] + for j in range(len(split_into_batches) % distributed_state.num_processes): + if len(split_into_batches) > 1 : + sublist.extend(split_into_batches.pop(-1)) + elif len(split_into_batches) == 1 : + sublist.extend(split_into_batches.pop(-1)) + listofbatches = [] + n, m = divmod(len(sublist), device) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) + batch_separated_list.append(split_into_batches) + distributed_state.wait_for_everyone() - batch_data = gather_object(batch_data) - + batch_data = gather_object(batch_separated_list) + del extinfo + if len(batch_data) > 0: - data_loader = get_batches(items=batch_data, batch_size=args.batch_size) - with distributed_state.split_between_processes(data_loader) as batch_list: - for j in range(len(batch_list)): - logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:") - logger.info(f"batch_list:") - for i in range(len(batch_list[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}\nSeed: {batch_list[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] - - distributed_state.wait_for_everyone() + for batch_list in batch_data: + with distributed_state.split_between_processes(batch_list) as batches: + for j in range(len(batches)): + logger.info(f"Loading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"batch_list:") + for i in range(len(batches[j])): + logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + prev_image = process_batch(batch_list[j], highres_fix)[0] + + distributed_state.wait_for_everyone() #for i in range(len(data_loader)): # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") # batch_data_split.append(data_loader[i]) From 78a64b03fcac071afdb13b85c008c08c25bffe15 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 05:45:25 +0800 Subject: [PATCH 080/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 1ffd0b6f7..a3ded9c6e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2415,7 +2415,7 @@ def scale_and_round(x): elif args.sequential_file_name: fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = f"im_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + fln = f"im_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") if negative_prompt is not None: From 8d8e7e40599f784216739481300e973404c29801 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 05:58:32 +0800 Subject: [PATCH 081/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index a3ded9c6e..f5711cf57 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2467,7 +2467,7 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 or len(raw_prompts) > 1: + if pi == 0 # or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing width = args.W From 009d07a988afc12b7b8d68eb053040c1fe812d77 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 06:00:39 +0800 Subject: [PATCH 082/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index f5711cf57..122f32ffd 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2467,7 +2467,7 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0 # or len(raw_prompts) > 1: + if pi == 0: # or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing width = args.W From 54ca0420e3bd172deddc51846b1e76bae2fd51d6 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 06:08:17 +0800 Subject: [PATCH 083/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 122f32ffd..c7487ad99 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2467,8 +2467,12 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if pi == 0: # or len(raw_prompts) > 1: - + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + if pi == 0 or len(raw_prompts) > 1: + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + if pi == 0: # parse prompt: if prompt is not changed, skip parsing width = args.W height = args.H @@ -2503,10 +2507,6 @@ def scale_and_round(x): gl_s_noise = args.gradual_latent_s_noise gl_unsharp_params = args.gradual_latent_unsharp_params - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - for parg in prompt_args[1:]: try: From 41b181c86a48743907e996aaf4d2639a7e81b80a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 06:14:53 +0800 Subject: [PATCH 084/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index c7487ad99..9c5dfc1fa 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2415,7 +2415,7 @@ def scale_and_round(x): elif args.sequential_file_name: fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = f"im_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" + fln = f"im_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") if negative_prompt is not None: From c7caa88cc35fb1ad101b0ded5ab50146c3f75f05 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 06:27:26 +0800 Subject: [PATCH 085/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9c5dfc1fa..f0d1aaccc 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2469,6 +2469,18 @@ def scale_and_round(x): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] + for parg in prompt_args[1:]: + + try: + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + logger.info(f"negative prompt: {negative_prompt}") + break + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(f"{ex}") if pi == 0 or len(raw_prompts) > 1: logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -2591,12 +2603,6 @@ def scale_and_round(x): logger.info(f"strength: {strength}") continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - m = re.match(r"c (.+)", parg, re.IGNORECASE) if m: # clip prompt clip_prompt = m.group(1) From b1b1c19be15c7e0636175790d60ba747d55267c4 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:18:11 +0800 Subject: [PATCH 086/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index f0d1aaccc..316c9d0d7 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2870,7 +2870,7 @@ def scale_and_round(x): templist.append(batch_data[index]) split_into_batches = get_batches(items=templist, batch_size=args.batch_size) if(len(split_into_batches) % distributed_state.num_processes != 0): - #Distributes last round of batches across available processes if last round of batches less than availble processes and there is more than one prompt in the last batch + #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch sublist = [] for j in range(len(split_into_batches) % distributed_state.num_processes): if len(split_into_batches) > 1 : From 39a375139a1dd0585658237d7a0ca2daeb32cf81 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 29 Jan 2025 18:46:52 +0800 Subject: [PATCH 087/221] adding example generation --- fine_tune.py | 6 ++-- library/train_util.py | 59 ++++++++++++++++++++++++++++++++- sdxl_train.py | 5 +++ sdxl_train_network.py | 4 +-- sdxl_train_textual_inversion.py | 4 +-- train_db.py | 6 ++-- train_network.py | 9 ++--- train_textual_inversion.py | 8 +++-- 8 files changed, 83 insertions(+), 18 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c79f97d25..c59ffa14a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -404,14 +404,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple ) # 指定ステップごとにモデルを保存 @@ -474,7 +474,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/train_util.py b/library/train_util.py index 100ef475d..0df9c1fca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5431,6 +5431,7 @@ def sample_images_common( tokenizer, text_encoder, unet, + example_tuple=None, prompt_replacement=None, controlnet=None, ): @@ -5527,7 +5528,18 @@ def sample_images_common( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): + idx = 0 for prompt_dict in prompts: + if '__caption__' in prompt_dict.get("prompt") and example_tuple: + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + idx = (idx + 1) % len(example_tuple[1]) sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) @@ -5558,6 +5570,42 @@ def sample_images_common( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def draw_text_on_image(text, max_width, text_color="black"): + from PIL import ImageDraw, ImageFont, Image + import textwrap + + font = ImageFont.load_default() + space_width = font.getbbox(' ')[2] + font_size = 20 + + def wrap_text(text, font, max_width): + words = text.split(' ') + lines = [] + current_line = "" + for word in words: + test_line = current_line + word + " " + if font.getbbox(test_line)[2] <= max_width: + current_line = test_line + else: + lines.append(current_line) + current_line = word + " " + lines.append(current_line) + return lines + + lines = wrap_text(text, font, max_width - 10) + text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20 + text_image = Image.new('RGB', (max_width, text_height), 'white') + text_draw = ImageDraw.Draw(text_image) + + y_text = 10 + for line in lines: + bbox = text_draw.textbbox((0, 0), line, font=font) + height = bbox[3] - bbox[1] + text_draw.text((10, y_text), line, font=font, fill=text_color) + y_text += font_size + + return text_image + def sample_image_inference( accelerator: Accelerator, @@ -5634,7 +5682,16 @@ def sample_image_inference( torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] - + if "original_lantent" in prompt_dict: + original_latent = prompt_dict.get("original_lantent") + original_image = pipeline.latents_to_image(original_latent)[0] + text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2) + new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height)) + new_image.paste(original_image, (0, text_image.height)) + new_image.paste(image, (original_image.width, text_image.height)) + new_image.paste(text_image, (0, 0)) + image = new_image + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list # but adding 'enum' to the filename should be enough diff --git a/sdxl_train.py b/sdxl_train.py index b533b2749..7779d2267 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -740,6 +740,7 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.backward(loss) + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -757,6 +758,8 @@ def optimizer_hook(parameter: torch.Tensor): for i in range(1, len(optimizers)): lr_schedulers[i].step() + + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -772,6 +775,7 @@ def optimizer_hook(parameter: torch.Tensor): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, + example_tuple, ) # 指定ステップごとにモデルを保存 @@ -854,6 +858,7 @@ def optimizer_hook(parameter: torch.Tensor): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, + example_tuple, ) is_main_process = accelerator.is_main_process diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1d..3eafc152c 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -164,8 +164,8 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): + sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e28..de75a0aad 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -81,9 +81,9 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): diff --git a/train_db.py b/train_db.py index e7cf3cde3..d4f558b1f 100644 --- a/train_db.py +++ b/train_db.py @@ -388,14 +388,14 @@ def train(args): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple ) # 指定ステップごとにモデルを保存 @@ -459,7 +459,7 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7bf125dca..2554b2302 100644 --- a/train_network.py +++ b/train_network.py @@ -131,8 +131,8 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) def train(self, args): session_id = random.randint(0, 2**32) @@ -1022,11 +1022,12 @@ def remove_model(old_ckpt_name): keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes + example_tuple = (latents, batch["captions"]) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1083,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 37349da7d..067d44ca2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -122,9 +122,9 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond noise_pred = unet(noisy_latents, timesteps, text_conds).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -627,6 +627,7 @@ def remove_model(old_ckpt_name): index_no_updates ] + example_tuple = (latents, captions) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -642,6 +643,7 @@ def remove_model(old_ckpt_name): tokenizer_or_list, text_encoder_or_list, unet, + example_tuple, prompt_replacement, ) @@ -714,7 +716,6 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images( accelerator, args, @@ -725,6 +726,7 @@ def remove_model(old_ckpt_name): tokenizer_or_list, text_encoder_or_list, unet, + example_tuple, prompt_replacement, ) From 158af8d5dc2d11963887e2abdcd7f94c7ac11199 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:15:49 +0800 Subject: [PATCH 088/221] Update train_util.py --- library/train_util.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0df9c1fca..665b71434 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5512,10 +5512,20 @@ def sample_images_common( prompt_dict = line_to_prompt_dict(prompt_dict) prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) + if '__caption__' in prompts[i].get("prompt") and example_tuple: + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') + prompts[i]["height"] = example_tuple[0].shape[2] * 8 + prompts[i]["width"] = example_tuple[0].shape[3] * 8 + prompts[i]["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + idx = (idx + 1) % len(example_tuple[1]) + prompts[i]["enum"] = i + prompts[i].pop("subset", None) + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) # save random state to restore later rng_state = torch.get_rng_state() @@ -5530,16 +5540,6 @@ def sample_images_common( with torch.no_grad(): idx = 0 for prompt_dict in prompts: - if '__caption__' in prompt_dict.get("prompt") and example_tuple: - while example_tuple[1][idx] == '': - idx = (idx + 1) % len(example_tuple[1]) - if idx == 0: - break - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') - prompt_dict["height"] = example_tuple[0].shape[2] * 8 - prompt_dict["width"] = example_tuple[0].shape[3] * 8 - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) - idx = (idx + 1) % len(example_tuple[1]) sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) From 96ab0cc3756733c57a97b26081c157b57747e8f7 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 29 Jan 2025 22:43:16 +0800 Subject: [PATCH 089/221] Modify list of seed for dynamic prompt and add checks for filename to prevent overwriting --- gen_img.py | 51 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/gen_img.py b/gen_img.py index d3c0d86d8..8250663ce 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2508,7 +2508,7 @@ def scale_and_round(x): metadata.add_text("crop-left", str(crop_left)) if filename is not None: - fln = filename + fln = first_available_filename(args.outdir, filename) #Checks to make sure is not existing file, else returns first available sequential filename else: if args.use_original_file_name and init_images is not None: if type(init_images) is list: @@ -2586,7 +2586,8 @@ def scale_and_round(x): negative_scale = args.negative_scale steps = args.steps seed = None - seeds = None + if pi == 0: + seeds = None strength = 0.8 if args.strength is None else args.strength negative_prompt = "" clip_prompt = None @@ -2670,6 +2671,8 @@ def scale_and_round(x): m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed + if pi > 0 and len(raw_prompts) > 1: #Bypass od 2nd loop for dynamic prompts + continue seeds = [int(d) for d in m.group(1).split(",")] logger.info(f"seeds: {seeds}") continue @@ -2802,7 +2805,11 @@ def scale_and_round(x): logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") - # override Deep Shrink + # override filename to add index number if more than one image per prompt + if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1): + filename = filename + "_%s" % pi + + # override Deep Shrink if ds_depth_1 is not None: if ds_depth_1 < 0: ds_depth_1 = args.ds_depth_1 or 3 @@ -2835,12 +2842,16 @@ def scale_and_round(x): # prepare seed if seeds is not None: # given in prompt # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う - # Previous implementation may result in unexpected behaviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts + # Previous implementation may result in unexpected behaviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts. Add condition if last element is -1, to start randomizing seed. if len(seeds) > 1: seed = seeds.pop(0) elif len(seeds) == 1: - seed = seeds.pop(0) - seeds = None + if seeds[0] == -1: + seeds = None + else: + seed = seeds.pop(0) + + else: if args.iter_same_seed: seed = iter_seed @@ -2940,7 +2951,35 @@ def scale_and_round(x): logger.info("done!") +def first_available_filename(path, filename): + """ + Checks if filename is in use. + if filename is in use, appends a running number + e.g. filename = 'file.png': + + file.png + file_1.png + file_2.png + Runs in log(n) time where n is the number of existing files in sequence + """ + i = 1 + if not os.path.exists(os.path.join(path, filename)): + return filename + fileext = os.path.splitext(filename) + filename = fileext[0] + "_%s" + fileext[1] + # First do an exponential search + while os.path.exists(os.path.join(path,filename % i)): + i = i * 2 + + # Result lies somewhere in the interval (i/2..i] + # We call this interval (a..b] and narrow it down until a + 1 = b + a, b = (i // 2, i) + while a + 1 < b: + c = (a + b) // 2 # interval midpoint + a, b = (c, b) if os.path.exists(os.path.join(path,filename % c)) else (a, c) + + return filename % b def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() From 9ae34ac52763493ba80905ad628f85278c83ac9b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 29 Jan 2025 23:21:02 +0800 Subject: [PATCH 090/221] Update gen_img.py --- gen_img.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/gen_img.py b/gen_img.py index 8250663ce..ee106d14d 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1492,6 +1492,8 @@ def main(args): else: dtype = torch.float32 + device = get_preferred_device() + highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" @@ -1521,9 +1523,10 @@ def main(args): if is_sdxl: if args.clip_skip is None: args.clip_skip = 2 - + + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) text_encoders = [text_encoder1, text_encoder2] @@ -3387,6 +3390,10 @@ def setup_parser() -> argparse.ArgumentParser: help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) + parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16") + parser.add_argument( + "--full_bf16", action="store_true", help="Loading model in bf16" + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" From 0bb8a2c2878aec41f92b0af4a2ef892d11dbfa61 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:26:55 +0800 Subject: [PATCH 091/221] Test logging fix --- gen_img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index ee106d14d..68bee2ecb 100644 --- a/gen_img.py +++ b/gen_img.py @@ -63,7 +63,7 @@ from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from library.utils import setup_logging, add_logging_arguments -setup_logging() + import logging logger = logging.getLogger(__name__) @@ -1485,6 +1485,8 @@ def __call__(self, *args, **kwargs): def main(args): + logging.basicConfig() + setup_logging(args) if args.fp16: dtype = torch.float16 elif args.bf16: From 99bc4b6f48c840bda9d5c9a25d4dff30aca85833 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:38:19 +0800 Subject: [PATCH 092/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 665b71434..d4a0ec45b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5506,6 +5506,7 @@ def sample_images_common( os.makedirs(save_dir, exist_ok=True) # preprocess prompts + idx = 0 for i in range(len(prompts)): prompt_dict = prompts[i] if isinstance(prompt_dict, str): @@ -5538,7 +5539,6 @@ def sample_images_common( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): - idx = 0 for prompt_dict in prompts: sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet From 0d263597953810ed2e654f13587bab6931486b95 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:41:31 +0800 Subject: [PATCH 093/221] Update gen_img.py --- gen_img.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 68bee2ecb..1d74b8b50 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1485,7 +1485,11 @@ def __call__(self, *args, **kwargs): def main(args): - logging.basicConfig() + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) setup_logging(args) if args.fp16: dtype = torch.float16 From a602fde7aa3d367e51f50c8e1da7de23d3a23c12 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:46:32 +0800 Subject: [PATCH 094/221] Update utils.py --- library/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/utils.py b/library/utils.py index 49d46a546..54a8ca5a3 100644 --- a/library/utils.py +++ b/library/utils.py @@ -67,7 +67,7 @@ def setup_logging(args=None, log_level=None, reset=False): if handler is None: handler = logging.StreamHandler(sys.stdout) # same as print - handler.propagate = False + handler.propagate = True formatter = logging.Formatter( fmt="%(message)s", From 1a1496ca93673b0467823c0ba1c4230863cf44f5 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:51:35 +0800 Subject: [PATCH 095/221] Update gen_img.py --- gen_img.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/gen_img.py b/gen_img.py index 1d74b8b50..60d9f862a 100644 --- a/gen_img.py +++ b/gen_img.py @@ -63,7 +63,7 @@ from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL from library.utils import setup_logging, add_logging_arguments - +setup_logging() import logging logger = logging.getLogger(__name__) @@ -1485,12 +1485,7 @@ def __call__(self, *args, **kwargs): def main(args): - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - setup_logging(args) + if args.fp16: dtype = torch.float16 elif args.bf16: From 8db164b052787f3cb84e1581316381ba40744dee Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 02:58:03 +0800 Subject: [PATCH 096/221] Update gen_img.py --- gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 60d9f862a..a6c616e9e 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2811,7 +2811,8 @@ def scale_and_round(x): # override filename to add index number if more than one image per prompt if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1): - filename = filename + "_%s" % pi + fileext = os.path.splitext(filename) + filename = fileext[0] + "_%s" + fileext[1] % pi # override Deep Shrink if ds_depth_1 is not None: From ec6ca91e6ac92cbdc1ce9d5b17b82dd82b27b86c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 03:00:34 +0800 Subject: [PATCH 097/221] Update gen_img.py --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index a6c616e9e..119aaa4f8 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2812,7 +2812,7 @@ def scale_and_round(x): # override filename to add index number if more than one image per prompt if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1): fileext = os.path.splitext(filename) - filename = fileext[0] + "_%s" + fileext[1] % pi + filename = fileext[0] + "_%s" % pi + fileext[1] # override Deep Shrink if ds_depth_1 is not None: From 06f11807e5f4e0f9df2b0d223611bd4dc766f141 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 03:13:29 +0800 Subject: [PATCH 098/221] Update gen_img.py --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 119aaa4f8..6bbc7d34b 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1752,7 +1752,7 @@ def __getattr__(self, item): logger.info(f"network_merge: {network_merge}") for i, network_module in enumerate(args.network_module): - logger.info("import network module: {network_module}") + logger.info(f"import network module: {network_module}") imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] From fa444fa4f0ec5973694f9af871b0050545d4b412 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 03:47:17 +0800 Subject: [PATCH 099/221] Update gen_img.py --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index 6bbc7d34b..f2a10b87a 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2802,7 +2802,6 @@ def scale_and_round(x): m = re.match(r"f (.+)", parg, re.IGNORECASE) if m: # filename filename = m.group(1) - logger.info(f"filename: {filename}") continue except ValueError as ex: @@ -2813,6 +2812,7 @@ def scale_and_round(x): if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1): fileext = os.path.splitext(filename) filename = fileext[0] + "_%s" % pi + fileext[1] + logger.info(f"filename: {filename}") # override Deep Shrink if ds_depth_1 is not None: From 6a653966a6834fa56a370c29a9a169181959fbda Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 03:49:55 +0800 Subject: [PATCH 100/221] Update gen_img.py --- gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index f2a10b87a..adddd98dc 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2677,7 +2677,7 @@ def scale_and_round(x): if m: # seed if pi > 0 and len(raw_prompts) > 1: #Bypass od 2nd loop for dynamic prompts continue - seeds = [int(d) for d in m.group(1).split(",")] + seeds = [int(float(d)) for d in m.group(1).split(",")] logger.info(f"seeds: {seeds}") continue From c70ce285854863f7c6051a8a307e529da8e9e349 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:51:53 +0800 Subject: [PATCH 101/221] Update gen_img.py --- gen_img.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gen_img.py b/gen_img.py index adddd98dc..b429cbed7 100644 --- a/gen_img.py +++ b/gen_img.py @@ -2677,7 +2677,9 @@ def scale_and_round(x): if m: # seed if pi > 0 and len(raw_prompts) > 1: #Bypass od 2nd loop for dynamic prompts continue - seeds = [int(float(d)) for d in m.group(1).split(",")] + logger.info(f"{m}") + seeds = m.group(1).split(",") + seeds = [int(d.strip()) for d in seeds] logger.info(f"seeds: {seeds}") continue From 3decab58da8453b45a336f8d306cb335b3c5b882 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:00:12 +0800 Subject: [PATCH 102/221] Update train_util.py --- library/train_util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d4a0ec45b..8c1142972 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5518,7 +5518,7 @@ def sample_images_common( idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break - prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') + prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', example_tuple[1][idx]) prompts[i]["height"] = example_tuple[0].shape[2] * 8 prompts[i]["width"] = example_tuple[0].shape[3] * 8 prompts[i]["original_lantent"] = example_tuple[0][idx].unsqueeze(0) @@ -5540,6 +5540,9 @@ def sample_images_common( # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): for prompt_dict in prompts: + if prompt_dict["prompt"] == '__caption__': + logger.info("No training prompts loaded, skipping '__caption__' prompt.") + continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) @@ -5553,6 +5556,9 @@ def sample_images_common( with torch.no_grad(): with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists[0]: + if prompt_dict["prompt"] == '__caption__': + logger.info("No training prompts loaded, skipping '__caption__' prompt.") + continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) From 198aea95a6bcc07a1affa1e5e0f471193f43bc14 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:18:08 +0800 Subject: [PATCH 103/221] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8c1142972..17269b5d6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5557,8 +5557,8 @@ def sample_images_common( with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists[0]: if prompt_dict["prompt"] == '__caption__': - logger.info("No training prompts loaded, skipping '__caption__' prompt.") - continue + logger.info("No training prompts loaded, skipping '__caption__' prompt.") + continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) From c86739b7063166f35a0c64aae9eb8d2e8bd60995 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 30 Jan 2025 21:05:26 +0800 Subject: [PATCH 104/221] Update train_util.py --- library/train_util.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 17269b5d6..50afcc0f2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5513,18 +5513,23 @@ def sample_images_common( prompt_dict = line_to_prompt_dict(prompt_dict) prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - if '__caption__' in prompts[i].get("prompt") and example_tuple: + + if '__caption__' in prompt_dict.get("prompt") and example_tuple: + logger.info(f"Original prompt: {prompt_dict.get("prompt")}") while example_tuple[1][idx] == '': idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break - prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', example_tuple[1][idx]) - prompts[i]["height"] = example_tuple[0].shape[2] * 8 - prompts[i]["width"] = example_tuple[0].shape[3] * 8 - prompts[i]["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]') + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) - prompts[i]["enum"] = i - prompts[i].pop("subset", None) + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + prompts[i] = prompt_dict + logger.info(f"Current prompt: {prompts[i].get("prompt")}") + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. From dee3c48276bd5f257372eb6aad7f6a877e340675 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 00:56:47 +0800 Subject: [PATCH 105/221] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 50afcc0f2..9dcfb26a6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5515,7 +5515,7 @@ def sample_images_common( assert isinstance(prompt_dict, dict) if '__caption__' in prompt_dict.get("prompt") and example_tuple: - logger.info(f"Original prompt: {prompt_dict.get("prompt")}") + logger.info(f"Original prompt: {prompt_dict.get('prompt')}") while example_tuple[1][idx] == '': idx = (idx + 1) % len(example_tuple[1]) if idx == 0: @@ -5528,7 +5528,7 @@ def sample_images_common( prompt_dict["enum"] = i prompt_dict.pop("subset", None) prompts[i] = prompt_dict - logger.info(f"Current prompt: {prompts[i].get("prompt")}") + logger.info(f"Current prompt: {prompts[i].get('prompt')}") # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. From 7e8daada4dc2b037871c0dc4d2bc547904354444 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 00:57:56 +0800 Subject: [PATCH 106/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 9dcfb26a6..99b93d929 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5520,7 +5520,7 @@ def sample_images_common( idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]') + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') prompt_dict["height"] = example_tuple[0].shape[2] * 8 prompt_dict["width"] = example_tuple[0].shape[3] * 8 prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) From a049c85e72bf25c2bd57f5977941f52c58005a6d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:23:24 +0800 Subject: [PATCH 107/221] Update train_util.py --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 99b93d929..811455983 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5520,7 +5520,8 @@ def sample_images_common( idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') + prompt_dict["prompt"] = f"{example_tuple[1][idx]}" + logger.info(f"Replacement prompt: {example_tuple[1][idx]}") prompt_dict["height"] = example_tuple[0].shape[2] * 8 prompt_dict["width"] = example_tuple[0].shape[3] * 8 prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) From 12dd9f6b59f25eb7ba4b02f2c259368020b070aa Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:52:07 +0800 Subject: [PATCH 108/221] Update train_util.py --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 811455983..b32fe9f0c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5516,6 +5516,7 @@ def sample_images_common( if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") + logger.info(f"Example tuple: {example_tuple}") while example_tuple[1][idx] == '': idx = (idx + 1) % len(example_tuple[1]) if idx == 0: From 5559d1b711dfef64e0b5e3064f2489544f40ba3c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:56:30 +0800 Subject: [PATCH 109/221] Update train_network.py --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 2554b2302..5e5794516 100644 --- a/train_network.py +++ b/train_network.py @@ -1023,6 +1023,7 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes example_tuple = (latents, batch["captions"]) + logger.info(f"Example tuple: {example_tuple}") if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 From a130bb444b40eba94d5180cd68104ccddab377cf Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:10:20 +0800 Subject: [PATCH 110/221] Update train_util.py --- library/train_util.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b32fe9f0c..572166295 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5516,17 +5516,23 @@ def sample_images_common( if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - logger.info(f"Example tuple: {example_tuple}") - while example_tuple[1][idx] == '': - idx = (idx + 1) % len(example_tuple[1]) - if idx == 0: - break - prompt_dict["prompt"] = f"{example_tuple[1][idx]}" - logger.info(f"Replacement prompt: {example_tuple[1][idx]}") + if len(example_tuple[1] > 1: + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = f"{example_tuple[1][idx]}" + logger.info(f"Replacement prompt: {example_tuple[1][idx]}") + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + idx = (idx + 1) % len(example_tuple[1]) + else: + prompt_dict["prompt"] = f"{example_tuple[1]}" + logger.info(f"Replacement prompt: {example_tuple[1]}") prompt_dict["height"] = example_tuple[0].shape[2] * 8 prompt_dict["width"] = example_tuple[0].shape[3] * 8 - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) - idx = (idx + 1) % len(example_tuple[1]) + prompt_dict["original_lantent"] = example_tuple[0].unsqueeze(0) prompt_dict["enum"] = i prompt_dict.pop("subset", None) prompts[i] = prompt_dict From f73db401fa8f01ac58eb29757d4b0b91915ac5d6 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:12:06 +0800 Subject: [PATCH 111/221] Update train_network.py --- train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train_network.py b/train_network.py index 5e5794516..2554b2302 100644 --- a/train_network.py +++ b/train_network.py @@ -1023,7 +1023,6 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes example_tuple = (latents, batch["captions"]) - logger.info(f"Example tuple: {example_tuple}") if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 From 7f0ae28294b8cb94d51a789ebd39ea3cc3539834 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:20:18 +0800 Subject: [PATCH 112/221] Update train_util.py --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 572166295..4c7b0b411 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5507,6 +5507,10 @@ def sample_images_common( # preprocess prompts idx = 0 + if example_tuple: + logger.info(f"len(example_tuple): {len(example_tuple)}") + logger.info(f"len(example_tuple[0]): {len(example_tuple[0])}") + logger.info(f"len(example_tuple[1]): {len(example_tuple[1])}") for i in range(len(prompts)): prompt_dict = prompts[i] if isinstance(prompt_dict, str): From 307099c4875037d8e38fbce6c9bcfae53affc0bf Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:21:37 +0800 Subject: [PATCH 113/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 4c7b0b411..83ae5122a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5520,7 +5520,7 @@ def sample_images_common( if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - if len(example_tuple[1] > 1: + if len(example_tuple[1]) > 1: while example_tuple[1][idx] == '': idx = (idx + 1) % len(example_tuple[1]) if idx == 0: From a8ea5a6ac4b680b14a0bcffc6eaf32ec3f9c6442 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 03:25:34 +0800 Subject: [PATCH 114/221] Update train_util.py --- library/train_util.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 83ae5122a..4d644f9e7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5507,10 +5507,6 @@ def sample_images_common( # preprocess prompts idx = 0 - if example_tuple: - logger.info(f"len(example_tuple): {len(example_tuple)}") - logger.info(f"len(example_tuple[0]): {len(example_tuple[0])}") - logger.info(f"len(example_tuple[1]): {len(example_tuple[1])}") for i in range(len(prompts)): prompt_dict = prompts[i] if isinstance(prompt_dict, str): @@ -5520,23 +5516,18 @@ def sample_images_common( if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - if len(example_tuple[1]) > 1: - while example_tuple[1][idx] == '': - idx = (idx + 1) % len(example_tuple[1]) - if idx == 0: - break - prompt_dict["prompt"] = f"{example_tuple[1][idx]}" - logger.info(f"Replacement prompt: {example_tuple[1][idx]}") - prompt_dict["height"] = example_tuple[0].shape[2] * 8 - prompt_dict["width"] = example_tuple[0].shape[3] * 8 - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) - idx = (idx + 1) % len(example_tuple[1]) - else: - prompt_dict["prompt"] = f"{example_tuple[1]}" - logger.info(f"Replacement prompt: {example_tuple[1]}") - prompt_dict["height"] = example_tuple[0].shape[2] * 8 - prompt_dict["width"] = example_tuple[0].shape[3] * 8 - prompt_dict["original_lantent"] = example_tuple[0].unsqueeze(0) + + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = f"{example_tuple[1][idx]}" + logger.info(f"Replacement prompt: {example_tuple[1][idx]}") + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + idx = (idx + 1) % len(example_tuple[1]) + prompt_dict["enum"] = i prompt_dict.pop("subset", None) prompts[i] = prompt_dict From 793616fc61c7de9e22b2257260b3127aca231d9a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:15:32 +0800 Subject: [PATCH 115/221] Update train_util.py --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 4d644f9e7..5fcb3b938 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5697,6 +5697,10 @@ def sample_image_inference( image = pipeline.latents_to_image(latents)[0] if "original_lantent" in prompt_dict: + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + original_latent = prompt_dict.get("original_lantent") original_image = pipeline.latents_to_image(original_latent)[0] text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2) From e723e457ad65c2157aa75e48f9b4a929cce05c1d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 14:18:45 +0800 Subject: [PATCH 116/221] Update train_util.py --- library/train_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 5fcb3b938..c8021fec0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5524,7 +5524,9 @@ def sample_images_common( prompt_dict["prompt"] = f"{example_tuple[1][idx]}" logger.info(f"Replacement prompt: {example_tuple[1][idx]}") prompt_dict["height"] = example_tuple[0].shape[2] * 8 + logger.info(f"Original Image Height: {prompt_dict["height"]}") prompt_dict["width"] = example_tuple[0].shape[3] * 8 + logger.info(f"Original Image Width: {prompt_dict["width"]}") prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) From 6de0051eb24a04db40f89a8ff839a7364c0686c2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 15:07:18 +0800 Subject: [PATCH 117/221] Update train_network.py --- train_network.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 2554b2302..5f2a31da5 100644 --- a/train_network.py +++ b/train_network.py @@ -1022,7 +1022,16 @@ def remove_model(old_ckpt_name): keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes - example_tuple = (latents, batch["captions"]) + # Collecting latents and caption lists from all processes + all_lists_of_latents = gather_object(latents) + all_lists_of_captions = gather_object(batch["captions"]) + all_latents = [] + all_captions = [] + for ilatents in all_lists_of_latents: + all_latents.extend(ilatents) + for icaptions in all_lists_of_captions: + all_captions.extend(icaptions) + example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 From a979ea5a50adff57d149c16b0ea19f39a035ba42 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:37:47 +0800 Subject: [PATCH 118/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 316c9d0d7..78039b2a9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2881,7 +2881,13 @@ def scale_and_round(x): n, m = divmod(len(sublist), device) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) batch_separated_list.append(split_into_batches) - + if distributed_state.num_processes > 1: + templist = [] + for i in range(distributed_state.num_processes): + templist.append(batch_separated_list[i :: distributed_state.num_processes]) + batch_separated_list = [] + for sub_batch_list in templist: + batch_separated_list.extend(sub_batch_list) distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) del extinfo From 3bb1d9d38984d363ef8c00bafa9570beea3121d0 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:02:57 +0800 Subject: [PATCH 119/221] Update train_util.py --- library/train_util.py | 135 +++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c8021fec0..8b6638812 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5472,18 +5472,23 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - if args.sample_prompts.endswith(".txt"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif args.sample_prompts.endswith(".toml"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif args.sample_prompts.endswith(".json"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - prompts = json.load(f) - + + if distributed_state.is_main_process: + # Load prompts into prompts list on main process only + if args.sample_prompts.endswith(".txt"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif args.sample_prompts.endswith(".toml"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif args.sample_prompts.endswith(".json"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + prompts = json.load(f) + else: + prompts = [] # Init empty prompts list for sub processes. + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, @@ -5502,38 +5507,41 @@ def sample_images_common( clip_skip=args.clip_skip, ) pipeline.to(distributed_state.device) - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) + # preprocess prompts - idx = 0 - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - if '__caption__' in prompt_dict.get("prompt") and example_tuple: - logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - - while example_tuple[1][idx] == '': + if distributed_state.is_main_process: + #Create output folder and preprocess prompts on main process only. + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + idx = 0 + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + if '__caption__' in prompt_dict.get("prompt") and example_tuple: + logger.info(f"Original prompt: {prompt_dict.get('prompt')}") + + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') + logger.info(f"Replacement prompt: {prompt_dict["prompt"]}") + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + logger.info(f"Original Image Height: {prompt_dict["height"]}") + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + logger.info(f"Original Image Width: {prompt_dict["width"]}") + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) - if idx == 0: - break - prompt_dict["prompt"] = f"{example_tuple[1][idx]}" - logger.info(f"Replacement prompt: {example_tuple[1][idx]}") - prompt_dict["height"] = example_tuple[0].shape[2] * 8 - logger.info(f"Original Image Height: {prompt_dict["height"]}") - prompt_dict["width"] = example_tuple[0].shape[3] * 8 - logger.info(f"Original Image Width: {prompt_dict["width"]}") - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) - idx = (idx + 1) % len(example_tuple[1]) - - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - prompts[i] = prompt_dict - logger.info(f"Current prompt: {prompts[i].get('prompt')}") + + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + prompts[i] = prompt_dict + logger.info(f"Current prompt: {prompts[i].get('prompt')}") # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. @@ -5546,32 +5554,34 @@ def sample_images_common( except Exception: pass - if distributed_state.num_processes <= 1: - # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): - for prompt_dict in prompts: - if prompt_dict["prompt"] == '__caption__': + if distributed_state.num_processes > 1 and distributed_state.is_main_process: + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + prompts = [] + for prompt in per_process_prompts: + prompts.extend(prompt) + distributed_state.wait_for_everyone() + per_process_prompts = gather_object(prompts) + prompts = [] + for prompt in per_process_prompts: + prompts.extend(prompt) + + + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + + + with torch.no_grad(): + with distributed_state.split_between_processes(prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists: + if '__caption__' in prompt_dict.get("prompt"): logger.info("No training prompts loaded, skipping '__caption__' prompt.") continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) - else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [] # list of lists - for i in range(distributed_state.num_processes): - per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - if prompt_dict["prompt"] == '__caption__': - logger.info("No training prompts loaded, skipping '__caption__' prompt.") - continue - sample_image_inference( - accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet - ) # clear pipeline and cache to reduce vram usage del pipeline @@ -5699,6 +5709,7 @@ def sample_image_inference( image = pipeline.latents_to_image(latents)[0] if "original_lantent" in prompt_dict: + #Prevent out of VRAM error if torch.cuda.is_available(): with torch.cuda.device(torch.cuda.current_device()): torch.cuda.empty_cache() From 216596719b8dca6b861043dc0709550f9da44525 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 21:37:57 +0800 Subject: [PATCH 120/221] Update train_util.py --- library/train_util.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8b6638812..1824a6a55 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5556,23 +5556,18 @@ def sample_images_common( if distributed_state.num_processes > 1 and distributed_state.is_main_process: per_process_prompts = [] # list of lists + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. for i in range(distributed_state.num_processes): per_process_prompts.append(prompts[i :: distributed_state.num_processes]) prompts = [] + # Flattening prompts for simplicity for prompt in per_process_prompts: prompts.extend(prompt) distributed_state.wait_for_everyone() per_process_prompts = gather_object(prompts) prompts = [] - for prompt in per_process_prompts: - prompts.extend(prompt) - - - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - - + with torch.no_grad(): with distributed_state.split_between_processes(prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists: From 2cdaa33147e75e7f7dd662f060eed37690321ae6 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 21:47:15 +0800 Subject: [PATCH 121/221] Update train_network.py --- train_network.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 5f2a31da5..7563424bc 100644 --- a/train_network.py +++ b/train_network.py @@ -1023,14 +1023,8 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes - all_lists_of_latents = gather_object(latents) - all_lists_of_captions = gather_object(batch["captions"]) - all_latents = [] - all_captions = [] - for ilatents in all_lists_of_latents: - all_latents.extend(ilatents) - for icaptions in all_lists_of_captions: - all_captions.extend(icaptions) + all_latents = gather_object(latents) + all_captions = gather_object(batch["captions"]) example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) From d179a48696f9ac79d97f34b1068f7e8ccc80712a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:28:21 +0800 Subject: [PATCH 122/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 78039b2a9..79453cfd4 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1846,6 +1846,8 @@ def __getattr__(self, item): pipe.set_control_nets(control_nets) logger.info(f"pipeline on {device} is ready.") distributed_state.wait_for_everyone() + pipes = gather_objects([pipe]) + unets = gather_objects([unet]) if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() @@ -2498,7 +2500,8 @@ def scale_and_round(x): negative_scale = args.negative_scale steps = args.steps seed = None - seeds = None + if pi == 0: + seeds = None strength = 0.8 if args.strength is None else args.strength negative_prompt = "" clip_prompt = None @@ -2578,7 +2581,11 @@ def scale_and_round(x): m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] + if pi > 0 and len(raw_prompts) > 1: #Bypass on 2nd loop for dynamic prompts + continue + logger.info(f"{m}") + seeds = m.group(1).split(",") + seeds = [int(float(d.strip())) for d in seeds] logger.info(f"seeds: {seeds}") continue @@ -2744,7 +2751,8 @@ def scale_and_round(x): if ds_depth_1 is not None: if ds_depth_1 < 0: ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + for unet in unets: + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) # override Gradual Latent if gl_timesteps is not None: @@ -2768,7 +2776,8 @@ def scale_and_round(x): us_strength, us_target_x, ) - pipe.set_gradual_latent(gradual_latent) + for pipe in pipes: + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -2776,8 +2785,12 @@ def scale_and_round(x): if len(seeds) > 1: seed = seeds.pop(0) elif len(seeds) == 1: - seed = seeds.pop(0) - seeds = None + if seeds[0] == -1: + seeds = None + else: + seed = seeds.pop(0) + + else: if predefined_seeds is not None: if len(predefined_seeds) > 0: @@ -2883,11 +2896,11 @@ def scale_and_round(x): batch_separated_list.append(split_into_batches) if distributed_state.num_processes > 1: templist = [] - for i in range(distributed_state.num_processes): - templist.append(batch_separated_list[i :: distributed_state.num_processes]) + for i in range(distributed_state.num_processes): + templist.append(batch_separated_list[i :: distributed_state.num_processes]) batch_separated_list = [] - for sub_batch_list in templist: - batch_separated_list.extend(sub_batch_list) + for sub_batch_list in templist: + batch_separated_list.extend(sub_batch_list) distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) del extinfo From 00234a0db75d3cdb025750ea071b36985b501ab8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:30:33 +0800 Subject: [PATCH 123/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 1824a6a55..0afd6097d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5530,7 +5530,7 @@ def sample_images_common( if idx == 0: break prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') - logger.info(f"Replacement prompt: {prompt_dict["prompt"]}") + logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") prompt_dict["height"] = example_tuple[0].shape[2] * 8 logger.info(f"Original Image Height: {prompt_dict["height"]}") prompt_dict["width"] = example_tuple[0].shape[3] * 8 From f7a77b68a18eda6d0c28d9620fcdf419aad8d7af Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:33:04 +0800 Subject: [PATCH 124/221] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0afd6097d..4b3980c7a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5532,9 +5532,9 @@ def sample_images_common( prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") prompt_dict["height"] = example_tuple[0].shape[2] * 8 - logger.info(f"Original Image Height: {prompt_dict["height"]}") + logger.info(f"Original Image Height: {prompt_dict['height']}") prompt_dict["width"] = example_tuple[0].shape[3] * 8 - logger.info(f"Original Image Width: {prompt_dict["width"]}") + logger.info(f"Original Image Width: {prompt_dict['width']}") prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) From 938ff562fd08d606de2e5c6d2a42dead100d786a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:40:37 +0800 Subject: [PATCH 125/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 79453cfd4..bd651c9b9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1846,8 +1846,8 @@ def __getattr__(self, item): pipe.set_control_nets(control_nets) logger.info(f"pipeline on {device} is ready.") distributed_state.wait_for_everyone() - pipes = gather_objects([pipe]) - unets = gather_objects([unet]) + pipes = gather_object([pipe]) + unets = gather_object([unet]) if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() From ba8dd4de48f1b2c7d7537f7a7f359414a0abb986 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 01:50:11 +0800 Subject: [PATCH 126/221] Update train_util.py --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 4b3980c7a..9f8a6deae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -28,6 +28,7 @@ import hashlib import subprocess from io import BytesIO +from accelerate.utils import gather_object import toml from tqdm import tqdm From c14aa5bf33b07c105515a08a2a1ee9e55256dead Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 02:05:35 +0800 Subject: [PATCH 127/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bd651c9b9..471e491f8 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2579,13 +2579,14 @@ def scale_and_round(x): logger.info(f"steps: {steps}") continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + m = re.match(r"d ([-?\d,]+)", parg, re.IGNORECASE) if m: # seed if pi > 0 and len(raw_prompts) > 1: #Bypass on 2nd loop for dynamic prompts continue logger.info(f"{m}") seeds = m.group(1).split(",") - seeds = [int(float(d.strip())) for d in seeds] + + seeds = [int(d) for d in seeds] logger.info(f"seeds: {seeds}") continue @@ -2783,12 +2784,12 @@ def scale_and_round(x): if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う if len(seeds) > 1: - seed = seeds.pop(0) + seed = abs(seeds.pop(0)) elif len(seeds) == 1: if seeds[0] == -1: seeds = None else: - seed = seeds.pop(0) + seed = abs(seeds.pop(0)) else: From 13f764dbbe3df762b6c349e9009bf00699f9fd3d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 02:06:24 +0800 Subject: [PATCH 128/221] Update train_network.py --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 7563424bc..54949f218 100644 --- a/train_network.py +++ b/train_network.py @@ -39,6 +39,7 @@ apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +from accelerate.utils import gather_object setup_logging() import logging From b5245204eb280ccb619cd590beeaa539bb7b2b9a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 02:16:42 +0800 Subject: [PATCH 129/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 471e491f8..0596e524c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2787,7 +2787,9 @@ def scale_and_round(x): seed = abs(seeds.pop(0)) elif len(seeds) == 1: if seeds[0] == -1: + logger.error("predefined seeds are exhausted") seeds = None + seed = None else: seed = abs(seeds.pop(0)) From 56f080508c96995e410b2aa02cb04f811cf75b88 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 02:18:43 +0800 Subject: [PATCH 130/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0596e524c..23cbbb1c7 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2583,9 +2583,7 @@ def scale_and_round(x): if m: # seed if pi > 0 and len(raw_prompts) > 1: #Bypass on 2nd loop for dynamic prompts continue - logger.info(f"{m}") seeds = m.group(1).split(",") - seeds = [int(d) for d in seeds] logger.info(f"seeds: {seeds}") continue @@ -2787,7 +2785,7 @@ def scale_and_round(x): seed = abs(seeds.pop(0)) elif len(seeds) == 1: if seeds[0] == -1: - logger.error("predefined seeds are exhausted") + logger.error("predefined seeds in prompt are exhausted") seeds = None seed = None else: From 5308a0bfc10268cbc8739627f1ac4d888a5486f8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 02:22:25 +0800 Subject: [PATCH 131/221] Update train_util.py --- library/train_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 9f8a6deae..44bba3f4d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5566,8 +5566,7 @@ def sample_images_common( for prompt in per_process_prompts: prompts.extend(prompt) distributed_state.wait_for_everyone() - per_process_prompts = gather_object(prompts) - prompts = [] + prompts = gather_object(prompts) with torch.no_grad(): with distributed_state.split_between_processes(prompts) as prompt_dict_lists: From 086e90dd4fc4469bfd0333e2e8cd8efc00079f47 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 05:16:12 +0800 Subject: [PATCH 132/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 23cbbb1c7..c923e2e6e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2382,7 +2382,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + ds_str = time.strftime("%Y%m%d", time.localtime()) + ts_str = time.strftime("%H%M%S", time.localtime()) for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): @@ -2417,7 +2418,7 @@ def scale_and_round(x): elif args.sequential_file_name: fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = f"im_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") if negative_prompt is not None: From e10eabfb83e843bfb29951f1c230cb6fc3604af8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 11:49:08 +0800 Subject: [PATCH 133/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index c923e2e6e..a7cadcb4c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1846,8 +1846,8 @@ def __getattr__(self, item): pipe.set_control_nets(control_nets) logger.info(f"pipeline on {device} is ready.") distributed_state.wait_for_everyone() - pipes = gather_object([pipe]) - unets = gather_object([unet]) + #pipes = gather_object([pipe]) + #unets = gather_object([unet]) if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() From eb7c0b563db3f9ef89ffea55512355c89ae18db2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 11:58:17 +0800 Subject: [PATCH 134/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index a7cadcb4c..8a9bcd74c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2897,14 +2897,18 @@ def scale_and_round(x): split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) batch_separated_list.append(split_into_batches) if distributed_state.num_processes > 1: + logger.info(f"batch_separated_list: {len(batch_separated_list)}") templist = [] for i in range(distributed_state.num_processes): templist.append(batch_separated_list[i :: distributed_state.num_processes]) + logger.info(f"templist: {len(templist)}") batch_separated_list = [] for sub_batch_list in templist: batch_separated_list.extend(sub_batch_list) + logger.info(f"batch_separated_list: {len(batch_separated_list)}") distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) + logger.info(f"batch_data: {len(batch_data)}") del extinfo if len(batch_data) > 0: From e607dc45ff7a74d34218b04cd64885a136ff1f03 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:04:49 +0800 Subject: [PATCH 135/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8a9bcd74c..91c2a9ba8 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2874,6 +2874,7 @@ def scale_and_round(x): global_step += 1 prompt_index += 1 + batch_data = gather_object(batch_data) batch_separated_list = [] if distributed_state.is_main_process and len(batch_data) > 0: unique_extinfo = list(set(extinfo)) @@ -2908,7 +2909,7 @@ def scale_and_round(x): logger.info(f"batch_separated_list: {len(batch_separated_list)}") distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) - logger.info(f"batch_data: {len(batch_data)}") + logger.info(f"batch_data line 2912: {len(batch_data)}") del extinfo if len(batch_data) > 0: From aafcac67ebc005ee392d9f5eeabf304d8879ac1f Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:11:21 +0800 Subject: [PATCH 136/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 91c2a9ba8..16d4a2540 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2875,6 +2875,7 @@ def scale_and_round(x): prompt_index += 1 batch_data = gather_object(batch_data) + logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] if distributed_state.is_main_process and len(batch_data) > 0: unique_extinfo = list(set(extinfo)) From b6ba98753c25bf41071906680d9b72efc9a8f702 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:18:50 +0800 Subject: [PATCH 137/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 16d4a2540..be8a298d2 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2751,8 +2751,7 @@ def scale_and_round(x): if ds_depth_1 is not None: if ds_depth_1 < 0: ds_depth_1 = args.ds_depth_1 or 3 - for unet in unets: - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) # override Gradual Latent if gl_timesteps is not None: @@ -2776,8 +2775,7 @@ def scale_and_round(x): us_strength, us_target_x, ) - for pipe in pipes: - pipe.set_gradual_latent(gradual_latent) + pipe.set_gradual_latent(gradual_latent) # prepare seed if seeds is not None: # given in prompt @@ -2877,8 +2875,10 @@ def scale_and_round(x): batch_data = gather_object(batch_data) logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] + logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") if distributed_state.is_main_process and len(batch_data) > 0: unique_extinfo = list(set(extinfo)) + logger.info(f"batch_data line 2880: {len(batch_data)}") # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): templist = [] @@ -2910,7 +2910,7 @@ def scale_and_round(x): logger.info(f"batch_separated_list: {len(batch_separated_list)}") distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) - logger.info(f"batch_data line 2912: {len(batch_data)}") + logger.info(f"batch_data line 2911: {len(batch_data)}") del extinfo if len(batch_data) > 0: From 53050380716e0fce4967604e1123affb273b2029 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:33:04 +0800 Subject: [PATCH 138/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index be8a298d2..336a91b68 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1510,7 +1510,7 @@ def main(args): if len(files) == 1: args.ckpt = files[0] #device = get_preferred_device() - logger.info(f"preferred device: {device}") + logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) for pi in range(distributed_state.num_processes): @@ -2898,6 +2898,7 @@ def scale_and_round(x): n, m = divmod(len(sublist), device) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) batch_separated_list.append(split_into_batches) + logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") if distributed_state.num_processes > 1: logger.info(f"batch_separated_list: {len(batch_separated_list)}") templist = [] From 91986ec4a2340ee5c3c50f06126206af00087a6b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:38:30 +0800 Subject: [PATCH 139/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 63 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 336a91b68..bc6565118 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2876,39 +2876,40 @@ def scale_and_round(x): logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") - if distributed_state.is_main_process and len(batch_data) > 0: - unique_extinfo = list(set(extinfo)) - logger.info(f"batch_data line 2880: {len(batch_data)}") - # splits list of prompts into sublists where BatchDataExt ext is identical - for i in range(len(unique_extinfo)): - templist = [] - res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] - for index in res: - templist.append(batch_data[index]) - split_into_batches = get_batches(items=templist, batch_size=args.batch_size) - if(len(split_into_batches) % distributed_state.num_processes != 0): - #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - sublist = [] - for j in range(len(split_into_batches) % distributed_state.num_processes): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - listofbatches = [] - n, m = divmod(len(sublist), device) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) - batch_separated_list.append(split_into_batches) - logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") - if distributed_state.num_processes > 1: - logger.info(f"batch_separated_list: {len(batch_separated_list)}") + if len(batch_data) > 0: + if distributed_state.is_main_process: + unique_extinfo = list(set(extinfo)) + logger.info(f"batch_data line 2880: {len(batch_data)}") + # splits list of prompts into sublists where BatchDataExt ext is identical + for i in range(len(unique_extinfo)): templist = [] - for i in range(distributed_state.num_processes): - templist.append(batch_separated_list[i :: distributed_state.num_processes]) - logger.info(f"templist: {len(templist)}") - batch_separated_list = [] - for sub_batch_list in templist: - batch_separated_list.extend(sub_batch_list) + res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] + for index in res: + templist.append(batch_data[index]) + split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + if(len(split_into_batches) % distributed_state.num_processes != 0): + #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch + sublist = [] + for j in range(len(split_into_batches) % distributed_state.num_processes): + if len(split_into_batches) > 1 : + sublist.extend(split_into_batches.pop(-1)) + elif len(split_into_batches) == 1 : + sublist.extend(split_into_batches.pop(-1)) + listofbatches = [] + n, m = divmod(len(sublist), device) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) + batch_separated_list.append(split_into_batches) + logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") + if distributed_state.num_processes > 1: logger.info(f"batch_separated_list: {len(batch_separated_list)}") + templist = [] + for i in range(distributed_state.num_processes): + templist.append(batch_separated_list[i :: distributed_state.num_processes]) + logger.info(f"templist: {len(templist)}") + batch_separated_list = [] + for sub_batch_list in templist: + batch_separated_list.extend(sub_batch_list) + logger.info(f"batch_separated_list: {len(batch_separated_list)}") distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) logger.info(f"batch_data line 2911: {len(batch_data)}") From 37331268e18b57e5c87b8ae7420b941219550d5e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:47:31 +0800 Subject: [PATCH 140/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bc6565118..9f437606a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2879,9 +2879,10 @@ def scale_and_round(x): if len(batch_data) > 0: if distributed_state.is_main_process: unique_extinfo = list(set(extinfo)) - logger.info(f"batch_data line 2880: {len(batch_data)}") + logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}") # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): + logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}") templist = [] res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] for index in res: From 0e334aee60394b4174f34ee7969d594b776c48c1 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:52:55 +0800 Subject: [PATCH 141/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 9f437606a..aa3c5965f 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2445,7 +2445,7 @@ def scale_and_round(x): global_step = 0 batch_data = [] extinfo = [] - while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)): + while args.interactive or (prompt_index < len(prompt_list) and distributed_state.is_main_process): if len(prompt_list) == 0: # interactive valid = False @@ -2873,6 +2873,7 @@ def scale_and_round(x): prompt_index += 1 batch_data = gather_object(batch_data) + extinfo = gather_object(extinfo) logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") From efb3722ff0b2824233d3e31f2eedf5f3fc7b247c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 12:56:31 +0800 Subject: [PATCH 142/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index aa3c5965f..410f73d5c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2898,8 +2898,8 @@ def scale_and_round(x): elif len(split_into_batches) == 1 : sublist.extend(split_into_batches.pop(-1)) listofbatches = [] - n, m = divmod(len(sublist), device) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) + n, m = divmod(len(sublist), distributed_state.num_processes) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) batch_separated_list.append(split_into_batches) logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") if distributed_state.num_processes > 1: From 03232a61946a76b17b450800fffde3902ba3895e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 21:33:46 +0800 Subject: [PATCH 143/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 72 ++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 410f73d5c..ca30bb507 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2877,41 +2877,43 @@ def scale_and_round(x): logger.info(f"batch_data line 2878: {len(batch_data)}") batch_separated_list = [] logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") - if len(batch_data) > 0: - if distributed_state.is_main_process: - unique_extinfo = list(set(extinfo)) - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}") - # splits list of prompts into sublists where BatchDataExt ext is identical - for i in range(len(unique_extinfo)): - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}") - templist = [] - res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] - for index in res: - templist.append(batch_data[index]) - split_into_batches = get_batches(items=templist, batch_size=args.batch_size) - if(len(split_into_batches) % distributed_state.num_processes != 0): - #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - sublist = [] - for j in range(len(split_into_batches) % distributed_state.num_processes): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - listofbatches = [] - n, m = divmod(len(sublist), distributed_state.num_processes) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - batch_separated_list.append(split_into_batches) - logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") - if distributed_state.num_processes > 1: - logger.info(f"batch_separated_list: {len(batch_separated_list)}") - templist = [] - for i in range(distributed_state.num_processes): - templist.append(batch_separated_list[i :: distributed_state.num_processes]) - logger.info(f"templist: {len(templist)}") - batch_separated_list = [] - for sub_batch_list in templist: - batch_separated_list.extend(sub_batch_list) - logger.info(f"batch_separated_list: {len(batch_separated_list)}") + if len(batch_data) > 0 and distributed_state.is_main_process: + unique_extinfo = list(set(extinfo)) + logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}") + # splits list of prompts into sublists where BatchDataExt ext is identical + for i in range(len(unique_extinfo)): + logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}") + templist = [] + res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] + for index in res: + templist.append(batch_data[index]) + split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + if(len(split_into_batches) % distributed_state.num_processes != 0): + #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch + sublist = [] + for j in range(len(split_into_batches) % distributed_state.num_processes): + if len(split_into_batches) > 1 : + sublist.extend(split_into_batches.pop(-1)) + elif len(split_into_batches) == 1 : + sublist.extend(split_into_batches.pop(-1)) + split_into_batches = [] + n, m = divmod(len(sublist), distributed_state.num_processes) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) + batch_separated_list.extend(split_into_batches) + logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") + if distributed_state.num_processes > 1: + logger.info(f"batch_separated_list: {len(batch_separated_list)}") + + temp_list = [] + for ext_batch in batch_separated_list: + for i in range(distributed_state.num_processes): + templist.append(ext_batch[i :: distributed_state.num_processes]) + logger.info(f"templist: {len(temp_list)}") + batch_separated_list = [] + for sub_batch_list in temp_list: + batch_separated_list.append(sub_batch_list) + logger.info(f"batch_separated_list: {len(batch_separated_list)}") + logger.info(f"sub_batch_list: {len(sub_batch_list)}") distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) logger.info(f"batch_data line 2911: {len(batch_data)}") From ff77d3afe1f306cbed35745f702e0df77c2b0df8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 21:50:58 +0800 Subject: [PATCH 144/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index ca30bb507..115ec6128 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2907,7 +2907,7 @@ def scale_and_round(x): temp_list = [] for ext_batch in batch_separated_list: for i in range(distributed_state.num_processes): - templist.append(ext_batch[i :: distributed_state.num_processes]) + temp_list.append(ext_batch[i :: distributed_state.num_processes]) logger.info(f"templist: {len(temp_list)}") batch_separated_list = [] for sub_batch_list in temp_list: From cba5437c3910be2f0d0958366873ea71722b657b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 21:56:56 +0800 Subject: [PATCH 145/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 115ec6128..1d18e2b6b 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2899,7 +2899,7 @@ def scale_and_round(x): split_into_batches = [] n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - batch_separated_list.extend(split_into_batches) + batch_separated_list.append(split_into_batches) logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") if distributed_state.num_processes > 1: logger.info(f"batch_separated_list: {len(batch_separated_list)}") From 11404c1ac302cac54697f9f17933b2e19992dba8 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 22:18:52 +0800 Subject: [PATCH 146/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 1d18e2b6b..b50c87077 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2420,9 +2420,9 @@ def scale_and_round(x): else: fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") + #logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") if negative_prompt is not None: - logger.info(f"Negative Prompt: {negative_prompt}\n") + # logger.info(f"Negative Prompt: {negative_prompt}\n") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: @@ -2923,10 +2923,10 @@ def scale_and_round(x): for batch_list in batch_data: with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): - logger.info(f"Loading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") + logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") logger.info(f"batch_list:") for i in range(len(batches[j])): - logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() From 107625b8334f65f9ed5b62b8c1885411d804a56e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 22:22:38 +0800 Subject: [PATCH 147/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index b50c87077..20122cefd 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2915,12 +2915,12 @@ def scale_and_round(x): logger.info(f"batch_separated_list: {len(batch_separated_list)}") logger.info(f"sub_batch_list: {len(sub_batch_list)}") distributed_state.wait_for_everyone() - batch_data = gather_object(batch_separated_list) + batch_separated_list = gather_object(batch_separated_list) logger.info(f"batch_data line 2911: {len(batch_data)}") del extinfo - if len(batch_data) > 0: - for batch_list in batch_data: + if len(batch_separated_list) > 0: + for batch_list in batch_separated_list: with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") From 63508160911936b97a356b09c855c05583fbf51b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:10:53 +0800 Subject: [PATCH 148/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 67 ++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 20122cefd..b30edc11c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2420,9 +2420,6 @@ def scale_and_round(x): else: fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - #logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}") - if negative_prompt is not None: - # logger.info(f"Negative Prompt: {negative_prompt}\n") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: @@ -2443,7 +2440,7 @@ def scale_and_round(x): # 画像生成のプロンプトが一周するまでのループ prompt_index = 0 global_step = 0 - batch_data = [] + prompt_data_list = [] extinfo = [] while args.interactive or (prompt_index < len(prompt_list) and distributed_state.is_main_process): if len(prompt_list) == 0: @@ -2867,26 +2864,24 @@ def scale_and_round(x): num_sub_prompts, ), ) - batch_data.append(b1) + prompt_data_list.append(b1) extinfo.append(b1.ext) global_step += 1 prompt_index += 1 - batch_data = gather_object(batch_data) + prompt_data_list = gather_object(prompt_data_list) extinfo = gather_object(extinfo) - logger.info(f"batch_data line 2878: {len(batch_data)}") - batch_separated_list = [] - logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}") - if len(batch_data) > 0 and distributed_state.is_main_process: + + ext_separated_list_of_batches = [] + if len(prompt_data_list) > 0 and distributed_state.is_main_process: unique_extinfo = list(set(extinfo)) - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}") + logger.info(f"Device {distributed_state.device}, prompt_data_list line 2880: {len(prompt_data_list)}, len(unique_extinfo): {len(unique_extinfo)}") # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): - logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}") templist = [] - res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]] + res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: - templist.append(batch_data[index]) + templist.append(prompt_data_list[index]) split_into_batches = get_batches(items=templist, batch_size=args.batch_size) if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch @@ -2899,28 +2894,34 @@ def scale_and_round(x): split_into_batches = [] n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - batch_separated_list.append(split_into_batches) - logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}") + ext_separated_list_of_batches.append(split_into_batches) if distributed_state.num_processes > 1: - logger.info(f"batch_separated_list: {len(batch_separated_list)}") - temp_list = [] - for ext_batch in batch_separated_list: + + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] for i in range(distributed_state.num_processes): - temp_list.append(ext_batch[i :: distributed_state.num_processes]) - logger.info(f"templist: {len(temp_list)}") - batch_separated_list = [] - for sub_batch_list in temp_list: - batch_separated_list.append(sub_batch_list) - logger.info(f"batch_separated_list: {len(batch_separated_list)}") - logger.info(f"sub_batch_list: {len(sub_batch_list)}") + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + ext_separated_list_of_batches[x] = [] + for batches in temp_list: + ext_separated_list_of_batches[x].extend(batches) + logger.info(f"templist: {len(temp_list)}, ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") + + logger.info(f"ext_separated_list_of_batches: {len(ext_separated_list_of_batches)}") + count_prompts = 0 + for sub_batch_list in ext_separated_list_of_batches: + logger.info(f" sub_batch_list: {len(sub_batch_list)}") + for batches in sub_batch_list: + logger.info(f" batches: {len(batches)}") + count_prompts += len(batches) + logger.info(f"count_prompts: {count_prompts}") + distributed_state.wait_for_everyone() - batch_separated_list = gather_object(batch_separated_list) - logger.info(f"batch_data line 2911: {len(batch_data)}") + ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo - if len(batch_separated_list) > 0: - for batch_list in batch_separated_list: + if len(ext_separated_list_of_batches) > 0: + for batch_list in ext_separated_list_of_batches: with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") @@ -2932,17 +2933,17 @@ def scale_and_round(x): distributed_state.wait_for_everyone() #for i in range(len(data_loader)): # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") - # batch_data_split.append(data_loader[i]) + # prompt_data_list_split.append(data_loader[i]) # if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader): # continue # with torch.no_grad(): - # with distributed_state.split_between_processes(batch_data_split) as batch_list: + # with distributed_state.split_between_processes(prompt_data_list_split) as batch_list: # logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:") # logger.info(f"batch_list:") # for i in range(len(batch_list[0])): # logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}") # prev_image = process_batch(batch_list[0], highres_fix)[0] - # batch_data_split.clear() + # prompt_data_list_split.clear() # distributed_state.wait_for_everyone() logger.info("done!") From eaf4a9a16487868cefc2d019dc19742de4b92e2f Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:17:03 +0800 Subject: [PATCH 149/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index b30edc11c..5ed67d2f1 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2900,6 +2900,7 @@ def scale_and_round(x): for x in range(len(ext_separated_list_of_batches)): temp_list = [] + logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") for i in range(distributed_state.num_processes): temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) ext_separated_list_of_batches[x] = [] @@ -2924,8 +2925,7 @@ def scale_and_round(x): for batch_list in ext_separated_list_of_batches: with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): - logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:") - logger.info(f"batch_list:") + logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") for i in range(len(batches[j])): logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") prev_image = process_batch(batch_list[j], highres_fix)[0] From cc6266db19d395d9b8344bb2a5254429a63ea93a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:36:54 +0800 Subject: [PATCH 150/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5ed67d2f1..6bfa14aa9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2883,21 +2883,25 @@ def scale_and_round(x): for index in res: templist.append(prompt_data_list[index]) split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - sublist = [] - for j in range(len(split_into_batches) % distributed_state.num_processes): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - split_into_batches = [] - n, m = divmod(len(sublist), distributed_state.num_processes) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) + popnum = (len(split_into_batches) % distributed_state.num_processes + else: + #force distribution check on last round of batches + popnum = distributed_state.num_processes + + for j in range(popnum): + if len(split_into_batches) > 1 : + sublist.extend(split_into_batches.pop(-1)) + elif len(split_into_batches) == 1 : + sublist.extend(split_into_batches.pop(-1)) + split_into_batches = [] + + n, m = divmod(len(sublist), distributed_state.num_processes) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) if distributed_state.num_processes > 1: - - for x in range(len(ext_separated_list_of_batches)): temp_list = [] logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") From 4ff521ed60cad6a2d184e3b991429cdf829ae7f2 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:42:16 +0800 Subject: [PATCH 151/221] Update train_util.py --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 44bba3f4d..ff2cbc3eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5511,9 +5511,10 @@ def sample_images_common( # preprocess prompts + + save_dir = args.output_dir + "/sample" if distributed_state.is_main_process: #Create output folder and preprocess prompts on main process only. - save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) idx = 0 for i in range(len(prompts)): From 75859a6cab0ea416b4b1f0a119c83f283c97f2f4 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:45:58 +0800 Subject: [PATCH 152/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index ff2cbc3eb..51a75c3be 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4714,7 +4714,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + logger.info(f"loading model for process {accelerator.state.local_process_index+1}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( args, From 8967e2f0144ac8c158862162d5a5842fdff754c5 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 00:50:03 +0800 Subject: [PATCH 153/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6bfa14aa9..15e9ebd2c 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2469,6 +2469,8 @@ def scale_and_round(x): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] + if pi == 0 or len(raw_prompts) > 1: + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: @@ -2481,8 +2483,7 @@ def scale_and_round(x): except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") - if pi == 0 or len(raw_prompts) > 1: - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + if pi == 0: # parse prompt: if prompt is not changed, skip parsing @@ -2904,23 +2905,12 @@ def scale_and_round(x): if distributed_state.num_processes > 1: for x in range(len(ext_separated_list_of_batches)): temp_list = [] - logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") for i in range(distributed_state.num_processes): temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) ext_separated_list_of_batches[x] = [] for batches in temp_list: ext_separated_list_of_batches[x].extend(batches) - logger.info(f"templist: {len(temp_list)}, ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") - - logger.info(f"ext_separated_list_of_batches: {len(ext_separated_list_of_batches)}") - count_prompts = 0 - for sub_batch_list in ext_separated_list_of_batches: - logger.info(f" sub_batch_list: {len(sub_batch_list)}") - for batches in sub_batch_list: - logger.info(f" batches: {len(batches)}") - count_prompts += len(batches) - logger.info(f"count_prompts: {count_prompts}") - + distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo From 01430d03e00808be7d35c5d4748089afca99c82e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 01:08:10 +0800 Subject: [PATCH 154/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 15e9ebd2c..379e123bb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2887,7 +2887,7 @@ def scale_and_round(x): sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - popnum = (len(split_into_batches) % distributed_state.num_processes + popnum = len(split_into_batches) % distributed_state.num_processes else: #force distribution check on last round of batches popnum = distributed_state.num_processes From 4594ef9c9d4f55b4d039c25d7d0be5b36eb382ab Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 01:29:25 +0800 Subject: [PATCH 155/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 379e123bb..79663609a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2876,7 +2876,6 @@ def scale_and_round(x): ext_separated_list_of_batches = [] if len(prompt_data_list) > 0 and distributed_state.is_main_process: unique_extinfo = list(set(extinfo)) - logger.info(f"Device {distributed_state.device}, prompt_data_list line 2880: {len(prompt_data_list)}, len(unique_extinfo): {len(unique_extinfo)}") # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): templist = [] From 6eabaf377568e5d76b33454c1ebf3c16ba716550 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:53:08 +0800 Subject: [PATCH 156/221] Update train_util.py --- library/train_util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 51a75c3be..d81d6fac3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5531,11 +5531,12 @@ def sample_images_common( idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") - prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["height"] = example_tuple[0][idx].shape[2] * 8 logger.info(f"Original Image Height: {prompt_dict['height']}") - prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["width"] = example_tuple[0][idx].shape[3] * 8 logger.info(f"Original Image Width: {prompt_dict['width']}") prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) From 6fc38477c8008d18a59db8f27c789b612e3e9395 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 04:39:15 +0800 Subject: [PATCH 157/221] Update train_network.py --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 54949f218..e0ba83561 100644 --- a/train_network.py +++ b/train_network.py @@ -1026,6 +1026,7 @@ def remove_model(old_ckpt_name): # Collecting latents and caption lists from all processes all_latents = gather_object(latents) all_captions = gather_object(batch["captions"]) + logger.log(f"all_latents: {all_latents}") example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) From 896ade6f9218ba4e79b2d9884a5c492376f20b01 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 04:46:53 +0800 Subject: [PATCH 158/221] Update train_network.py --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e0ba83561..5d170926f 100644 --- a/train_network.py +++ b/train_network.py @@ -1026,7 +1026,8 @@ def remove_model(old_ckpt_name): # Collecting latents and caption lists from all processes all_latents = gather_object(latents) all_captions = gather_object(batch["captions"]) - logger.log(f"all_latents: {all_latents}") + logger.info(f"latents: {latents}") + logger.info(f"all_latents: {all_latents}") example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) From aea7cada42cbb3373fc04bf98f7cae0a91d77d27 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 04:54:36 +0800 Subject: [PATCH 159/221] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 5d170926f..17a3b36e5 100644 --- a/train_network.py +++ b/train_network.py @@ -1024,7 +1024,7 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes - all_latents = gather_object(latents) + all_latents = gather_object([latents]) all_captions = gather_object(batch["captions"]) logger.info(f"latents: {latents}") logger.info(f"all_latents: {all_latents}") From 83a171d3dc304905c493effb63bfd0f85f909190 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 04:56:22 +0800 Subject: [PATCH 160/221] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 17a3b36e5..2accc40a5 100644 --- a/train_network.py +++ b/train_network.py @@ -1026,8 +1026,8 @@ def remove_model(old_ckpt_name): # Collecting latents and caption lists from all processes all_latents = gather_object([latents]) all_captions = gather_object(batch["captions"]) - logger.info(f"latents: {latents}") - logger.info(f"all_latents: {all_latents}") + #logger.info(f"latents: {latents}") + #logger.info(f"all_latents: {all_latents}") example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) From 06681ce1a9006c3105c3a8dfd45a410f69d9a3ec Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 05:30:08 +0800 Subject: [PATCH 161/221] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index d81d6fac3..1b1c30cab 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5538,7 +5538,7 @@ def sample_images_common( logger.info(f"Original Image Height: {prompt_dict['height']}") prompt_dict["width"] = example_tuple[0][idx].shape[3] * 8 logger.info(f"Original Image Width: {prompt_dict['width']}") - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + prompt_dict["original_lantent"] = example_tuple[0][idx] idx = (idx + 1) % len(example_tuple[1]) prompt_dict["enum"] = i From 55179ad8094617d60a4f3810c41ff1a3e5fcdbf3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 16:29:58 +0800 Subject: [PATCH 162/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 79663609a..260a7ffe9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,12 +2916,13 @@ def scale_and_round(x): if len(ext_separated_list_of_batches) > 0: for batch_list in ext_separated_list_of_batches: - with distributed_state.split_between_processes(batch_list) as batches: - for j in range(len(batches)): - logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") - for i in range(len(batches[j])): - logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] + with torch.no_grad(): + with distributed_state.split_between_processes(batch_list) as batches: + for j in range(len(batches)): + logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") + for i in range(len(batches[j])): + logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + prev_image = process_batch(batch_list[j], highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From df365c5cd027f73a6b22be39a18e0a21d821df20 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 17:00:09 +0800 Subject: [PATCH 163/221] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1b1c30cab..0618b4aaf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5574,8 +5574,8 @@ def sample_images_common( with distributed_state.split_between_processes(prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists: if '__caption__' in prompt_dict.get("prompt"): - logger.info("No training prompts loaded, skipping '__caption__' prompt.") - continue + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'Astronaut riding a horse on the moon') + logger.info("No training prompts loaded, replacing with placeholder 'Astronaut riding a horse on the moon' prompt.") sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) From b3b856393cdebc355ab8e631553b09736d5bfe30 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 03:38:24 +0800 Subject: [PATCH 164/221] Update train_network.py --- train_network.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 2accc40a5..9364cc194 100644 --- a/train_network.py +++ b/train_network.py @@ -39,7 +39,7 @@ apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments -from accelerate.utils import gather_object +from accelerate.utils import gather_object, gather setup_logging() import logging @@ -1024,10 +1024,11 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes - all_latents = gather_object([latents]) + logger.info(f"latents.size: {latents.size()} before gather on device {accelerator.state.local_process_index}") + all_latents = gather(latents) all_captions = gather_object(batch["captions"]) #logger.info(f"latents: {latents}") - #logger.info(f"all_latents: {all_latents}") + logger.info(f"all_latents.size: {all_latents.size()}") example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) From 639a60de8a2b456fda23c4ffc6aa2075353431b9 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 04:05:55 +0800 Subject: [PATCH 165/221] Update train_network.py --- train_network.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/train_network.py b/train_network.py index 9364cc194..1a6ba5eb2 100644 --- a/train_network.py +++ b/train_network.py @@ -132,8 +132,8 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, latents_list=None): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, latents_list) def train(self, args): session_id = random.randint(0, 2**32) @@ -1025,16 +1025,22 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes logger.info(f"latents.size: {latents.size()} before gather on device {accelerator.state.local_process_index}") - all_latents = gather(latents) - all_captions = gather_object(batch["captions"]) - #logger.info(f"latents: {latents}") - logger.info(f"all_latents.size: {all_latents.size()}") - example_tuple = (all_latents, all_captions) + #Converts batch of latents into list of dicts containing individual latents, height and width to merge across processes + #Allows for different latent sizes + latents_list = [] + for idx in range(len(batch["captions"]): + latent_dict = {} + latent_dict["prompt"] = batch["captions"][idx] + latent_dict["height"] = latents.shape[2] * 8 + latent_dict["width"] = latents.shape[3] * 8 + latent_dict["original_lantent"] = latents[idx].unsqueeze(0) + latents_list.append(latent_dict) + latents_list = gather_object(latents_list) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, latents_list) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1090,7 +1096,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, latents_list) # end of epoch From 9d5c051ada36d21cdfc89daab30c412db0a8e559 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 04:13:53 +0800 Subject: [PATCH 166/221] Update train_util.py --- library/train_util.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0618b4aaf..b32b80d2a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5432,7 +5432,7 @@ def sample_images_common( tokenizer, text_encoder, unet, - example_tuple=None, + latents_list=None, prompt_replacement=None, controlnet=None, ): @@ -5524,22 +5524,22 @@ def sample_images_common( prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - if '__caption__' in prompt_dict.get("prompt") and example_tuple: + if '__caption__' in prompt_dict.get("prompt") and latents_list: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - while example_tuple[1][idx] == '': - idx = (idx + 1) % len(example_tuple[1]) + while latents_list[idx]["prompt"] == '': + idx = (idx + 1) % len(latents_list) if idx == 0: break - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{latents_list[idx]["prompt"]}') logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") - prompt_dict["height"] = example_tuple[0][idx].shape[2] * 8 + prompt_dict["height"] = latents_list[idx]["height"] logger.info(f"Original Image Height: {prompt_dict['height']}") - prompt_dict["width"] = example_tuple[0][idx].shape[3] * 8 + prompt_dict["width"] = latents_list[idx]["width"] logger.info(f"Original Image Width: {prompt_dict['width']}") - prompt_dict["original_lantent"] = example_tuple[0][idx] - idx = (idx + 1) % len(example_tuple[1]) + prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"] + idx = (idx + 1) % len(latents_list) prompt_dict["enum"] = i prompt_dict.pop("subset", None) From 07a2e3571f15a40248bce2ab7ccc9ff0fb314c0e Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 04:48:34 +0800 Subject: [PATCH 167/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 260a7ffe9..11a756eaf 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2924,7 +2924,7 @@ def scale_and_round(x): logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") prev_image = process_batch(batch_list[j], highres_fix)[0] - distributed_state.wait_for_everyone() + distributed_state.wait_for_everyone() #for i in range(len(data_loader)): # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") # prompt_data_list_split.append(data_loader[i]) From e250e7aad052ee81f86e12d60082c04b728f9348 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 05:09:14 +0800 Subject: [PATCH 168/221] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 1a6ba5eb2..b498df92f 100644 --- a/train_network.py +++ b/train_network.py @@ -1028,7 +1028,7 @@ def remove_model(old_ckpt_name): #Converts batch of latents into list of dicts containing individual latents, height and width to merge across processes #Allows for different latent sizes latents_list = [] - for idx in range(len(batch["captions"]): + for idx in range(len(batch["captions"])): latent_dict = {} latent_dict["prompt"] = batch["captions"][idx] latent_dict["height"] = latents.shape[2] * 8 From 4c5ec694f77080e2e05987e42d12df7bbc882e0a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 05:22:00 +0800 Subject: [PATCH 169/221] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index b498df92f..3a2243e64 100644 --- a/train_network.py +++ b/train_network.py @@ -1024,9 +1024,9 @@ def remove_model(old_ckpt_name): # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes - logger.info(f"latents.size: {latents.size()} before gather on device {accelerator.state.local_process_index}") - #Converts batch of latents into list of dicts containing individual latents, height and width to merge across processes - #Allows for different latent sizes + + # Converts batch of latents into list of dicts containing individual latents, height and width to merge across processes + # Allows for different latent sizes latents_list = [] for idx in range(len(batch["captions"])): latent_dict = {} From e6433ec0350f67f061eec8386133cb043fb331c7 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:37:44 +0800 Subject: [PATCH 170/221] Update train_network.py --- train_network.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/train_network.py b/train_network.py index 3a2243e64..6ec1553a7 100644 --- a/train_network.py +++ b/train_network.py @@ -132,8 +132,8 @@ def all_reduce_network(self, accelerator, network): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, latents_list=None): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, latents_list) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) def train(self, args): session_id = random.randint(0, 2**32) @@ -1023,24 +1023,13 @@ def remove_model(old_ckpt_name): keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes - # Collecting latents and caption lists from all processes + example_tuple = (latents, batch["captions"]) - # Converts batch of latents into list of dicts containing individual latents, height and width to merge across processes - # Allows for different latent sizes - latents_list = [] - for idx in range(len(batch["captions"])): - latent_dict = {} - latent_dict["prompt"] = batch["captions"][idx] - latent_dict["height"] = latents.shape[2] * 8 - latent_dict["width"] = latents.shape[3] * 8 - latent_dict["original_lantent"] = latents[idx].unsqueeze(0) - latents_list.append(latent_dict) - latents_list = gather_object(latents_list) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, latents_list) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1096,7 +1085,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, latents_list) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch From 6cd468a6b3e76ad806a0782f502be950dd8d6365 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:50:35 +0800 Subject: [PATCH 171/221] Update train_util.py --- library/train_util.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b32b80d2a..1d4bf90d4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5432,7 +5432,7 @@ def sample_images_common( tokenizer, text_encoder, unet, - latents_list=None, + example_tuple=None, prompt_replacement=None, controlnet=None, ): @@ -5511,12 +5511,23 @@ def sample_images_common( # preprocess prompts - + if example_tuple: + latents_list = [] + for idx in range(len(example_tuple[1])): + latent_dict = {} + latent_dict["prompt"] = example_tuple[1][idx] + latent_dict["height"] = example_tuple[0].shape[2] * 8 + latent_dict["width"] = example_tuple[0].shape[3] * 8 + latent_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + latents_list.append(latent_dict) + distributed_state.wait_for_everyone() + latents_list = gather_object(latents_list) save_dir = args.output_dir + "/sample" if distributed_state.is_main_process: #Create output folder and preprocess prompts on main process only. os.makedirs(save_dir, exist_ok=True) idx = 0 + for i in range(len(prompts)): prompt_dict = prompts[i] if isinstance(prompt_dict, str): @@ -5524,7 +5535,7 @@ def sample_images_common( prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - if '__caption__' in prompt_dict.get("prompt") and latents_list: + if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") while latents_list[idx]["prompt"] == '': From e8faf4f3f1c102a80ebe12c2a07288e6e529ca9a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 3 Feb 2025 19:25:13 +0800 Subject: [PATCH 172/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 124 +++++++++++++++++++++++------------------- 1 file changed, 68 insertions(+), 56 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 11a756eaf..c0a8c6c15 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2120,7 +2120,7 @@ def resize_images(imgs, size): iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + def process_batch(batch: List[BatchData], distributed_state, highres_fix, highres_1st=False): batch_size = len(batch) # highres_fixの処理 @@ -2170,7 +2170,7 @@ def scale_and_round(x): batch_1st.append(BatchData(is_1st_latent, global_count, base, ext_1st)) pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) + images_1st = process_batch(batch_1st, distributed_state, True, True) # 2nd stageのバッチを作成して以下処理する logger.info("process 2nd stage") @@ -2381,60 +2381,72 @@ def scale_and_round(x): return images # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ds_str = time.strftime("%Y%m%d", time.localtime()) - ts_str = time.strftime("%H%M%S", time.localtime()) - for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + distributed_state.wait_for_everyone() + all_images=gather_object(images) + all_global_counter = gather_object(global_counter) + all_prompts = gather_object(prompts) + all_negative_prompts = gather_object(negative_prompts) + all_seeds = gather_object(seeds) + all_clip_prompts = gather_object(clip_prompts) + all_raw_prompts = gather_object(raw_prompts) + all_init_images = gather_object(init_images) + if distributed_state.is_main_process: + + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ds_str = time.strftime("%Y%m%d", time.localtime()) + ts_str = time.strftime("%H%M%S", time.localtime()) + for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(all_images, all_global_counter, all_prompts, all_negative_prompts, all_seeds, all_clip_prompts, all_raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.error( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) - + fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" + logger.info(f"Saving image {global_count}: {fln}\nPrompt: {prompt}") + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + + distributed_state.wait_for_everyone() return images # 画像生成のプロンプトが一周するまでのループ @@ -2922,7 +2934,7 @@ def scale_and_round(x): logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") for i in range(len(batches[j])): logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] + prev_image = process_batch(batch_list[j], distributed_state, highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From debf4af8fe0ac3702d8fd87a781600bd92a5612b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 4 Feb 2025 01:52:05 +0800 Subject: [PATCH 173/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index c0a8c6c15..e28212c78 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2430,7 +2430,7 @@ def scale_and_round(x): fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" - logger.info(f"Saving image {global_count}: {fln}\nPrompt: {prompt}") + logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt}") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: From f36b545480dbe900ca4d21ae6685d3cc1e567e88 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 6 Feb 2025 19:14:05 +0800 Subject: [PATCH 174/221] Update train_network.py --- train_network.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 6ec1553a7..7105a1fa8 100644 --- a/train_network.py +++ b/train_network.py @@ -1028,8 +1028,9 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + if args.sample_every_n_steps is not None and steps % args.sample_every_n_steps != 0: + example_tuple = (latents, batch["captions"]) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1085,7 +1086,9 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) + if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs != 0: + example_tuple = (latents, batch["captions"]) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch From c1e387f9f9321b85c47e32f33f1351cb46ef679c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 7 Feb 2025 01:18:43 +0800 Subject: [PATCH 175/221] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 7105a1fa8..e7c67eb4c 100644 --- a/train_network.py +++ b/train_network.py @@ -1028,7 +1028,7 @@ def remove_model(old_ckpt_name): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if args.sample_every_n_steps is not None and steps % args.sample_every_n_steps != 0: + if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0: example_tuple = (latents, batch["captions"]) self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) From 8d14893598fd9dc400cc2144b27f0e7bd1f71106 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 7 Feb 2025 01:28:26 +0800 Subject: [PATCH 176/221] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index e7c67eb4c..9f4e3a872 100644 --- a/train_network.py +++ b/train_network.py @@ -1086,7 +1086,7 @@ def remove_model(old_ckpt_name): if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs != 0: + if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs == 0: example_tuple = (latents, batch["captions"]) self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) From 984441c014188545d8bd259c5c28fd0b19896c74 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:30:31 +0800 Subject: [PATCH 177/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e28212c78..0c75731b3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2389,7 +2389,10 @@ def scale_and_round(x): all_seeds = gather_object(seeds) all_clip_prompts = gather_object(clip_prompts) all_raw_prompts = gather_object(raw_prompts) - all_init_images = gather_object(init_images) + if init_images not None: + all_init_images = gather_object(init_images) + else: + all_init_images = None if distributed_state.is_main_process: highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" From 94c57044af1dca48a2cd147228ae206d6d953af0 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:32:41 +0800 Subject: [PATCH 178/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0c75731b3..abcc5a058 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2389,7 +2389,7 @@ def scale_and_round(x): all_seeds = gather_object(seeds) all_clip_prompts = gather_object(clip_prompts) all_raw_prompts = gather_object(raw_prompts) - if init_images not None: + if init_images is not None: all_init_images = gather_object(init_images) else: all_init_images = None From a53aa0ad199ebd15e979c8c2907e314a36720641 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:38:22 +0800 Subject: [PATCH 179/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index abcc5a058..97ce5e366 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,14 +2916,14 @@ def scale_and_round(x): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - if distributed_state.num_processes > 1: - for x in range(len(ext_separated_list_of_batches)): - temp_list = [] - for i in range(distributed_state.num_processes): - temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - ext_separated_list_of_batches[x] = [] - for batches in temp_list: - ext_separated_list_of_batches[x].extend(batches) + #if distributed_state.num_processes > 1: + # for x in range(len(ext_separated_list_of_batches)): + # temp_list = [] + # for i in range(distributed_state.num_processes): + # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + # ext_separated_list_of_batches[x] = [] + # for batches in temp_list: + # ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) From 03611217bfd217bebb7426bff6fdbed9831e7a25 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:46:46 +0800 Subject: [PATCH 180/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 97ce5e366..3b25d2c51 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2926,7 +2926,7 @@ def scale_and_round(x): # ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() - ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) + # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo if len(ext_separated_list_of_batches) > 0: From b56634093ce9f31c2d857646ff55a712b5255f06 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:53:17 +0800 Subject: [PATCH 181/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3b25d2c51..d72753c97 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2889,7 +2889,7 @@ def scale_and_round(x): extinfo = gather_object(extinfo) ext_separated_list_of_batches = [] - if len(prompt_data_list) > 0 and distributed_state.is_main_process: + if len(prompt_data_list) > 0: unique_extinfo = list(set(extinfo)) # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): @@ -2916,14 +2916,14 @@ def scale_and_round(x): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + ext_separated_list_of_batches[x] = [] + for batches in temp_list: + ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) From c126db949ed79d7831642542c3ec43e9386b790a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 11:08:34 +0800 Subject: [PATCH 182/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d72753c97..459ff496f 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,14 +2916,14 @@ def scale_and_round(x): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - if distributed_state.num_processes > 1: - for x in range(len(ext_separated_list_of_batches)): - temp_list = [] - for i in range(distributed_state.num_processes): - temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - ext_separated_list_of_batches[x] = [] - for batches in temp_list: - ext_separated_list_of_batches[x].extend(batches) + #if distributed_state.num_processes > 1: + # for x in range(len(ext_separated_list_of_batches)): + # temp_list = [] + # for i in range(distributed_state.num_processes): + # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + # ext_separated_list_of_batches[x] = [] + # for batches in temp_list: + # ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) From 6e30e5726b4404d9bfca16046350377898f6fd5d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 12:51:18 +0800 Subject: [PATCH 183/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 459ff496f..e2fe2edaf 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2928,15 +2928,16 @@ def scale_and_round(x): distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo - + logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if len(ext_separated_list_of_batches) > 0: for batch_list in ext_separated_list_of_batches: with torch.no_grad(): with distributed_state.split_between_processes(batch_list) as batches: for j in range(len(batches)): - logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") + batchlogstr=f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:" for i in range(len(batches[j])): - logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") + batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" + logger.info(batchlogstr) prev_image = process_batch(batch_list[j], distributed_state, highres_fix)[0] distributed_state.wait_for_everyone() From 68a039f57bb5b803dc23e755c92b451c13fdd55a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:17:51 +0800 Subject: [PATCH 184/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e2fe2edaf..6d28b945a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,14 +2916,16 @@ def scale_and_round(x): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + holder = [] + for batches in temp_list: + holder.extend(batches) + ext_separated_list_of_batches[x] = holder[:] distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) From 79962b8e8e8cd31a790a38560b0c8a2af02d4144 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:39:10 +0800 Subject: [PATCH 185/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 + 1 file changed, 1 insertion(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6d28b945a..22cf9fa6e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2913,6 +2913,7 @@ def scale_and_round(x): sublist.extend(split_into_batches.pop(-1)) split_into_batches = [] + sublist.reverse() n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) From 1738b7f419cd52d09422b50e1baa89d9c3bb75ac Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:41:24 +0800 Subject: [PATCH 186/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 22cf9fa6e..ea93597cb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2889,7 +2889,7 @@ def scale_and_round(x): extinfo = gather_object(extinfo) ext_separated_list_of_batches = [] - if len(prompt_data_list) > 0: + if len(prompt_data_list) > 0 and distributed_state.is_main_process: unique_extinfo = list(set(extinfo)) # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): @@ -2929,7 +2929,7 @@ def scale_and_round(x): ext_separated_list_of_batches[x] = holder[:] distributed_state.wait_for_everyone() - # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) + ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if len(ext_separated_list_of_batches) > 0: From 0105b920aba4390e30e5b4dc7a7e8d096e3fb338 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 19:17:45 +0800 Subject: [PATCH 187/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index ea93597cb..87f377bdb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2897,7 +2897,8 @@ def scale_and_round(x): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) - split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() + ''' sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch @@ -2916,7 +2917,9 @@ def scale_and_round(x): sublist.reverse() n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) + ''' ext_separated_list_of_batches.append(split_into_batches) + ''' logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if distributed_state.num_processes > 1: for x in range(len(ext_separated_list_of_batches)): @@ -2927,11 +2930,11 @@ def scale_and_round(x): for batches in temp_list: holder.extend(batches) ext_separated_list_of_batches[x] = holder[:] - + ''' distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo - logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") + #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if len(ext_separated_list_of_batches) > 0: for batch_list in ext_separated_list_of_batches: with torch.no_grad(): From f06d23cb1f468bd2f32cb16a76332cf3b3c86aee Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:05:34 +0800 Subject: [PATCH 188/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 87f377bdb..0d50f20ef 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1513,13 +1513,12 @@ def main(args): logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - for pi in range(distributed_state.num_processes): - if pi == distributed_state.local_process_index: - logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) + + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone() # xformers、Hypernetwork対応 @@ -2897,8 +2896,14 @@ def scale_and_round(x): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) + if distributed_state.num_processes > 1: + resorted_list = [] + for i in range(distributed_state.num_processes): + resorted_list.append(templist[i :: distributed_state.num_processes]) + for list in resorted_list: + templist.extend(list) split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() - ''' + sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch @@ -2917,7 +2922,7 @@ def scale_and_round(x): sublist.reverse() n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - ''' + ext_separated_list_of_batches.append(split_into_batches) ''' logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From c0c17a1ceda63222ceb7b0c647c7ac4b24e02d50 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:22:40 +0800 Subject: [PATCH 189/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0d50f20ef..ea05fbf65 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2900,8 +2900,8 @@ def scale_and_round(x): resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) - for list in resorted_list: - templist.extend(list) + for list_of_prompts in resorted_list: + templist.extend(list_of_prompts) split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() sublist = [] From fcb046b8a3476e17b362e3334f820bc499036f74 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:40:08 +0800 Subject: [PATCH 190/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index ea05fbf65..dea87ff26 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2937,6 +2937,15 @@ def scale_and_round(x): ext_separated_list_of_batches[x] = holder[:] ''' distributed_state.wait_for_everyone() + if distributed_state.is_main_process: + batchlogstr = "Running through ext_separated_list_of_batches:\n" + for x in range(len(ext_separated_list_of_batches)): + batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" + for y in range(len(ext_separated_list_of_batches[x])): + batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" + for z in range(len(ext_separated_list_of_batches[x][y])): + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break:\n" + logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From 3be7042b6d4a5b12d1b62e8cdb794d4f77497214 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:41:43 +0800 Subject: [PATCH 191/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index dea87ff26..7101f9082 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2938,7 +2938,7 @@ def scale_and_round(x): ''' distributed_state.wait_for_everyone() if distributed_state.is_main_process: - batchlogstr = "Running through ext_separated_list_of_batches:\n" + batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" for y in range(len(ext_separated_list_of_batches[x])): @@ -2947,6 +2947,15 @@ def scale_and_round(x): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break:\n" logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) + if distributed_state.is_main_process: + batchlogstr = "Running through ext_separated_list_of_batches After Gather:\n" + for x in range(len(ext_separated_list_of_batches)): + batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" + for y in range(len(ext_separated_list_of_batches[x])): + batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" + for z in range(len(ext_separated_list_of_batches[x][y])): + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break:\n" + logger.info(batchlogstr) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if len(ext_separated_list_of_batches) > 0: From 8af7ab1add6b6b2ee515ffca2ad0829c92821883 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:43:12 +0800 Subject: [PATCH 192/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7101f9082..8e9764699 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2944,7 +2944,7 @@ def scale_and_round(x): for y in range(len(ext_separated_list_of_batches[x])): batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break:\n" + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) if distributed_state.is_main_process: @@ -2954,7 +2954,7 @@ def scale_and_round(x): for y in range(len(ext_separated_list_of_batches[x])): batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break:\n" + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From 19e6bf0979d624bfe3aae2c24b89892ca7c8ecc9 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:47:02 +0800 Subject: [PATCH 193/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8e9764699..35d99c63e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2939,22 +2939,22 @@ def scale_and_round(x): distributed_state.wait_for_everyone() if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" - for x in range(len(ext_separated_list_of_batches)): - batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" - for y in range(len(ext_separated_list_of_batches[x])): - batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" - for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" + for x in range(len(ext_separated_list_of_batches)): + batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" + for y in range(len(ext_separated_list_of_batches[x])): + batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" + for z in range(len(ext_separated_list_of_batches[x][y])): + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches After Gather:\n" - for x in range(len(ext_separated_list_of_batches)): - batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" - for y in range(len(ext_separated_list_of_batches[x])): - batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" - for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" + for x in range(len(ext_separated_list_of_batches)): + batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" + for y in range(len(ext_separated_list_of_batches[x])): + batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" + for z in range(len(ext_separated_list_of_batches[x][y])): + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From c4dd71212bd2a4c403003ac19d55304cc5fad39b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:55:20 +0800 Subject: [PATCH 194/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 35d99c63e..27ed79aa1 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2945,16 +2945,16 @@ def scale_and_round(x): batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" - logger.info(batchlogstr) + logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) - if distributed_state.is_main_process: - batchlogstr = "Running through ext_separated_list_of_batches After Gather:\n" - for x in range(len(ext_separated_list_of_batches)): - batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)} break:\n" - for y in range(len(ext_separated_list_of_batches[x])): - batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])} break:\n" - for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" + + batchlogstr = "Running through ext_separated_list_of_batches After Gather:\n" + for x in range(len(ext_separated_list_of_batches)): + batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)}:\n" + for y in range(len(ext_separated_list_of_batches[x])): + batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])}:\n" + for z in range(len(ext_separated_list_of_batches[x][y])): + batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])}: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From 4a744c375c8d597cfc5d3ce86087c25038086dbe Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:57:01 +0800 Subject: [PATCH 195/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 27ed79aa1..8062bdbce 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2896,12 +2896,14 @@ def scale_and_round(x): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) + ''' if distributed_state.num_processes > 1: resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) for list_of_prompts in resorted_list: templist.extend(list_of_prompts) + ''' split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() sublist = [] From 0ee1083aec05fbc516f226bd8854dd9ecf60ca6d Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 22:06:37 +0800 Subject: [PATCH 196/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 1 - 1 file changed, 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8062bdbce..6befc4303 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2921,7 +2921,6 @@ def scale_and_round(x): sublist.extend(split_into_batches.pop(-1)) split_into_batches = [] - sublist.reverse() n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) From 7aa333796de0dc0e4d2233a650cb5642c4da12f6 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 22:12:29 +0800 Subject: [PATCH 197/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 109 +++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6befc4303..df92717e0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2380,6 +2380,7 @@ def scale_and_round(x): return images # save image + ''' distributed_state.wait_for_everyone() all_images=gather_object(images) all_global_counter = gather_object(global_counter) @@ -2393,62 +2394,62 @@ def scale_and_round(x): else: all_init_images = None if distributed_state.is_main_process: - - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ds_str = time.strftime("%Y%m%d", time.localtime()) - ts_str = time.strftime("%H%M%S", time.localtime()) - for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(all_images, all_global_counter, all_prompts, all_negative_prompts, all_seeds, all_clip_prompts, all_raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" + ''' + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ds_str = time.strftime("%Y%m%d", time.localtime()) + ts_str = time.strftime("%H%M%S", time.localtime()) + for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(all_images, all_global_counter, all_prompts, all_negative_prompts, all_seeds, all_clip_prompts, all_raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" else: - fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" - logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt}") - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.error( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" + logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt}") + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) - distributed_state.wait_for_everyone() + #distributed_state.wait_for_everyone() return images # 画像生成のプロンプトが一周するまでのループ From 22ed684fb3e066d542e94e4a39c61e84c76d666b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 22:18:49 +0800 Subject: [PATCH 198/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index df92717e0..0a83931d9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2399,7 +2399,7 @@ def scale_and_round(x): ds_str = time.strftime("%Y%m%d", time.localtime()) ts_str = time.strftime("%H%M%S", time.localtime()) for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(all_images, all_global_counter, all_prompts, all_negative_prompts, all_seeds, all_clip_prompts, all_raw_prompts) + zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): if highres_fix: seed -= 1 # record original seed From 375971daa6d6225f2f6445c41684418c9f671303 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 22:31:49 +0800 Subject: [PATCH 199/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0a83931d9..d5a8d0643 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2949,15 +2949,6 @@ def scale_and_round(x): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) - - batchlogstr = "Running through ext_separated_list_of_batches After Gather:\n" - for x in range(len(ext_separated_list_of_batches)): - batchlogstr += f"Batch_ext {x} of {len(ext_separated_list_of_batches)}:\n" - for y in range(len(ext_separated_list_of_batches[x])): - batchlogstr += f" Batch {y} of {len(ext_separated_list_of_batches[x])}:\n" - for z in range(len(ext_separated_list_of_batches[x][y])): - batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])}: {ext_separated_list_of_batches[x][y][z].global_count}\n" - logger.info(batchlogstr) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") if len(ext_separated_list_of_batches) > 0: From 36617e95cdeca8c8a1931cb87cd9a60ffca323a4 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 23:27:32 +0800 Subject: [PATCH 200/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d5a8d0643..1d475d350 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2398,6 +2398,8 @@ def scale_and_round(x): highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ds_str = time.strftime("%Y%m%d", time.localtime()) ts_str = time.strftime("%H%M%S", time.localtime()) + metadatas = [] + filenames = [] for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) ): @@ -2435,6 +2437,8 @@ def scale_and_round(x): fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt}") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + metadatas.append(metadata) + filenames.append("test_"+fln) if not args.no_preview and not highres_1st and args.interactive: try: @@ -2450,7 +2454,7 @@ def scale_and_round(x): ) #distributed_state.wait_for_everyone() - return images + return images, metadatas, filenames # 画像生成のプロンプトが一周するまでのループ prompt_index = 0 @@ -2960,7 +2964,18 @@ def scale_and_round(x): for i in range(len(batches[j])): batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" logger.info(batchlogstr) - prev_image = process_batch(batch_list[j], distributed_state, highres_fix)[0] + prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix)[0] + distributed_state.wait_for_everyone() + all_images = gather_object(prev_image) + all_metadatas = gather_object(prev_metadata) + all_filenames = gather_object(prev_filename) + if distributed_state.is_main_process: + for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): + logger.info(f"Saving image: {fln}") + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + distributed_state.wait_for_everyone() + + distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From 5a01f0695df9c7d73049ced1ac413632667e7bf9 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 23:33:49 +0800 Subject: [PATCH 201/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 1d475d350..b5f9c10bd 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2972,12 +2972,9 @@ def scale_and_round(x): if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): logger.info(f"Saving image: {fln}") - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + image.save(os.path.join(args.outdir, filename), pnginfo=metadata) distributed_state.wait_for_everyone() - - - distributed_state.wait_for_everyone() #for i in range(len(data_loader)): # logger.info(f"Loading Batch {i+1} of {len(data_loader)}") # prompt_data_list_split.append(data_loader[i]) From 95dd3677a0945e6e27c4b22390601616d73f1a5b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:02:31 +0800 Subject: [PATCH 202/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index b5f9c10bd..0f7304587 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2964,8 +2964,13 @@ def scale_and_round(x): for i in range(len(batches[j])): batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" logger.info(batchlogstr) - prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix)[0] + test = process_batch(batch_list[j], distributed_state, highres_fix)[0] + logger.info(f"test: {len(test)}") + prev_image = test[0] + prev_metadata = test[1] + prev_filename = test[2] distributed_state.wait_for_everyone() + all_images = gather_object(prev_image) all_metadatas = gather_object(prev_metadata) all_filenames = gather_object(prev_filename) From de8620f865fcc8b88c5c773b332adc0ecdd7185c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:12:20 +0800 Subject: [PATCH 203/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0f7304587..b4b33541a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2454,7 +2454,7 @@ def scale_and_round(x): ) #distributed_state.wait_for_everyone() - return images, metadatas, filenames + return (images, metadatas, filenames) # 画像生成のプロンプトが一周するまでのループ prompt_index = 0 From d865824533711a85d129a54a2be3d06bda9764eb Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:21:00 +0800 Subject: [PATCH 204/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index b4b33541a..8c92e31e3 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2169,7 +2169,7 @@ def scale_and_round(x): batch_1st.append(BatchData(is_1st_latent, global_count, base, ext_1st)) pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, distributed_state, True, True) + images_1st, _, _ = process_batch(batch_1st, distributed_state, True, True) # 2nd stageのバッチを作成して以下処理する logger.info("process 2nd stage") @@ -2377,7 +2377,7 @@ def scale_and_round(x): clip_guide_images=guide_images, ) if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images + return images, None, None # save image ''' @@ -2454,7 +2454,7 @@ def scale_and_round(x): ) #distributed_state.wait_for_everyone() - return (images, metadatas, filenames) + return images, metadatas, filenames # 画像生成のプロンプトが一周するまでのループ prompt_index = 0 @@ -2964,11 +2964,10 @@ def scale_and_round(x): for i in range(len(batches[j])): batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" logger.info(batchlogstr) - test = process_batch(batch_list[j], distributed_state, highres_fix)[0] - logger.info(f"test: {len(test)}") - prev_image = test[0] - prev_metadata = test[1] - prev_filename = test[2] + prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix)[0] + logger.info(f"prev_image: {len(prev_image)}") + logger.info(f"prev_metadata: {len(prev_metadata)}") + logger.info(f"prev_filename: {len(prev_filename)}") distributed_state.wait_for_everyone() all_images = gather_object(prev_image) From 13150c7dee5e757d713eb58dbd0dae8b74a27d98 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:30:46 +0800 Subject: [PATCH 205/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 8c92e31e3..3e74fea0e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2435,7 +2435,7 @@ def scale_and_round(x): fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" - logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt}") + logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt} on process {distributed_state.local_process_index}") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) metadatas.append(metadata) filenames.append("test_"+fln) @@ -2454,6 +2454,7 @@ def scale_and_round(x): ) #distributed_state.wait_for_everyone() + logger.info(f"images: {len(images)} metadatas: {len(metadatas)} filenames: {len(filenames)} in process: {distributed_state.local_process_index}) return images, metadatas, filenames # 画像生成のプロンプトが一周するまでのループ From 213faba728d4ba9065d7e0ef1a807ba134d73da4 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:31:47 +0800 Subject: [PATCH 206/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3e74fea0e..ef58de868 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2454,7 +2454,7 @@ def scale_and_round(x): ) #distributed_state.wait_for_everyone() - logger.info(f"images: {len(images)} metadatas: {len(metadatas)} filenames: {len(filenames)} in process: {distributed_state.local_process_index}) + logger.info(f"images: {len(images)} metadatas: {len(metadatas)} filenames: {len(filenames)} in process: {distributed_state.local_process_index}") return images, metadatas, filenames # 画像生成のプロンプトが一周するまでのループ From dfdfb6dfc8568455c4f7be0805fffcb1c01964cb Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:40:41 +0800 Subject: [PATCH 207/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index ef58de868..d3bedccdf 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2965,7 +2965,7 @@ def scale_and_round(x): for i in range(len(batches[j])): batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" logger.info(batchlogstr) - prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix)[0] + prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix) logger.info(f"prev_image: {len(prev_image)}") logger.info(f"prev_metadata: {len(prev_metadata)}") logger.info(f"prev_filename: {len(prev_filename)}") From 8cc503f57039b4f97b2cefc679b569dc7b25c67b Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:43:07 +0800 Subject: [PATCH 208/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d3bedccdf..3b555ee3f 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2965,15 +2965,16 @@ def scale_and_round(x): for i in range(len(batches[j])): batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" logger.info(batchlogstr) - prev_image, prev_metadata, prev_filename = process_batch(batch_list[j], distributed_state, highres_fix) + coll_image, coll_metadata, coll_filename = process_batch(batch_list[j], distributed_state, highres_fix) logger.info(f"prev_image: {len(prev_image)}") logger.info(f"prev_metadata: {len(prev_metadata)}") logger.info(f"prev_filename: {len(prev_filename)}") distributed_state.wait_for_everyone() - all_images = gather_object(prev_image) - all_metadatas = gather_object(prev_metadata) - all_filenames = gather_object(prev_filename) + all_images = gather_object(coll_image) + all_metadatas = gather_object(coll_metadata) + all_filenames = gather_object(coll_filename) + prev_image = allimages[0] if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): logger.info(f"Saving image: {fln}") From a5b8c0bc44aa190ece4ebdbef5db294ebec9b12c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:47:09 +0800 Subject: [PATCH 209/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3b555ee3f..7f0f1449e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2977,7 +2977,7 @@ def scale_and_round(x): prev_image = allimages[0] if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): - logger.info(f"Saving image: {fln}") + logger.info(f"Saving image: {filename}") image.save(os.path.join(args.outdir, filename), pnginfo=metadata) distributed_state.wait_for_everyone() From 84934b3be30eb797d6ac0ba47e8ada97b4b0364a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:52:38 +0800 Subject: [PATCH 210/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7f0f1449e..6cf51e73e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2960,12 +2960,12 @@ def scale_and_round(x): for batch_list in ext_separated_list_of_batches: with torch.no_grad(): with distributed_state.split_between_processes(batch_list) as batches: - for j in range(len(batches)): - batchlogstr=f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:" - for i in range(len(batches[j])): - batchlogstr += f"\nImage: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}" + for batch in batches: + batchlogstr=f"\nLoading batch {j+1}/{len(batches)} of {len(batch)} prompts onto device {distributed_state.local_process_index}:\nbatch_list:" + for i in range(len(batch)): + batchlogstr += f"\nImage: {batch[i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batch[i].base.prompt}\nNegative Prompt: {batch[i].base.negative_prompt}\nSeed: {batch[i].base.seed}" logger.info(batchlogstr) - coll_image, coll_metadata, coll_filename = process_batch(batch_list[j], distributed_state, highres_fix) + coll_image, coll_metadata, coll_filename = process_batch(batch, distributed_state, highres_fix) logger.info(f"prev_image: {len(prev_image)}") logger.info(f"prev_metadata: {len(prev_metadata)}") logger.info(f"prev_filename: {len(prev_filename)}") From a215e3ffe63812e5cc15690f8aff8b697b22c2e3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:53:47 +0800 Subject: [PATCH 211/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6cf51e73e..bfd7a28fb 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2966,9 +2966,9 @@ def scale_and_round(x): batchlogstr += f"\nImage: {batch[i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batch[i].base.prompt}\nNegative Prompt: {batch[i].base.negative_prompt}\nSeed: {batch[i].base.seed}" logger.info(batchlogstr) coll_image, coll_metadata, coll_filename = process_batch(batch, distributed_state, highres_fix) - logger.info(f"prev_image: {len(prev_image)}") - logger.info(f"prev_metadata: {len(prev_metadata)}") - logger.info(f"prev_filename: {len(prev_filename)}") + logger.info(f"coll_image: {len(coll_image)}") + logger.info(f"coll_metadata: {len(coll_metadata)}") + logger.info(f"coll_filename: {len(coll_filename)}") distributed_state.wait_for_everyone() all_images = gather_object(coll_image) From ebbe9490cac6963e5f428efeabbf52d84b149272 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 00:55:19 +0800 Subject: [PATCH 212/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bfd7a28fb..54b4b44be 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2961,7 +2961,7 @@ def scale_and_round(x): with torch.no_grad(): with distributed_state.split_between_processes(batch_list) as batches: for batch in batches: - batchlogstr=f"\nLoading batch {j+1}/{len(batches)} of {len(batch)} prompts onto device {distributed_state.local_process_index}:\nbatch_list:" + batchlogstr=f"\nLoading batch of {len(batches)} of {len(batch)} prompts onto device {distributed_state.local_process_index}:\nbatch_list:" for i in range(len(batch)): batchlogstr += f"\nImage: {batch[i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batch[i].base.prompt}\nNegative Prompt: {batch[i].base.negative_prompt}\nSeed: {batch[i].base.seed}" logger.info(batchlogstr) From 58d6434587a00f75eb88ce075baba132e6c5f727 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:01:57 +0800 Subject: [PATCH 213/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 54b4b44be..d8c8f044b 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2974,7 +2974,7 @@ def scale_and_round(x): all_images = gather_object(coll_image) all_metadatas = gather_object(coll_metadata) all_filenames = gather_object(coll_filename) - prev_image = allimages[0] + prev_image = all_images[0] if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): logger.info(f"Saving image: {filename}") From 0a7408390813899ea5293061b98db2c76d63e748 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:21:05 +0800 Subject: [PATCH 214/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d8c8f044b..0f2b3bd07 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2902,14 +2902,14 @@ def scale_and_round(x): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) - ''' + if distributed_state.num_processes > 1: resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) for list_of_prompts in resorted_list: templist.extend(list_of_prompts) - ''' + split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() sublist = [] @@ -2926,6 +2926,7 @@ def scale_and_round(x): elif len(split_into_batches) == 1 : sublist.extend(split_into_batches.pop(-1)) split_into_batches = [] + # sublist = sorted(sublist, key=lambda x: x.global_count) n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) From c6e6643e2da0b2b026a62157af1612d2e77c253f Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:31:00 +0800 Subject: [PATCH 215/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 0f2b3bd07..480f3dc60 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2907,6 +2907,7 @@ def scale_and_round(x): resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) + templist = [] for list_of_prompts in resorted_list: templist.extend(list_of_prompts) @@ -2945,6 +2946,7 @@ def scale_and_round(x): ext_separated_list_of_batches[x] = holder[:] ''' distributed_state.wait_for_everyone() + ''' if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): @@ -2954,6 +2956,7 @@ def scale_and_round(x): for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) + ''' ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") @@ -2967,19 +2970,21 @@ def scale_and_round(x): batchlogstr += f"\nImage: {batch[i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batch[i].base.prompt}\nNegative Prompt: {batch[i].base.negative_prompt}\nSeed: {batch[i].base.seed}" logger.info(batchlogstr) coll_image, coll_metadata, coll_filename = process_batch(batch, distributed_state, highres_fix) - logger.info(f"coll_image: {len(coll_image)}") - logger.info(f"coll_metadata: {len(coll_metadata)}") - logger.info(f"coll_filename: {len(coll_filename)}") + #logger.info(f"coll_image: {len(coll_image)}") + #logger.info(f"coll_metadata: {len(coll_metadata)}") + #logger.info(f"coll_filename: {len(coll_filename)}") distributed_state.wait_for_everyone() all_images = gather_object(coll_image) all_metadatas = gather_object(coll_metadata) all_filenames = gather_object(coll_filename) prev_image = all_images[0] + ''' if distributed_state.is_main_process: for image, metadata, filename in zip(all_images, all_metadatas, all_filenames): logger.info(f"Saving image: {filename}") image.save(os.path.join(args.outdir, filename), pnginfo=metadata) + ''' distributed_state.wait_for_everyone() #for i in range(len(data_loader)): From d69619cb29484e86dd5d36d5468b017680464359 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:37:09 +0800 Subject: [PATCH 216/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 480f3dc60..057d6122e 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2946,7 +2946,7 @@ def scale_and_round(x): ext_separated_list_of_batches[x] = holder[:] ''' distributed_state.wait_for_everyone() - ''' + if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): @@ -2956,7 +2956,7 @@ def scale_and_round(x): for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) - ''' + ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From b640c39b3586541c8ad1d8e36f1d2154817910bd Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:45:27 +0800 Subject: [PATCH 217/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 057d6122e..5fbf09064 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2927,7 +2927,13 @@ def scale_and_round(x): elif len(split_into_batches) == 1 : sublist.extend(split_into_batches.pop(-1)) split_into_batches = [] - # sublist = sorted(sublist, key=lambda x: x.global_count) + sublist = sorted(sublist, key=lambda x: x.global_count) + resorted_list = [] + for i in range(distributed_state.num_processes): + resorted_list.append(sublist[i :: distributed_state.num_processes]) + sublist = [] + for list_of_prompts in resorted_list: + sublist.extend(list_of_prompts) n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) From 70fd06dee6952ef8288cf399f40d4490abd9a97c Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 9 Feb 2025 01:59:39 +0800 Subject: [PATCH 218/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5fbf09064..62a245ed7 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2434,7 +2434,7 @@ def scale_and_round(x): elif args.sequential_file_name: fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" + fln = f"im_{globalcount:02d}_{ds_str}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" logger.info(f"Saving image {globalcount}: {fln}\nPrompt: {prompt} on process {distributed_state.local_process_index}") image.save(os.path.join(args.outdir, fln), pnginfo=metadata) metadatas.append(metadata) From 586e89a3ab682e61371a2631074378e944716668 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 11 Feb 2025 00:35:19 +0800 Subject: [PATCH 219/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 51 ++++++------------------------------------- 1 file changed, 7 insertions(+), 44 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 62a245ed7..d6af05ada 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2908,51 +2908,14 @@ def scale_and_round(x): for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) templist = [] - for list_of_prompts in resorted_list: - templist.extend(list_of_prompts) - - split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() - - sublist = [] - if(len(split_into_batches) % distributed_state.num_processes != 0): - #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - popnum = len(split_into_batches) % distributed_state.num_processes + for list_of_prompts in resorted_list: + templist.extend(get_batches(items=list_of_prompts, batch_size=args.batch_size).copy()) + split_into_batches = templist else: - #force distribution check on last round of batches - popnum = distributed_state.num_processes - - for j in range(popnum): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - split_into_batches = [] - sublist = sorted(sublist, key=lambda x: x.global_count) - resorted_list = [] - for i in range(distributed_state.num_processes): - resorted_list.append(sublist[i :: distributed_state.num_processes]) - sublist = [] - for list_of_prompts in resorted_list: - sublist.extend(list_of_prompts) - - n, m = divmod(len(sublist), distributed_state.num_processes) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - + split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() + ext_separated_list_of_batches.append(split_into_batches) - ''' - logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") - if distributed_state.num_processes > 1: - for x in range(len(ext_separated_list_of_batches)): - temp_list = [] - for i in range(distributed_state.num_processes): - temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - holder = [] - for batches in temp_list: - holder.extend(batches) - ext_separated_list_of_batches[x] = holder[:] - ''' - distributed_state.wait_for_everyone() - + if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): @@ -2962,7 +2925,7 @@ def scale_and_round(x): for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) - + distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") From b8d3c687df25c735e157e9451bc384bc08e6fc51 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:43:05 +0800 Subject: [PATCH 220/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d6af05ada..586a59eae 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1513,11 +1513,12 @@ def main(args): logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - - logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) + for pi in range(distributed_state.state.num_processes): + if pi == distributed_state.state.local_process_index: + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone() From 9b9d20515a2dca2907f9e59fa58dae48a1ab6845 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 13 Feb 2025 01:23:10 +0800 Subject: [PATCH 221/221] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 586a59eae..0a30a2cd0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1513,8 +1513,8 @@ def main(args): logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - for pi in range(distributed_state.state.num_processes): - if pi == distributed_state.state.local_process_index: + for pi in range(distributed_state.num_processes): + if pi == distributed_state.local_process_index: logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype