From 5389c8c59c78b23820a59105feb0c8effbc7c256 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Fri, 12 Jul 2024 20:53:05 +0200 Subject: [PATCH 01/13] add the controlnet pipeline for pixart alpha --- src/diffusers/models/__init__.py | 2 + .../models/controlnet_pixart_alpha.py | 292 +++++ .../pipelines/controlnet_pixart/__init__.py | 67 + .../pipeline_pixart_alpha_controlnet.py | 1092 +++++++++++++++++ 4 files changed, 1453 insertions(+) create mode 100644 src/diffusers/models/controlnet_pixart_alpha.py create mode 100644 src/diffusers/pipelines/controlnet_pixart/__init__.py create mode 100644 src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 39dc149ff6d1..5643286614cb 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,6 +36,7 @@ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnet_pixart_alpha"] = ["PixArtControlNetAdapterBlock", "PixArtControlNetAdapterModel", "PixArtControlNetTransformerModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -82,6 +83,7 @@ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel + from .controlnet_pixart_alpha import PixArtControlNetAdapterBlock, PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/models/controlnet_pixart_alpha.py b/src/diffusers/models/controlnet_pixart_alpha.py new file mode 100644 index 000000000000..fbbd75b14ca3 --- /dev/null +++ b/src/diffusers/models/controlnet_pixart_alpha.py @@ -0,0 +1,292 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models import PixArtTransformer2DModel +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.modeling_outputs import Transformer2DModelOutput + +class PixArtControlNetAdapterBlock(nn.Module): + def __init__( + self, + block_index, + + # taken from PixArtTransformer2DModel + num_attention_heads: int = 16, + attention_head_dim: int = 72, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + attention_type: Optional[str] = "default", + ): + super().__init__() + + self.block_index = block_index + self.inner_dim = num_attention_heads * attention_head_dim + + # the first block has a zero before layer + if self.block_index == 0: + self.before_proj = nn.Linear(self.inner_dim, self.inner_dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + + self.transformer_block = BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + + self.after_proj = nn.Linear(self.inner_dim, self.inner_dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def train(self, mode: bool = True): + self.transformer_block.train(mode) + + if self.block_index == 0: + self.before_proj.train(mode) + + self.after_proj.train(mode) + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + if self.block_index == 0: + controlnet_states = self.before_proj(controlnet_states) + controlnet_states = hidden_states + controlnet_states + + controlnet_states_down = self.transformer_block( + hidden_states=controlnet_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + class_labels=None + ) + + controlnet_states_left = self.after_proj(controlnet_states_down) + + return controlnet_states_left, controlnet_states_down + +class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin): + # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer + @register_to_config + def __init__(self, num_layers = 13) -> None: + super().__init__() + + self.num_layers = num_layers + + self.controlnet_blocks = nn.ModuleList( + [ + PixArtControlNetAdapterBlock(block_index=i) + for i in range(num_layers) + ] + ) + + @classmethod + def from_transformer(cls, transformer: PixArtTransformer2DModel): + control_net = PixArtControlNetAdapterModel() + + # copied the specified number of blocks from the transformer + for depth in range(control_net.num_layers): + control_net.controlnet_blocks[depth].transformer_block.load_state_dict(transformer.transformer_blocks[depth].state_dict()) + + return control_net + + def train(self, mode: bool = True): + for block in self.controlnet_blocks: + block.train(mode) + +class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): + def __init__( + self, + transformer: PixArtTransformer2DModel, + controlnet: PixArtControlNetAdapterModel, + blocks_num=13, + init_from_transformer=False, + training=False + ): + super().__init__() + + self.blocks_num = blocks_num + self.gradient_checkpointing = False + self.register_to_config(**transformer.config) + self.training = training + + if init_from_transformer: + # copies the specified number of blocks from the transformer + controlnet.from_transformer(transformer, self.blocks_num) + + self.transformer = transformer + self.controlnet = controlnet + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + controlnet_cond: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + + if self.transformer.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.transformer.config.patch_size, + hidden_states.shape[-1] // self.transformer.config.patch_size, + ) + hidden_states = self.transformer.pos_embed(hidden_states) + + timestep, embedded_timestep = self.transformer.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.transformer.caption_projection is not None: + encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + controlnet_states_down = None + if controlnet_cond is not None: + controlnet_states_down = self.transformer.pos_embed(controlnet_cond) + + # 2. Blocks + for block_index, block in enumerate(self.transformer.transformer_blocks): + if self.training and self.gradient_checkpointing: + # rc todo: for training and gradient checkpointing + print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") + exit(1) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + None, + **ckpt_kwargs, + ) + else: + # the control nets are only used for the blocks 1 to self.blocks_num + if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None: + controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[block_index - 1]( + hidden_states=hidden_states, # used only in the first block + controlnet_states=controlnet_states_down, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask + ) + + hidden_states = hidden_states + controlnet_states_left + + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.transformer.scale_shift_table[None] + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.transformer.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.transformer.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.transformer.config.patch_size, self.transformer.config.patch_size, self.transformer.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.transformer.out_channels, height * self.transformer.config.patch_size, width * self.transformer.config.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + diff --git a/src/diffusers/pipelines/controlnet_pixart/__init__.py b/src/diffusers/pipelines/controlnet_pixart/__init__.py new file mode 100644 index 000000000000..3df6d67761e9 --- /dev/null +++ b/src/diffusers/pipelines/controlnet_pixart/__init__.py @@ -0,0 +1,67 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pixart_alpha_controlnet"] = ["PixArtAlphaControlnetPipeline"] + # _import_structure["pipeline_controlnet_pixart_sigama"] = ["PixArtSigmaControlnetPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py b/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py new file mode 100644 index 000000000000..17075d532288 --- /dev/null +++ b/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py @@ -0,0 +1,1092 @@ +# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer +import PIL +import numpy as np + +from diffusers.image_processor import PixArtImageProcessor, PipelineImageInput, VaeImageProcessor +from diffusers.models import AutoencoderKL, PixArtTransformer2DModel, PixArtControlNetAdapterModel, PixArtControlNetTransformerModel +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + +def get_closest_hw(width, height, image_size): + if image_size == 1024: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif image_size == 512: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid image size") + + height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + return width, height + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtAlphaControlnetPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + controlnet: PixArtControlNetAdapterModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + # change to the controlnet transformer model + transformer = PixArtControlNetTransformerModel( + transformer=transformer, controlnet=controlnet + ) + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, controlnet=controlnet + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.controlnet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + image = None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if image is not None: + self.check_image(image, prompt, prompt_embeds) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # based on pipeline_pixart_inpaiting.py + def prepare_image_latents(self, image, device, dtype): + image = image.to(device=device, dtype=dtype) + + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.config.scaling_factor + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + # rc todo: controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + # rc todo: control_guidance_start = 0.0, + # rc todo: control_guidance_end = 1.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + image, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 4.1 Prepare image + image_latents = None + if image is not None: + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.transformer.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + image_latents = self.prepare_image_latents(image, device, self.transformer.controlnet.dtype) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + controlnet_cond=image_latents, + # rc todo: controlnet_conditioning_scale=1.0, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) From abcc7703a23eae9420719ffd200358e7bb417e46 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Fri, 12 Jul 2024 21:08:19 +0200 Subject: [PATCH 02/13] import structure for the pixart alpha controlnet pipeline --- src/diffusers/pipelines/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1d5fd5c2d094..7dd3bfb34ce4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -23,6 +23,7 @@ "controlnet_hunyuandit": [], "controlnet_sd3": [], "controlnet_xs": [], + "controlnet_pixart": [], "deprecated": [], "latent_diffusion": [], "ledits_pp": [], @@ -155,6 +156,12 @@ "StableDiffusionXLControlNetXSPipeline", ] ) + _import_structure["controlnet_pixart"].extend( + [ + "PixArtAlphaControlnetPipeline", + "get_closest_hw", + ] + ) _import_structure["controlnet_hunyuandit"].extend( [ "HunyuanDiTControlNetPipeline", @@ -441,6 +448,10 @@ StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, ) + from .controlnet_pixart import ( + PixArtAlphaControlnetPipeline, + get_closest_hw + ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, From eaa2e21204b1ae5f177cf1e82d212f6fc9d2bda0 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Fri, 12 Jul 2024 22:35:47 +0200 Subject: [PATCH 03/13] use PixArtImageProcessor --- .../controlnet_pixart/pipeline_pixart_alpha_controlnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py b/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py index 17075d532288..62ef31a71962 100644 --- a/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py +++ b/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py @@ -305,9 +305,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) + self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( From 3e2ec9a0a2dec79fc898e140152109c23807306b Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 22 Jul 2024 19:00:52 +0200 Subject: [PATCH 04/13] moved the pixart controlnet in examples --- examples/pixart/.gitignore | 1 + .../pixart}/controlnet_pixart_alpha.py | 0 .../pipeline_pixart_alpha_controlnet.py | 3 +- .../run_pixart_alpha_controlnet_pipeline.py | 79 +++++++++++++++++++ src/diffusers/models/__init__.py | 2 - src/diffusers/pipelines/__init__.py | 6 -- .../pipelines/controlnet_pixart/__init__.py | 67 ---------------- 7 files changed, 82 insertions(+), 76 deletions(-) create mode 100644 examples/pixart/.gitignore rename {src/diffusers/models => examples/pixart}/controlnet_pixart_alpha.py (100%) rename {src/diffusers/pipelines/controlnet_pixart => examples/pixart}/pipeline_pixart_alpha_controlnet.py (99%) create mode 100644 examples/pixart/run_pixart_alpha_controlnet_pipeline.py delete mode 100644 src/diffusers/pipelines/controlnet_pixart/__init__.py diff --git a/examples/pixart/.gitignore b/examples/pixart/.gitignore new file mode 100644 index 000000000000..ba281509fa15 --- /dev/null +++ b/examples/pixart/.gitignore @@ -0,0 +1 @@ +images/ \ No newline at end of file diff --git a/src/diffusers/models/controlnet_pixart_alpha.py b/examples/pixart/controlnet_pixart_alpha.py similarity index 100% rename from src/diffusers/models/controlnet_pixart_alpha.py rename to examples/pixart/controlnet_pixart_alpha.py diff --git a/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py b/examples/pixart/pipeline_pixart_alpha_controlnet.py similarity index 99% rename from src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py rename to examples/pixart/pipeline_pixart_alpha_controlnet.py index 62ef31a71962..8b109a85533c 100644 --- a/src/diffusers/pipelines/controlnet_pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/pixart/pipeline_pixart_alpha_controlnet.py @@ -24,7 +24,8 @@ import numpy as np from diffusers.image_processor import PixArtImageProcessor, PipelineImageInput, VaeImageProcessor -from diffusers.models import AutoencoderKL, PixArtTransformer2DModel, PixArtControlNetAdapterModel, PixArtControlNetTransformerModel +from diffusers.models import AutoencoderKL, PixArtTransformer2DModel +from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.utils import ( BACKENDS_MAPPING, diff --git a/examples/pixart/run_pixart_alpha_controlnet_pipeline.py b/examples/pixart/run_pixart_alpha_controlnet_pipeline.py new file mode 100644 index 000000000000..8cf9afaccadd --- /dev/null +++ b/examples/pixart/run_pixart_alpha_controlnet_pipeline.py @@ -0,0 +1,79 @@ +# pip install transformers SentencePiece torchvision controlnet-aux + +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +from controlnet_pixart_alpha import PixArtControlNetAdapterModel +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline +from diffusers.utils import load_image + +from diffusers.image_processor import PixArtImageProcessor + +from controlnet_aux import HEDdetector + +controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" + +weight_dtype = torch.float16 +image_size = 1024 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(0) + +# load controlnet +controlnet = PixArtControlNetAdapterModel.from_pretrained( + controlnet_repo_id, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +pipe = PixArtAlphaControlnetPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + controlnet=controlnet, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +images_path = "images" +control_image_file = "0_7.jpg" + +# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "yellow modern car, city in background, beautiful rainy day" +# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "purple car, on highway, beautiful sunny day" +# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ." +prompt = "battleship in space, galaxy in background" + +control_image_name = control_image_file.split('.')[0] + +control_image = load_image(f"{images_path}/{control_image_file}") +print(control_image.size) +height, width = control_image.size + +hed = HEDdetector.from_pretrained("lllyasviel/Annotators") + +condition_transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.CenterCrop([image_size, image_size]), +]) + +control_image = condition_transform(control_image) +hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) + +hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg") + +# run pipeline +with torch.no_grad(): + out = pipe( + prompt=prompt, + image=hed_edge, + num_inference_steps=14, + guidance_scale=4.5, + height=image_size, + width=image_size, + ) + + out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") + \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 5643286614cb..39dc149ff6d1 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,7 +36,6 @@ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] - _import_structure["controlnet_pixart_alpha"] = ["PixArtControlNetAdapterBlock", "PixArtControlNetAdapterModel", "PixArtControlNetTransformerModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] @@ -83,7 +82,6 @@ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel - from .controlnet_pixart_alpha import PixArtControlNetAdapterBlock, PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7dd3bfb34ce4..aca452adf018 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -156,12 +156,6 @@ "StableDiffusionXLControlNetXSPipeline", ] ) - _import_structure["controlnet_pixart"].extend( - [ - "PixArtAlphaControlnetPipeline", - "get_closest_hw", - ] - ) _import_structure["controlnet_hunyuandit"].extend( [ "HunyuanDiTControlNetPipeline", diff --git a/src/diffusers/pipelines/controlnet_pixart/__init__.py b/src/diffusers/pipelines/controlnet_pixart/__init__.py deleted file mode 100644 index 3df6d67761e9..000000000000 --- a/src/diffusers/pipelines/controlnet_pixart/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["pipeline_pixart_alpha_controlnet"] = ["PixArtAlphaControlnetPipeline"] - # _import_structure["pipeline_controlnet_pixart_sigama"] = ["PixArtSigmaControlnetPipeline"] -try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - pass - - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline - - try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - pass - - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) From 98f63fb7938a005d71eeba64eb8dadea1fe5f71c Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 22 Jul 2024 19:02:24 +0200 Subject: [PATCH 05/13] rollback changes --- src/diffusers/pipelines/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index aca452adf018..1d5fd5c2d094 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -23,7 +23,6 @@ "controlnet_hunyuandit": [], "controlnet_sd3": [], "controlnet_xs": [], - "controlnet_pixart": [], "deprecated": [], "latent_diffusion": [], "ledits_pp": [], @@ -442,10 +441,6 @@ StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, ) - from .controlnet_pixart import ( - PixArtAlphaControlnetPipeline, - get_closest_hw - ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, From e381b407a2454ed5c0f6614826edb6c37d571f80 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 22 Jul 2024 19:28:55 +0200 Subject: [PATCH 06/13] training script --- examples/pixart/.gitignore | 3 +- examples/pixart/requirements.txt | 6 + .../run_pixart_alpha_controlnet_pipeline.py | 2 - .../pixart/train_controlnet_hf_diffusers.sh | 23 + examples/pixart/train_pixart_controlnet_hf.py | 1106 +++++++++++++++++ 5 files changed, 1137 insertions(+), 3 deletions(-) create mode 100644 examples/pixart/requirements.txt create mode 100755 examples/pixart/train_controlnet_hf_diffusers.sh create mode 100644 examples/pixart/train_pixart_controlnet_hf.py diff --git a/examples/pixart/.gitignore b/examples/pixart/.gitignore index ba281509fa15..4be0fcb237f5 100644 --- a/examples/pixart/.gitignore +++ b/examples/pixart/.gitignore @@ -1 +1,2 @@ -images/ \ No newline at end of file +images/ +output/ \ No newline at end of file diff --git a/examples/pixart/requirements.txt b/examples/pixart/requirements.txt new file mode 100644 index 000000000000..2b307927ee9f --- /dev/null +++ b/examples/pixart/requirements.txt @@ -0,0 +1,6 @@ +transformers +SentencePiece +torchvision +controlnet-aux +datasets +# wandb \ No newline at end of file diff --git a/examples/pixart/run_pixart_alpha_controlnet_pipeline.py b/examples/pixart/run_pixart_alpha_controlnet_pipeline.py index 8cf9afaccadd..3e8036b4f4ad 100644 --- a/examples/pixart/run_pixart_alpha_controlnet_pipeline.py +++ b/examples/pixart/run_pixart_alpha_controlnet_pipeline.py @@ -1,5 +1,3 @@ -# pip install transformers SentencePiece torchvision controlnet-aux - import torch import torchvision.transforms as T import torchvision.transforms.functional as TF diff --git a/examples/pixart/train_controlnet_hf_diffusers.sh b/examples/pixart/train_controlnet_hf_diffusers.sh new file mode 100755 index 000000000000..0abd88f19e18 --- /dev/null +++ b/examples/pixart/train_controlnet_hf_diffusers.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# run +# accelerate config + +# check with +# accelerate env + +export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512" +export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test" + +accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --report_to="wandb" \ + --seed=42 \ + --dataloader_num_workers=8 +# --lr_scheduler="cosine" --lr_warmup_steps=0 \ diff --git a/examples/pixart/train_pixart_controlnet_hf.py b/examples/pixart/train_pixart_controlnet_hf.py new file mode 100644 index 000000000000..9049c8fa4a1e --- /dev/null +++ b/examples/pixart/train_pixart_controlnet_hf.py @@ -0,0 +1,1106 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers.""" + +import argparse +import logging +import math +import gc +import os +import sys +import random +import shutil +from pathlib import Path + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +import accelerate +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from PIL import Image +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler +from diffusers.models import PixArtTransformer2DModel +from transformers import T5EncoderModel, T5Tokenizer +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +script_dir = os.path.dirname(os.path.abspath(__file__)) +subfolder_path = os.path.join(script_dir, 'pipeline') +sys.path.insert(0, subfolder_path) + +from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.29.2") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, transformer, controlnet, tokenizer, scheduler, text_encoder, args, accelerator, weight_dtype, step, is_final_validation=False): + if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16: + raise ValueError("Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints.") + + if not is_final_validation: + logger.info(f"Running validation step {step} ... ") + + controlnet = accelerator.unwrap_model(controlnet) + pipeline = PixArtAlphaControlnetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + else: + logger.info("Running validation - final ... ") + + controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + pipeline = PixArtAlphaControlnetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + image = pipeline( + prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + logger.info("Validation done!!") + + return image_logs + +def save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "pixart-alpha", + "pixart-alpha-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from the transformer.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + nargs="+", + default=None, + help="One or more prompts to be evaluated every `--validation_steps`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="pixart-controlnet", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + parser.add_argument( + "--tracker_project_name", + type=str, + default="pixart_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo(repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token).repo_id + + # See Section 3.1. of the paper. + max_length = 120 + + # For mixed precision training we cast all non-trainable weigths (vae, text_encoder) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype) + tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype) + + text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, torch_dtype=weight_dtype) + vae.requires_grad_(False) + vae.to(accelerator.device) + + transformer = PixArtTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer") + transformer.to(accelerator.device) + transformer.requires_grad_(False) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = PixArtControlNetAdapterModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from transformer.") + controlnet = PixArtControlNetAdapterModel.from_transformer(transformer) + + transformer.to(dtype=weight_dtype) + + controlnet.to(accelerator.device) + controlnet.train() + + def unwrap_model(model, keep_fp32_wrapper=True): + model = accelerator.unwrap_model(model, keep_fp32_wrapper=keep_fp32_wrapper) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for _, model in enumerate(models): + if isinstance(model, PixArtControlNetTransformerModel): + print(f"Saving model {model.__class__.__name__} to {output_dir}") + model.controlnet.save_pretrained(os.path.join(output_dir, "controlnet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # rc todo: test and load the controlenet adapter and transformer + raise ValueError("load model hook not tested") + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, PixArtControlNetTransformerModel): + load_model = PixArtControlNetAdapterModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + transformer.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Transformer loaded as datatype {unwrap_model(controlnet).dtype}. The trainable parameters should be in torch.float32." + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`") + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + params_to_optimize = controlnet.parameters() + optimizer = optimizer_cls( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0., max_length=120): + captions = [] + for caption in examples[caption_column]: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer(captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") + return inputs.input_ids, inputs.attention_mask + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] + examples["conditioning_pixel_values"] = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["input_ids"], examples['prompt_attention_mask'] = tokenize_captions(examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + prompt_attention_mask = torch.stack([example["prompt_attention_mask"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + 'prompt_attention_mask': prompt_attention_mask + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True) + controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(controlnet_transformer, optimizer, train_dataloader, lr_scheduler) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + latent_channels = transformer.config.in_channels + for epoch in range(first_epoch, args.num_train_epochs): + controlnet_transformer.controlnet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Convert control images to latent space + controlnet_image_latents = vae.encode(batch["conditioning_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch['prompt_attention_mask'])[0] + prompt_attention_mask = batch['prompt_attention_mask'] + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if getattr(transformer, 'module', transformer).config.sample_size == 128: + resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1) + aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1) + resolution = resolution.to(dtype=weight_dtype, device=latents.device) + aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device) + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # Predict the noise residual and compute loss + model_pred = controlnet_transformer(noisy_latents, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + controlnet_cond=controlnet_image_latents, + added_cond_kwargs=added_cond_kwargs, + return_dict=False + )[0] + + if transformer.config.out_channels // 2 == latent_channels: + model_pred = model_pred.chunk(2, dim=1)[0] + else: + model_pred = model_pred + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = (torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet_transformer.controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints") + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation(vae, transformer, controlnet_transformer.controlnet, tokenizer, noise_scheduler, text_encoder, args, accelerator, weight_dtype, global_step, is_final_validation=False) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False) + controlnet.save_pretrained(os.path.join(args.output_dir, "controlnet")) + + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation(vae, transformer, controlnet, tokenizer, noise_scheduler, text_encoder, args, accelerator, weight_dtype, global_step, is_final_validation=True) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + +if __name__ == "__main__": + main() From 2e6ca8cfc48d0987f570aca680209abb9a050cf1 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Tue, 23 Jul 2024 11:55:01 +0200 Subject: [PATCH 07/13] moved pipepile to comunity folder --- .../pipeline_pixart_alpha_controlnet.py | 6 +++++- .../run_pixart_alpha_controlnet_pipeline.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) rename examples/{pixart => community}/pipeline_pixart_alpha_controlnet.py (99%) rename examples/{pixart => community}/run_pixart_alpha_controlnet_pipeline.py (93%) diff --git a/examples/pixart/pipeline_pixart_alpha_controlnet.py b/examples/community/pipeline_pixart_alpha_controlnet.py similarity index 99% rename from examples/pixart/pipeline_pixart_alpha_controlnet.py rename to examples/community/pipeline_pixart_alpha_controlnet.py index 8b109a85533c..3884a1a6a3ca 100644 --- a/examples/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/community/pipeline_pixart_alpha_controlnet.py @@ -16,6 +16,8 @@ import inspect import re import urllib.parse as ul +import sys +import os from typing import Callable, List, Optional, Tuple, Union import torch @@ -25,7 +27,6 @@ from diffusers.image_processor import PixArtImageProcessor, PipelineImageInput, VaeImageProcessor from diffusers.models import AutoencoderKL, PixArtTransformer2DModel -from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from diffusers.schedulers import DPMSolverMultistepScheduler from diffusers.utils import ( BACKENDS_MAPPING, @@ -38,6 +39,9 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel + logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): diff --git a/examples/pixart/run_pixart_alpha_controlnet_pipeline.py b/examples/community/run_pixart_alpha_controlnet_pipeline.py similarity index 93% rename from examples/pixart/run_pixart_alpha_controlnet_pipeline.py rename to examples/community/run_pixart_alpha_controlnet_pipeline.py index 3e8036b4f4ad..b15eae2b8ca2 100644 --- a/examples/pixart/run_pixart_alpha_controlnet_pipeline.py +++ b/examples/community/run_pixart_alpha_controlnet_pipeline.py @@ -1,8 +1,9 @@ +import sys +import os import torch import torchvision.transforms as T import torchvision.transforms.functional as TF -from controlnet_pixart_alpha import PixArtControlNetAdapterModel from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline from diffusers.utils import load_image @@ -10,6 +11,9 @@ from controlnet_aux import HEDdetector +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel + controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" weight_dtype = torch.float16 From a6ed8f7a465bad13a4f4c16c6cdd722527425006 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Wed, 24 Jul 2024 21:02:14 +0200 Subject: [PATCH 08/13] readme section for the pixart controlnet model and pipeline --- examples/community/README.md | 94 +++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index f467ee38de3b..df1ecabe4384 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -71,6 +71,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | | FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) | | AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | +PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -4286,4 +4287,95 @@ grid_image.save(grid_dir + "sample.png") `pag_scale` : guidance scale of PAG (ex: 5.0) -`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0']) \ No newline at end of file +`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0']) + +# PIXART-α Controlnet pipeline + +[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md) + +This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers. + +## Example Usage + +This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper. + +```py +import sys +import os +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline +from diffusers.utils import load_image + +from diffusers.image_processor import PixArtImageProcessor + +from controlnet_aux import HEDdetector + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel + +controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" + +weight_dtype = torch.float16 +image_size = 1024 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(0) + +# load controlnet +controlnet = PixArtControlNetAdapterModel.from_pretrained( + controlnet_repo_id, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +pipe = PixArtAlphaControlnetPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + controlnet=controlnet, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +images_path = "images" +control_image_file = "0_7.jpg" + +prompt = "battleship in space, galaxy in background" + +control_image_name = control_image_file.split('.')[0] + +control_image = load_image(f"{images_path}/{control_image_file}") +print(control_image.size) +height, width = control_image.size + +hed = HEDdetector.from_pretrained("lllyasviel/Annotators") + +condition_transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.CenterCrop([image_size, image_size]), +]) + +control_image = condition_transform(control_image) +hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) + +hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg") + +# run pipeline +with torch.no_grad(): + out = pipe( + prompt=prompt, + image=hed_edge, + num_inference_steps=14, + guidance_scale=4.5, + height=image_size, + width=image_size, + ) + + out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") + +``` + +In the folder examples/pixart there is also a script that can be used to train new models. +Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training. \ No newline at end of file From 0d1ed9dcc50a71ce6e3a5416dc7b6d7af4a101ea Mon Sep 17 00:00:00 2001 From: raci0399 Date: Wed, 16 Oct 2024 19:28:29 +0200 Subject: [PATCH 09/13] moved the pixart controlnet pipeline under research_projects --- examples/{ => research_projects}/pixart/.gitignore | 0 .../{ => research_projects}/pixart/controlnet_pixart_alpha.py | 0 .../pixart}/pipeline_pixart_alpha_controlnet.py | 0 examples/{ => research_projects}/pixart/requirements.txt | 0 .../pixart/train_controlnet_hf_diffusers.sh | 0 .../{ => research_projects}/pixart/train_pixart_controlnet_hf.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename examples/{ => research_projects}/pixart/.gitignore (100%) rename examples/{ => research_projects}/pixart/controlnet_pixart_alpha.py (100%) rename examples/{community => research_projects/pixart}/pipeline_pixart_alpha_controlnet.py (100%) rename examples/{ => research_projects}/pixart/requirements.txt (100%) rename examples/{ => research_projects}/pixart/train_controlnet_hf_diffusers.sh (100%) rename examples/{ => research_projects}/pixart/train_pixart_controlnet_hf.py (100%) diff --git a/examples/pixart/.gitignore b/examples/research_projects/pixart/.gitignore similarity index 100% rename from examples/pixart/.gitignore rename to examples/research_projects/pixart/.gitignore diff --git a/examples/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py similarity index 100% rename from examples/pixart/controlnet_pixart_alpha.py rename to examples/research_projects/pixart/controlnet_pixart_alpha.py diff --git a/examples/community/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py similarity index 100% rename from examples/community/pipeline_pixart_alpha_controlnet.py rename to examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py diff --git a/examples/pixart/requirements.txt b/examples/research_projects/pixart/requirements.txt similarity index 100% rename from examples/pixart/requirements.txt rename to examples/research_projects/pixart/requirements.txt diff --git a/examples/pixart/train_controlnet_hf_diffusers.sh b/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh similarity index 100% rename from examples/pixart/train_controlnet_hf_diffusers.sh rename to examples/research_projects/pixart/train_controlnet_hf_diffusers.sh diff --git a/examples/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py similarity index 100% rename from examples/pixart/train_pixart_controlnet_hf.py rename to examples/research_projects/pixart/train_pixart_controlnet_hf.py From 9b89559314ca68bc1cb46a5b318b78430541a945 Mon Sep 17 00:00:00 2001 From: junsongc Date: Fri, 25 Oct 2024 17:21:47 +0800 Subject: [PATCH 10/13] make style && make quality; --- .../run_pixart_alpha_controlnet_pipeline.py | 24 +-- .../pixart/controlnet_pixart_alpha.py | 78 +++++---- .../pipeline_pixart_alpha_controlnet.py | 35 ++-- .../pixart/train_pixart_controlnet_hf.py | 164 +++++++++++++----- 4 files changed, 198 insertions(+), 103 deletions(-) diff --git a/examples/community/run_pixart_alpha_controlnet_pipeline.py b/examples/community/run_pixart_alpha_controlnet_pipeline.py index b15eae2b8ca2..9ff922f42b57 100644 --- a/examples/community/run_pixart_alpha_controlnet_pipeline.py +++ b/examples/community/run_pixart_alpha_controlnet_pipeline.py @@ -1,19 +1,18 @@ -import sys import os +import sys + import torch import torchvision.transforms as T -import torchvision.transforms.functional as TF - +from controlnet_aux import HEDdetector from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline -from diffusers.utils import load_image -from diffusers.image_processor import PixArtImageProcessor +from diffusers.utils import load_image -from controlnet_aux import HEDdetector sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel + controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" weight_dtype = torch.float16 @@ -48,7 +47,7 @@ # prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ." prompt = "battleship in space, galaxy in background" -control_image_name = control_image_file.split('.')[0] +control_image_name = control_image_file.split(".")[0] control_image = load_image(f"{images_path}/{control_image_file}") print(control_image.size) @@ -56,10 +55,12 @@ hed = HEDdetector.from_pretrained("lllyasviel/Annotators") -condition_transform = T.Compose([ - T.Lambda(lambda img: img.convert('RGB')), - T.CenterCrop([image_size, image_size]), -]) +condition_transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB")), + T.CenterCrop([image_size, image_size]), + ] +) control_image = condition_transform(control_image) hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) @@ -78,4 +79,3 @@ ) out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") - \ No newline at end of file diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index fbbd75b14ca3..1f210ab71816 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -3,17 +3,17 @@ import torch from torch import nn -from diffusers.models.attention import BasicTransformerBlock -from diffusers.models import PixArtTransformer2DModel from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.modeling_utils import ModelMixin +from diffusers.models import PixArtTransformer2DModel +from diffusers.models.attention import BasicTransformerBlock from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin + class PixArtControlNetAdapterBlock(nn.Module): def __init__( - self, + self, block_index, - # taken from PixArtTransformer2DModel num_attention_heads: int = 16, attention_head_dim: int = 72, @@ -55,7 +55,7 @@ def __init__( attention_type=attention_type, ) - self.after_proj = nn.Linear(self.inner_dim, self.inner_dim) + self.after_proj = nn.Linear(self.inner_dim, self.inner_dim) nn.init.zeros_(self.after_proj.weight) nn.init.zeros_(self.after_proj.bias) @@ -90,50 +90,51 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, - class_labels=None + class_labels=None, ) controlnet_states_left = self.after_proj(controlnet_states_down) return controlnet_states_left, controlnet_states_down + class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin): # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer @register_to_config - def __init__(self, num_layers = 13) -> None: + def __init__(self, num_layers=13) -> None: super().__init__() self.num_layers = num_layers self.controlnet_blocks = nn.ModuleList( - [ - PixArtControlNetAdapterBlock(block_index=i) - for i in range(num_layers) - ] + [PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)] ) @classmethod def from_transformer(cls, transformer: PixArtTransformer2DModel): control_net = PixArtControlNetAdapterModel() - + # copied the specified number of blocks from the transformer for depth in range(control_net.num_layers): - control_net.controlnet_blocks[depth].transformer_block.load_state_dict(transformer.transformer_blocks[depth].state_dict()) + control_net.controlnet_blocks[depth].transformer_block.load_state_dict( + transformer.transformer_blocks[depth].state_dict() + ) return control_net def train(self, mode: bool = True): for block in self.controlnet_blocks: block.train(mode) - + + class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): def __init__( - self, - transformer: PixArtTransformer2DModel, - controlnet: PixArtControlNetAdapterModel, - blocks_num=13, - init_from_transformer=False, - training=False + self, + transformer: PixArtTransformer2DModel, + controlnet: PixArtControlNetAdapterModel, + blocks_num=13, + init_from_transformer=False, + training=False, ): super().__init__() @@ -141,14 +142,14 @@ def __init__( self.gradient_checkpointing = False self.register_to_config(**transformer.config) self.training = training - + if init_from_transformer: # copies the specified number of blocks from the transformer controlnet.from_transformer(transformer, self.blocks_num) self.transformer = transformer self.controlnet = controlnet - + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -165,10 +166,9 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): - if self.transformer.use_additional_conditions and added_cond_kwargs is None: raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") - + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -243,15 +243,17 @@ def custom_forward(*inputs): else: # the control nets are only used for the blocks 1 to self.blocks_num if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None: - controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[block_index - 1]( - hidden_states=hidden_states, # used only in the first block + controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[ + block_index - 1 + ]( + hidden_states=hidden_states, # used only in the first block controlnet_states=controlnet_states_down, encoder_hidden_states=encoder_hidden_states, timestep=timestep, added_cond_kwargs=added_cond_kwargs, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask + encoder_attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + controlnet_states_left @@ -268,7 +270,8 @@ def custom_forward(*inputs): # 3. Output shift, scale = ( - self.transformer.scale_shift_table[None] + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device) + self.transformer.scale_shift_table[None] + + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device) ).chunk(2, dim=1) hidden_states = self.transformer.norm_out(hidden_states) # Modulation @@ -278,15 +281,26 @@ def custom_forward(*inputs): # unpatchify hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.transformer.config.patch_size, self.transformer.config.patch_size, self.transformer.out_channels) + shape=( + -1, + height, + width, + self.transformer.config.patch_size, + self.transformer.config.patch_size, + self.transformer.out_channels, + ) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( - shape=(-1, self.transformer.out_channels, height * self.transformer.config.patch_size, width * self.transformer.config.patch_size) + shape=( + -1, + self.transformer.out_channels, + height * self.transformer.config.patch_size, + width * self.transformer.config.patch_size, + ) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index 3884a1a6a3ca..8753fbea93a7 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -14,21 +14,22 @@ import html import inspect +import os import re -import urllib.parse as ul import sys -import os +import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union +import numpy as np +import PIL import torch from transformers import T5EncoderModel, T5Tokenizer -import PIL -import numpy as np -from diffusers.image_processor import PixArtImageProcessor, PipelineImageInput, VaeImageProcessor +from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor from diffusers.models import AutoencoderKL, PixArtTransformer2DModel +from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import DPMSolverMultistepScheduler -from diffusers.utils import ( +from diffusers.utils import ( BACKENDS_MAPPING, deprecate, is_bs4_available, @@ -37,11 +38,12 @@ replace_example_docstring, ) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel + logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): @@ -174,6 +176,7 @@ "4.0": [512.0, 128.0], } + def get_closest_hw(width, height, image_size): if image_size == 1024: aspect_ratio_bin = ASPECT_RATIO_1024_BIN @@ -181,11 +184,12 @@ def get_closest_hw(width, height, image_size): aspect_ratio_bin = ASPECT_RATIO_512_BIN else: raise ValueError("Invalid image size") - + height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) return width, height + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -300,12 +304,15 @@ def __init__( super().__init__() # change to the controlnet transformer model - transformer = PixArtControlNetTransformerModel( - transformer=transformer, controlnet=controlnet - ) + transformer = PixArtControlNetTransformerModel(transformer=transformer, controlnet=controlnet) self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, controlnet=controlnet + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -472,7 +479,7 @@ def prepare_extra_step_kwargs(self, generator, eta): if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs - + def check_inputs( self, prompt, @@ -480,7 +487,7 @@ def check_inputs( width, negative_prompt, callback_steps, - image = None, + image=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py index 9049c8fa4a1e..827bd4309bd6 100644 --- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py +++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py @@ -15,36 +15,36 @@ """Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers.""" import argparse +import gc import logging import math -import gc import os -import sys import random import shutil +import sys from pathlib import Path +import accelerate import datasets import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers -import accelerate from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder -from PIL import Image from packaging import version +from PIL import Image from torchvision import transforms from tqdm.auto import tqdm +from transformers import T5EncoderModel, T5Tokenizer import diffusers from diffusers import AutoencoderKL, DDPMScheduler from diffusers.models import PixArtTransformer2DModel -from transformers import T5EncoderModel, T5Tokenizer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available, make_image_grid @@ -52,13 +52,15 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module + script_dir = os.path.dirname(os.path.abspath(__file__)) -subfolder_path = os.path.join(script_dir, 'pipeline') +subfolder_path = os.path.join(script_dir, "pipeline") sys.path.insert(0, subfolder_path) from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline + if is_wandb_available(): import wandb @@ -67,9 +69,24 @@ logger = get_logger(__name__, log_level="INFO") -def log_validation(vae, transformer, controlnet, tokenizer, scheduler, text_encoder, args, accelerator, weight_dtype, step, is_final_validation=False): + +def log_validation( + vae, + transformer, + controlnet, + tokenizer, + scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + step, + is_final_validation=False, +): if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16: - raise ValueError("Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints.") + raise ValueError( + "Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints." + ) if not is_final_validation: logger.info(f"Running validation step {step} ... ") @@ -91,7 +108,7 @@ def log_validation(vae, transformer, controlnet, tokenizer, scheduler, text_enco logger.info("Running validation - final ... ") controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) - + pipeline = PixArtAlphaControlnetPipeline.from_pretrained( args.pretrained_model_name_or_path, controlnet=controlnet, @@ -130,7 +147,7 @@ def log_validation(vae, transformer, controlnet, tokenizer, scheduler, text_enco for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") validation_image = validation_image.resize((args.resolution, args.resolution)) - + images = [] for _ in range(args.num_validation_images): @@ -187,6 +204,7 @@ def log_validation(vae, transformer, controlnet, tokenizer, scheduler, text_enco return image_logs + def save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" if image_logs is not None: @@ -307,7 +325,7 @@ def parse_args(): default=None, help="One or more prompts to be evaluated every `--validation_steps`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" - " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.", ) parser.add_argument( "--validation_image", @@ -531,9 +549,9 @@ def parse_args(): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) - + args = parser.parse_args() - + # Sanity checks if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") @@ -543,6 +561,7 @@ def parse_args(): return args + def main(): args = parse_args() @@ -551,7 +570,7 @@ def main(): "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." " Please use `huggingface-cli login` to authenticate with the Hub." ) - + logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) @@ -565,7 +584,6 @@ def main(): if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -593,7 +611,9 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - repo_id = create_repo(repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token).repo_id + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id # See Section 3.1. of the paper. max_length = 120 @@ -607,14 +627,26 @@ def main(): weight_dtype = torch.bfloat16 # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype) - tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype) + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype + ) + tokenizer = T5Tokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype + ) - text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype) + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype + ) text_encoder.requires_grad_(False) text_encoder.to(accelerator.device) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, torch_dtype=weight_dtype) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) vae.requires_grad_(False) vae.to(accelerator.device) @@ -656,7 +688,7 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): # rc todo: test and load the controlenet adapter and transformer raise ValueError("load model hook not tested") - + for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() @@ -666,7 +698,7 @@ def load_model_hook(models, input_dir): model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) - del load_model + del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -700,14 +732,18 @@ def load_model_hook(models, input_dir): controlnet.enable_gradient_checkpointing() if args.scale_lr: - args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Initialize the optimizer if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: - raise ImportError("Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`") + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) optimizer_cls = bnb.optim.AdamW8bit else: @@ -768,7 +804,7 @@ def load_model_hook(models, input_dir): raise ValueError( f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" ) - + if args.conditioning_image_column is None: conditioning_image_column = column_names[2] logger.info(f"conditioning image column defaulting to {conditioning_image_column}") @@ -781,7 +817,7 @@ def load_model_hook(models, input_dir): # Preprocessing the datasets. # We need to tokenize input captions and transform the images. - def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0., max_length=120): + def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0.0, max_length=120): captions = [] for caption in examples[caption_column]: if random.random() < proportion_empty_prompts: @@ -823,7 +859,9 @@ def preprocess_train(examples): conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] examples["conditioning_pixel_values"] = [conditioning_image_transforms(image) for image in conditioning_images] - examples["input_ids"], examples['prompt_attention_mask'] = tokenize_captions(examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length) + examples["input_ids"], examples["prompt_attention_mask"] = tokenize_captions( + examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length + ) return examples @@ -847,7 +885,7 @@ def collate_fn(examples): "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, - 'prompt_attention_mask': prompt_attention_mask + "prompt_attention_mask": prompt_attention_mask, } # DataLoaders creation: @@ -875,8 +913,10 @@ def collate_fn(examples): # Prepare everything with our `accelerator`. controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True) - controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(controlnet_transformer, optimizer, train_dataloader, lr_scheduler) - + controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet_transformer, optimizer, train_dataloader, lr_scheduler + ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -948,14 +988,18 @@ def collate_fn(examples): latents = latents * vae.config.scaling_factor # Convert control images to latent space - controlnet_image_latents = vae.encode(batch["conditioning_pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + controlnet_image_latents = vae.encode( + batch["conditioning_pixel_values"].to(dtype=weight_dtype) + ).latent_dist.sample() controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor - + # Sample noise that we'll add to the latents noise = torch.randn_like(latents) if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) bsz = latents.shape[0] # Sample a random timestep for each image @@ -967,8 +1011,8 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch['prompt_attention_mask'])[0] - prompt_attention_mask = batch['prompt_attention_mask'] + prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch["prompt_attention_mask"])[0] + prompt_attention_mask = batch["prompt_attention_mask"] # Get the target for loss depending on the prediction type if args.prediction_type is not None: # set prediction_type of scheduler if defined @@ -983,7 +1027,7 @@ def collate_fn(examples): # Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if getattr(transformer, 'module', transformer).config.sample_size == 128: + if getattr(transformer, "module", transformer).config.sample_size == 128: resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1) aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1) resolution = resolution.to(dtype=weight_dtype, device=latents.device) @@ -991,20 +1035,21 @@ def collate_fn(examples): added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # Predict the noise residual and compute loss - model_pred = controlnet_transformer(noisy_latents, + model_pred = controlnet_transformer( + noisy_latents, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=timesteps, controlnet_cond=controlnet_image_latents, added_cond_kwargs=added_cond_kwargs, - return_dict=False + return_dict=False, )[0] if transformer.config.out_channels // 2 == latent_channels: model_pred = model_pred.chunk(2, dim=1)[0] else: model_pred = model_pred - + if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") else: @@ -1015,7 +1060,9 @@ def collate_fn(examples): if noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective requires that we add one to SNR values before we divide by them. snr = snr + 1 - mse_loss_weights = (torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights @@ -1054,7 +1101,9 @@ def collate_fn(examples): num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] - logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints") + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: @@ -1067,7 +1116,19 @@ def collate_fn(examples): logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation(vae, transformer, controlnet_transformer.controlnet, tokenizer, noise_scheduler, text_encoder, args, accelerator, weight_dtype, global_step, is_final_validation=False) + log_validation( + vae, + transformer, + controlnet_transformer.controlnet, + tokenizer, + noise_scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + global_step, + is_final_validation=False, + ) logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -1080,11 +1141,23 @@ def collate_fn(examples): if accelerator.is_main_process: controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False) controlnet.save_pretrained(os.path.join(args.output_dir, "controlnet")) - + image_logs = None if args.validation_prompt is not None: - image_logs = log_validation(vae, transformer, controlnet, tokenizer, noise_scheduler, text_encoder, args, accelerator, weight_dtype, global_step, is_final_validation=True) - + image_logs = log_validation( + vae, + transformer, + controlnet, + tokenizer, + noise_scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + global_step, + is_final_validation=True, + ) + if args.push_to_hub: save_model_card( repo_id, @@ -1102,5 +1175,6 @@ def collate_fn(examples): accelerator.end_training() + if __name__ == "__main__": main() From 5957191352c6bc241692c0ff1d2896e19d8b4029 Mon Sep 17 00:00:00 2001 From: junsongc Date: Fri, 25 Oct 2024 17:27:22 +0800 Subject: [PATCH 11/13] make style && make quality; --- .../pixart/pipeline_pixart_alpha_controlnet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index 8753fbea93a7..b9d3511f761e 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -25,6 +25,7 @@ import torch from transformers import T5EncoderModel, T5Tokenizer +from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor from diffusers.models import AutoencoderKL, PixArtTransformer2DModel from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput @@ -40,10 +41,6 @@ from diffusers.utils.torch_utils import randn_tensor -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): From 995a79f0127dfbffef536dd56f9e1bb8e695a212 Mon Sep 17 00:00:00 2001 From: junsongc Date: Fri, 25 Oct 2024 17:36:01 +0800 Subject: [PATCH 12/13] make style && make quality; --- .../run_pixart_alpha_controlnet_pipeline.py | 10 ++-------- .../pixart/controlnet_pixart_alpha.py | 1 + .../pixart/pipeline_pixart_alpha_controlnet.py | 4 +--- .../pixart/train_pixart_controlnet_hf.py | 14 +++++--------- 4 files changed, 9 insertions(+), 20 deletions(-) diff --git a/examples/community/run_pixart_alpha_controlnet_pipeline.py b/examples/community/run_pixart_alpha_controlnet_pipeline.py index 9ff922f42b57..0014c590541b 100644 --- a/examples/community/run_pixart_alpha_controlnet_pipeline.py +++ b/examples/community/run_pixart_alpha_controlnet_pipeline.py @@ -1,16 +1,10 @@ -import os -import sys - import torch import torchvision.transforms as T from controlnet_aux import HEDdetector -from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline from diffusers.utils import load_image - - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel +from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel +from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index 1f210ab71816..b7f5a427e52e 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -8,6 +8,7 @@ from diffusers.models.attention import BasicTransformerBlock from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.torch_utils import is_torch_version class PixArtControlNetAdapterBlock(nn.Module): diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index b9d3511f761e..aace66f9c18e 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -14,18 +14,16 @@ import html import inspect -import os import re -import sys import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL import torch +from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from transformers import T5EncoderModel, T5Tokenizer -from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor from diffusers.models import AutoencoderKL, PixArtTransformer2DModel from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py index 827bd4309bd6..995a20dfa28e 100644 --- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py +++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py @@ -21,7 +21,6 @@ import os import random import shutil -import sys from pathlib import Path import accelerate @@ -38,6 +37,7 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline from torchvision import transforms from tqdm.auto import tqdm from transformers import T5EncoderModel, T5Tokenizer @@ -51,14 +51,10 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - - -script_dir = os.path.dirname(os.path.abspath(__file__)) -subfolder_path = os.path.join(script_dir, "pipeline") -sys.path.insert(0, subfolder_path) - -from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel -from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline +from examples.research_projects.pixart.controlnet_pixart_alpha import ( + PixArtControlNetAdapterModel, + PixArtControlNetTransformerModel, +) if is_wandb_available(): From 146dec284f240533b814c222713a38027469c7fc Mon Sep 17 00:00:00 2001 From: raci0399 Date: Sun, 27 Oct 2024 08:50:37 +0100 Subject: [PATCH 13/13] moved the file to research_projects folder --- .../pixart}/run_pixart_alpha_controlnet_pipeline.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{community => research_projects/pixart}/run_pixart_alpha_controlnet_pipeline.py (100%) diff --git a/examples/community/run_pixart_alpha_controlnet_pipeline.py b/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py similarity index 100% rename from examples/community/run_pixart_alpha_controlnet_pipeline.py rename to examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py