From 56ba44b539956983b12a290beff941d0e1dc7b29 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 15:37:59 +0530 Subject: [PATCH 01/34] update conversion script to handle motion adapter sdxl checkpoint --- .../convert_animatediff_motion_module_to_diffusers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index 9c5d236fd713..2595dc13bdbf 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -30,6 +30,7 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--motion_max_seq_length", type=int, default=32) + parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int) return parser.parse_args() @@ -43,9 +44,13 @@ def get_args(): conv_state_dict = convert_motion_module(state_dict) adapter = MotionAdapter( - use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length + block_out_channels=args.block_out_channels, + use_motion_mid_block=args.use_motion_mid_block, + motion_max_seq_length=args.motion_max_seq_length, ) # skip loading position embeddings adapter.load_state_dict(conv_state_dict, strict=False) adapter.save_pretrained(args.output_path) - adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16) + + adapter.to(dtype=torch.float16) + adapter.save_pretrained(args.output_path, variant="fp16") From 7ae7bc84f438784c4571c73fb39c7be6f32f8f19 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:27:02 +0530 Subject: [PATCH 02/34] add animatediff xl --- src/diffusers/models/unets/unet_3d_blocks.py | 2 + .../models/unets/unet_motion_model.py | 63 +- .../animatediff/pipeline_animatediff_sdxl.py | 1271 +++++++++++++++++ 3 files changed, 1327 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 6c20b1175349..60a08160b55c 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -118,6 +118,7 @@ def get_down_block( raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") return CrossAttnDownBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -252,6 +253,7 @@ def get_up_block( raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") return CrossAttnUpBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9654ae508215..2e03d753036b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -183,6 +183,8 @@ def __init__( sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", @@ -203,6 +205,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, @@ -210,6 +214,12 @@ def __init__( use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + addition_embed_type_num_heads: int = 64, ): super().__init__() @@ -231,9 +241,22 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + # input - conv_in_kernel = 3 - conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding @@ -253,6 +276,10 @@ def __init__( if encoder_hid_dim_type is None: self.encoder_hid_proj = None + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -260,6 +287,15 @@ def __init__( if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -269,7 +305,7 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -277,13 +313,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_block[i], ) self.down_blocks.append(down_block) @@ -295,12 +332,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + use_linear_projection=use_linear_projection, + transformer_layers_per_block=transformer_layers_per_block[-1], ) else: @@ -310,10 +349,12 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, + transformer_layers_per_block=transformer_layers_per_block[-1], ) # count how many layers upsample the images @@ -322,6 +363,9 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -340,7 +384,7 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, @@ -349,13 +393,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -411,7 +456,7 @@ def from_unet2d( config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] # Need this for backwards compatibility with UNet2DConditionModel checkpoints - if not config.get("num_attention_heads"): + if config.get("attention_head_dim", None): config["num_attention_heads"] = config["attention_head_dim"] model = cls.from_config(config) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py new file mode 100644 index 000000000000..efb0f91625c8 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -0,0 +1,1271 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import AnimateDiffPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffXLPipeline( + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + unet.config["down_block_types"] = [ + "DownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + ] + unet.config["up_block_types"] = [ + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "UpBlockMotion", + ] + unet.config["mid_block_type"] = "UNetMidBlockCrossAttnMotion" + + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + + + This API is 🧪 experimental. + + + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + progress_bar.update() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return AnimateDiffPipelineOutput(images=image) From 01f5978a16acfe90e346a103e08abd1abbab8789 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:40:06 +0530 Subject: [PATCH 03/34] handle addition_embed_type --- .../models/unets/unet_motion_model.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 2e03d753036b..095f0714e3f4 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -815,6 +815,30 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) + # ------------------------------------------------- (aryan): check this + aug_emb = None + + if self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb if aug_emb is None else emb + aug_emb + # ------------------------------------------------- emb = emb.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": From 256250085970b96f9f317b8c5aecad0641efd297 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:40:29 +0530 Subject: [PATCH 04/34] fix output --- .../animatediff/pipeline_animatediff_sdxl.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index efb0f91625c8..6d8117a8f1e5 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch from transformers import ( CLIPImageProcessor, @@ -69,6 +70,28 @@ """ +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -818,6 +841,20 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 assert emb.shape == (w.shape[0], embedding_dim) return emb + # Copied from diffusers.pipelines.animatediff.AnimateDiffPipeline._retrieve_video_frames + def _retrieve_video_frames(self, latents, output_type, return_dict): + """Helper function to handle latents to output conversion.""" + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) + @property def guidance_scale(self): return self._guidance_scale @@ -1239,33 +1276,20 @@ def __call__( progress_bar.update() - if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) - else: - image = latents - - if not output_type == "latent": - # apply watermark if available - if self.watermark is not None: - image = self.watermark.apply_watermark(image) + video = self._retrieve_video_frames(latents, output_type, return_dict) - image = self.image_processor.postprocess(image, output_type=output_type) + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) # Offload all models self.maybe_free_model_hooks() - if not return_dict: - return (image,) - - return AnimateDiffPipelineOutput(images=image) + return video From 736a224e74cd784964b3d6e52091a7ad82670589 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:47:58 +0530 Subject: [PATCH 05/34] update --- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 6d8117a8f1e5..19a00ecdc840 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -311,7 +311,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt def encode_prompt( self, prompt: str, @@ -340,7 +340,7 @@ def encode_prompt( device: (`torch.device`): torch device num_videos_per_prompt (`int`): - number of videos that should be generated per prompt + number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): @@ -841,7 +841,6 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 assert emb.shape == (w.shape[0], embedding_dim) return emb - # Copied from diffusers.pipelines.animatediff.AnimateDiffPipeline._retrieve_video_frames def _retrieve_video_frames(self, latents, output_type, return_dict): """Helper function to handle latents to output conversion.""" if output_type == "latent": From 4a2b9de6b1ab581fa711eb0abbf5ec510de564d0 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:54:58 +0530 Subject: [PATCH 06/34] add imports --- src/diffusers/__init__.py | 2 ++ src/diffusers/pipelines/__init__.py | 3 ++- src/diffusers/pipelines/animatediff/__init__.py | 2 ++ .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 2 +- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 931ccc8a4a9d..33250dd4e58d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -208,6 +208,7 @@ "AmusedInpaintPipeline", "AmusedPipeline", "AnimateDiffPipeline", + "AnimateDiffSDXLPipeline", "AnimateDiffVideoToVideoPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", @@ -570,6 +571,7 @@ AmusedInpaintPipeline, AmusedPipeline, AnimateDiffPipeline, + AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9d11f78ceee2..50fa3ba0a194 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -111,6 +111,7 @@ _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = [ "AnimateDiffPipeline", + "AnimateDiffSDXLPipeline", "AnimateDiffVideoToVideoPipeline", ] _import_structure["audioldm"] = ["AudioLDMPipeline"] @@ -344,7 +345,7 @@ from ..utils.dummy_torch_and_transformers_objects import * else: from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline - from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline + from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffVideoToVideoPipeline from .audioldm import AudioLDMPipeline from .audioldm2 import ( AudioLDM2Pipeline, diff --git a/src/diffusers/pipelines/animatediff/__init__.py b/src/diffusers/pipelines/animatediff/__init__.py index 35b99a76fd21..ae6b67b1924c 100644 --- a/src/diffusers/pipelines/animatediff/__init__.py +++ b/src/diffusers/pipelines/animatediff/__init__.py @@ -22,6 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] + _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -33,6 +34,7 @@ else: from .pipeline_animatediff import AnimateDiffPipeline + from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline from .pipeline_output import AnimateDiffPipelineOutput diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 19a00ecdc840..a3b69f629329 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -152,7 +152,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AnimateDiffXLPipeline( +class AnimateDiffSDXLPipeline( DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, From eb060e007961616a8483e36b8a117ffafbd5fe73 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 17:55:25 +0530 Subject: [PATCH 07/34] make fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e0d5c77d0e8c..da3a1531d5cb 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AnimateDiffSDXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c01d2c29dffbd75524fa6a3437be91563db1b659 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 26 Jan 2024 18:06:27 +0530 Subject: [PATCH 08/34] add decode latents --- .../animatediff/pipeline_animatediff_sdxl.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index a3b69f629329..796499ba8627 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -571,6 +571,19 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature From 60364ea74a78b510c254ab45cdff32870bff6f7a Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 27 Jan 2024 07:06:03 +0530 Subject: [PATCH 09/34] update docstrings --- .../animatediff/pipeline_animatediff_sdxl.py | 77 +++++++++++++++---- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 796499ba8627..7613bca82b1b 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -66,6 +66,47 @@ EXAMPLE_DOC_STRING = """ Examples: ```py + import torch + from diffusers.models import MotionAdapter + from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler + from diffusers.utils import export_to_gif + + # TODO(aryan): ask team to move checkpoint to authors account + adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) + + model_id = "stabilityai/stable-diffusion-xl-base-1.0" + scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + clip_sample=False, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, + ) + pipe = AnimateDiffSDXLPipeline.from_pretrained( + model_id, + motion_adapter=adapter, + scheduler=scheduler, + torch_dtype=torch.float16, + variant="fp16", + ).to("cuda") + + # enable memory savings + pipe.enable_vae_slicing() + pipe.enable_vae_tiling() + + output = pipe( + prompt="a panda surfing in the ocean, realistic, high quality", + negative_prompt="low quality, worst quality", + num_inference_steps=20, + guidance_scale=8, + width=1024, + height=1024, + num_frames=16, + ) + + frames = output.frames[0] + export_to_gif(frames, "animation.gif") ``` """ @@ -160,7 +201,7 @@ class AnimateDiffSDXLPipeline( IPAdapterMixin, ): r""" - Pipeline for text-to-image generation using Stable Diffusion XL. + Pipeline for text-to-video generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -191,17 +232,14 @@ class AnimateDiffSDXLPipeline( tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. - add_watermarker (`bool`, *optional*): - Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to - watermark output images. If not defined, it will default to True if the package is installed, otherwise no - watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" @@ -945,23 +983,26 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders + num_frames: + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. + The height in pixels of the generated video. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. + The width in pixels of the generated video. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument @@ -979,13 +1020,13 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. @@ -996,7 +1037,7 @@ def __call__( One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): @@ -1013,9 +1054,10 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between + The output format of the generated video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead @@ -1233,6 +1275,7 @@ def __call__( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) + # 10. Denoising loop self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1301,7 +1344,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) - # Offload all models + # 11. Offload all models self.maybe_free_model_hooks() return video From 0db8340a799f84031daa68101c9cf7fc0d66c396 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 27 Jan 2024 07:06:18 +0530 Subject: [PATCH 10/34] add animatediff sdxl to docs --- docs/source/en/api/pipelines/animatediff.md | 54 +++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 481fc4cdb58d..1f250cff2577 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -101,6 +101,54 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you +### AnimateDiffSDXLPipeline + +AnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available. + +```python +import torch +from diffusers.models import MotionAdapter +from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler +from diffusers.utils import export_to_gif + +# TODO(aryan): ask team to move checkpoint to authors account +adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) + +model_id = "stabilityai/stable-diffusion-xl-base-1.0" +scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + clip_sample=False, + timestep_spacing="linspace", + beta_schedule="linear", + steps_offset=1, +) +pipe = AnimateDiffSDXLPipeline.from_pretrained( + model_id, + motion_adapter=adapter, + scheduler=scheduler, + torch_dtype=torch.float16, + variant="fp16", +).to("cuda") + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_vae_tiling() + +output = pipe( + prompt="a panda surfing in the ocean, realistic, high quality", + negative_prompt="low quality, worst quality", + num_inference_steps=20, + guidance_scale=8, + width=1024, + height=1024, + num_frames=16, +) + +frames = output.frames[0] +export_to_gif(frames, "animation.gif") +``` + ### AnimateDiffVideoToVideoPipeline AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities. @@ -414,6 +462,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) - all - __call__ +## AnimateDiffSDXLPipeline + +[[autodoc]] AnimateDiffSDXLPipeline + - all + - __call__ + ## AnimateDiffVideoToVideoPipeline [[autodoc]] AnimateDiffVideoToVideoPipeline From bf2cd499a2dda404122c3855e9132b0bc9abde82 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 27 Jan 2024 07:07:06 +0530 Subject: [PATCH 11/34] remove unnecessary lines --- src/diffusers/models/unets/unet_motion_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 095f0714e3f4..b7d8886e1361 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -815,7 +815,6 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) - # ------------------------------------------------- (aryan): check this aug_emb = None if self.config.addition_embed_type == "text_time": @@ -838,7 +837,6 @@ def forward( aug_emb = self.add_embedding(add_embeds) emb = emb if aug_emb is None else emb + aug_emb - # ------------------------------------------------- emb = emb.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": From 389adaa292ebbd336a40172c16fff52510bccbfc Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 27 Jan 2024 07:08:35 +0530 Subject: [PATCH 12/34] update example --- .../animatediff/pipeline_animatediff_sdxl.py | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 7613bca82b1b..0c56b8732f8c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -66,47 +66,46 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - import torch - from diffusers.models import MotionAdapter - from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler - from diffusers.utils import export_to_gif - - # TODO(aryan): ask team to move checkpoint to authors account - adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) - - model_id = "stabilityai/stable-diffusion-xl-base-1.0" - scheduler = DDIMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - clip_sample=False, - timestep_spacing="linspace", - beta_schedule="linear", - steps_offset=1, - ) - pipe = AnimateDiffSDXLPipeline.from_pretrained( - model_id, - motion_adapter=adapter, - scheduler=scheduler, - torch_dtype=torch.float16, - variant="fp16", - ).to("cuda") - - # enable memory savings - pipe.enable_vae_slicing() - pipe.enable_vae_tiling() - - output = pipe( - prompt="a panda surfing in the ocean, realistic, high quality", - negative_prompt="low quality, worst quality", - num_inference_steps=20, - guidance_scale=8, - width=1024, - height=1024, - num_frames=16, - ) - - frames = output.frames[0] - export_to_gif(frames, "animation.gif") + >>> import torch + >>> from diffusers.models import MotionAdapter + >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... clip_sample=False, + ... timestep_spacing="linspace", + ... beta_schedule="linear", + ... steps_offset=1, + >>> ) + >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( + ... model_id, + ... motion_adapter=adapter, + ... scheduler=scheduler, + ... torch_dtype=torch.float16, + ... variant="fp16", + >>> ).to("cuda") + + >>> # enable memory savings + >>> pipe.enable_vae_slicing() + >>> pipe.enable_vae_tiling() + + >>> output = pipe( + ... prompt="a panda surfing in the ocean, realistic, high quality", + ... negative_prompt="low quality, worst quality", + ... num_inference_steps=20, + ... guidance_scale=8, + ... width=1024, + ... height=1024, + ... num_frames=16, + >>> ) + + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") ``` """ From ba4f9f45881bb09af06bf2bd28861454b4e9f9de Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 27 Jan 2024 07:36:23 +0530 Subject: [PATCH 13/34] add test --- .../animatediff/test_animatediff_sdxl.py | 293 ++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 tests/pipelines/animatediff/test_animatediff_sdxl.py diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py new file mode 100644 index 000000000000..152d458b4ef2 --- /dev/null +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -0,0 +1,293 @@ +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +import diffusers +from diffusers import ( + AnimateDiffSDXLPipeline, + AutoencoderKL, + DDIMScheduler, + MotionAdapter, + UNet2DConditionModel, + UNetMotionModel, +) +from diffusers.utils import is_xformers_available, logging +from diffusers.utils.testing_utils import torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +def to_np(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu().numpy() + + return tensor + + +class AnimateDiffPipelineSDXLFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AnimateDiffSDXLPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64, 128), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4, 8), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2, 4), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + norm_num_groups=1, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + motion_adapter = MotionAdapter( + block_out_channels=(32, 64, 128), + motion_layers_per_block=2, + motion_norm_num_groups=2, + motion_num_attention_heads=4, + use_motion_mid_block=False, + ) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "motion_adapter": motion_adapter, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "feature_extractor": None, + "image_encoder": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 7.5, + "output_type": "pt", + } + return inputs + + def test_motion_unet_loading(self): + components = self.get_dummy_components() + pipe = AnimateDiffSDXLPipeline(**components) + + assert isinstance(pipe.unet, UNetMotionModel) + + @unittest.skip("Attention slicing is not enabled in this pipeline") + def test_attention_slicing_forward_pass(self): + pass + + def test_inference_batch_single_identical( + self, + batch_size=2, + expected_max_diff=1e-4, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for components in pipe.components.values(): + if hasattr(components, "set_default_attn_processor"): + components.set_default_attn_processor() + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(torch_device) + # Reset generator in case it is has been used in self.get_dummy_inputs + inputs["generator"] = self.get_generator(0) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batched_inputs.update(inputs) + + for name in self.batch_params: + if name not in inputs: + continue + + value = inputs[name] + if name == "prompt": + len_prompt = len(value) + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + batched_inputs[name][-1] = 100 * "very long" + + else: + batched_inputs[name] = batch_size * [value] + + if "generator" in inputs: + batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)] + + if "batch_size" in inputs: + batched_inputs["batch_size"] = batch_size + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output = pipe(**inputs) + output_batch = pipe(**batched_inputs) + + assert output_batch[0].shape[0] == batch_size + + max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max() + assert max_diff < expected_max_diff + + @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") + def test_to_device(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + pipe.to("cpu") + # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cpu" for device in model_devices)) + + output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] + self.assertTrue(np.isnan(output_cpu).sum() == 0) + + pipe.to("cuda") + model_devices = [ + component.device.type for component in pipe.components.values() if hasattr(component, "device") + ] + self.assertTrue(all(device == "cuda" for device in model_devices)) + + output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] + self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0) + + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + + def test_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipe.encode_prompt(prompt) + + pipe( + **inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_without_offload = pipe(**inputs).frames[0] + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) + + pipe.enable_xformers_memory_efficient_attention() + inputs = self.get_dummy_inputs(torch_device) + output_with_offload = pipe(**inputs).frames[0] + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) + + max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() + self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") From 1fa606f94062a07494c5276967598fa06d90257f Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:16:00 +0530 Subject: [PATCH 14/34] revert conv_in conv_out kernel param --- src/diffusers/models/unets/unet_motion_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 55dacb3562ba..e5cf1270e06b 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -223,8 +223,6 @@ def __init__( encoder_hid_dim_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, addition_embed_type_num_heads: int = 64, time_cond_proj_dim: Optional[int] = None, @@ -265,6 +263,8 @@ def __init__( raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input + conv_in_kernel = 3 + conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding @@ -345,7 +345,6 @@ def __init__( use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, - use_linear_projection=use_linear_projection, transformer_layers_per_block=transformer_layers_per_block[-1], ) From 93fc8481ccd30a7df609263c33a5cef0b639764f Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:17:54 +0530 Subject: [PATCH 15/34] remove unused param addition_embed_type_num_heads --- src/diffusers/models/unets/unet_motion_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index e5cf1270e06b..795bb19fbaca 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -224,7 +224,6 @@ def __init__( addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, - addition_embed_type_num_heads: int = 64, time_cond_proj_dim: Optional[int] = None, ): super().__init__() From 5bbd8efb5340fe3d42c00cbde6685e3009c911d0 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:23:27 +0530 Subject: [PATCH 16/34] latest IPAdapter impl --- .../animatediff/pipeline_animatediff_sdxl.py | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 0c56b8732f8c..5bd6ea311b31 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -608,6 +608,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -963,6 +998,7 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1055,6 +1091,9 @@ def __call__( input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -1238,14 +1277,10 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = image_embeds.to(device) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1288,8 +1323,9 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None: + if ip_adapter_image is not None or ip_adapter_image_embeds: added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( latent_model_input, t, From 9f691274e2a15a5464899ecbb06c63bbf574053e Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:23:48 +0530 Subject: [PATCH 17/34] make fix-copies --- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 5bd6ea311b31..8c98554e8491 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -444,7 +444,7 @@ def encode_prompt( prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - # textual inversion: procecss multi-vector tokens if necessary + # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): From 5919f372e37a335aaa748142a259eaf76a16db14 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:27:34 +0530 Subject: [PATCH 18/34] fix return --- .../animatediff/pipeline_animatediff_sdxl.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 8c98554e8491..054094803f66 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -926,19 +926,6 @@ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32 assert emb.shape == (w.shape[0], embedding_dim) return emb - def _retrieve_video_frames(self, latents, output_type, return_dict): - """Helper function to handle latents to output conversion.""" - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - - if not return_dict: - return (video,) - - return AnimateDiffPipelineOutput(frames=video) - @property def guidance_scale(self): return self._guidance_scale @@ -1373,7 +1360,11 @@ def __call__( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - video = self._retrieve_video_frames(latents, output_type, return_dict) + if output_type != "latent": + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + else: + video = latents # cast back to fp16 if needed if needs_upcasting: @@ -1382,4 +1373,7 @@ def __call__( # 11. Offload all models self.maybe_free_model_hooks() - return video + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) From a09355a01287877c2a8cab3e56d45e39d7dc0a69 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:31:10 +0530 Subject: [PATCH 19/34] add IPAdapterTesterMixin to tests --- tests/pipelines/animatediff/test_animatediff_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py index 152d458b4ef2..f21f563dad86 100644 --- a/tests/pipelines/animatediff/test_animatediff_sdxl.py +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -17,7 +17,7 @@ from diffusers.utils.testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin def to_np(tensor): @@ -27,7 +27,7 @@ def to_np(tensor): return tensor -class AnimateDiffPipelineSDXLFastTests(PipelineTesterMixin, unittest.TestCase): +class AnimateDiffPipelineSDXLFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = AnimateDiffSDXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS From 306902fdd75978f60eb1dcad63eb2416c4287811 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 26 Feb 2024 23:38:11 +0530 Subject: [PATCH 20/34] fix return --- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 054094803f66..1e61cfbd397c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1360,11 +1360,8 @@ def __call__( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if output_type != "latent": - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - else: - video = latents + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # cast back to fp16 if needed if needs_upcasting: From feb458aa6c4234b662667f6feb42544490dd15ee Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 24 Mar 2024 19:19:52 +0530 Subject: [PATCH 21/34] revert based on suggestion --- src/diffusers/models/unets/unet_motion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 4c602b98e788..8b32ded9bb83 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -466,7 +466,7 @@ def from_unet2d( config["in_channels"] = motion_adapter.config["conv_in_channels"] # Need this for backwards compatibility with UNet2DConditionModel checkpoints - if config.get("attention_head_dim", None): + if not config.get("num_attention_heads"): config["num_attention_heads"] = config["attention_head_dim"] model = cls.from_config(config) From 7c3807daff72062a198e589ecd8612a73612c87d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 24 Mar 2024 19:40:30 +0530 Subject: [PATCH 22/34] add freeinit --- .../animatediff/pipeline_animatediff_sdxl.py | 137 ++++++++++-------- 1 file changed, 75 insertions(+), 62 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 1e61cfbd397c..c11d4669a139 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -57,6 +57,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline from .pipeline_output import AnimateDiffPipelineOutput @@ -198,6 +199,7 @@ class AnimateDiffSDXLPipeline( StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, + FreeInitMixin, ): r""" Pipeline for text-to-video generation using Stable Diffusion XL. @@ -1269,10 +1271,7 @@ def __call__( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt ) - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - # 8.1 Apply denoising_end + # 7.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) @@ -1288,7 +1287,7 @@ def __call__( num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - # 9. Optionally get Guidance Scale Embedding + # 8. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt) @@ -1296,62 +1295,72 @@ def __call__( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) - # 10. Denoising loop - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None or ip_adapter_image_embeds: - added_cond_kwargs["image_embeds"] = image_embeds - - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) - progress_bar.update() + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + progress_bar.update() # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast @@ -1360,8 +1369,12 @@ def __call__( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + # 10. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # cast back to fp16 if needed if needs_upcasting: From dd996ad1e3f95863564ef20965970543c3be5e7e Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 24 Mar 2024 22:04:40 +0530 Subject: [PATCH 23/34] fix test_to_dtype test --- tests/pipelines/animatediff/test_animatediff_sdxl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py index f21f563dad86..bbacd36c8048 100644 --- a/tests/pipelines/animatediff/test_animatediff_sdxl.py +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -17,7 +17,7 @@ from diffusers.utils.testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin def to_np(tensor): @@ -27,7 +27,9 @@ def to_np(tensor): return tensor -class AnimateDiffPipelineSDXLFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): +class AnimateDiffPipelineSDXLFastTests( + IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase +): pipeline_class = AnimateDiffSDXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -132,7 +134,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 7.5, - "output_type": "pt", + "output_type": "np", } return inputs @@ -235,7 +237,7 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) From 53f581580b5bf49fb2150a7217aeb1ab1dbdcb23 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 24 Mar 2024 22:05:24 +0530 Subject: [PATCH 24/34] use StableDiffusionMixin instead of different helper methods --- .../animatediff/pipeline_animatediff_sdxl.py | 123 +----------------- 1 file changed, 2 insertions(+), 121 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index c11d4669a139..c8ec00cd0914 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -58,7 +58,7 @@ ) from ...utils.torch_utils import randn_tensor from ..free_init_utils import FreeInitMixin -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -195,6 +195,7 @@ def retrieve_timesteps( class AnimateDiffSDXLPipeline( DiffusionPipeline, + StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, @@ -317,39 +318,6 @@ def __init__( self.default_sample_size = self.unet.config.sample_size - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt def encode_prompt( self, @@ -812,93 +780,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, - key, value) are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - - Args: - unet (`bool`, defaults to `True`): To apply fusion on the UNet. - vae (`bool`, defaults to `True`): To apply fusion on the VAE. - """ - self.fusing_unet = False - self.fusing_vae = False - - if unet: - self.fusing_unet = True - self.unet.fuse_qkv_projections() - self.unet.set_attn_processor(FusedAttnProcessor2_0()) - - if vae: - if not isinstance(self.vae, AutoencoderKL): - raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") - - self.fusing_vae = True - self.vae.fuse_qkv_projections() - self.vae.set_attn_processor(FusedAttnProcessor2_0()) - - def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): - """Disable QKV projection fusion if enabled. - - - - This API is 🧪 experimental. - - - - Args: - unet (`bool`, defaults to `True`): To apply fusion on the UNet. - vae (`bool`, defaults to `True`): To apply fusion on the VAE. - - """ - if unet: - if not self.fusing_unet: - logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") - else: - self.unet.unfuse_qkv_projections() - self.fusing_unet = False - - if vae: - if not self.fusing_vae: - logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") - else: - self.vae.unfuse_qkv_projections() - self.fusing_vae = False - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ From 971852f4d5798decda5d90ad7b868dd65cc8def1 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 25 Mar 2024 03:03:54 +0530 Subject: [PATCH 25/34] fix progress bar iterations --- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index c8ec00cd0914..6e02e258f0cc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1186,7 +1186,7 @@ def __call__( self._num_timesteps = len(timesteps) # 9. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue From dc0fd88a747d358b2402bc55561f8ddcf4d89ee0 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 26 Mar 2024 11:02:42 +0530 Subject: [PATCH 26/34] apply suggestions from review --- .../animatediff/pipeline_animatediff_sdxl.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 6e02e258f0cc..1b0f1459488e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -270,7 +270,7 @@ def __init__( text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, @@ -286,19 +286,8 @@ def __init__( ): super().__init__() - unet.config["down_block_types"] = [ - "DownBlockMotion", - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - ] - unet.config["up_block_types"] = [ - "CrossAttnUpBlockMotion", - "CrossAttnUpBlockMotion", - "UpBlockMotion", - ] - unet.config["mid_block_type"] = "UNetMidBlockCrossAttnMotion" - - unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, From 75bd4e81413139fcf3d14401d6c867136cfa398a Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 26 Mar 2024 11:04:42 +0530 Subject: [PATCH 27/34] hardcode flip_sin_to_cos and freq_shift --- src/diffusers/models/unets/unet_motion_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 8b32ded9bb83..550d794721cc 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -191,8 +191,6 @@ def __init__( sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", @@ -283,7 +281,7 @@ def __init__( self.encoder_hid_proj = None if addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) # class embedding From 124d1c93a58586c9735f68f63491469e30949d70 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 26 Mar 2024 12:18:39 +0530 Subject: [PATCH 28/34] make fix-copies --- .../animatediff/pipeline_animatediff_sdxl.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 1b0f1459488e..fac727ec7c34 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -128,7 +128,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -569,7 +569,7 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -593,13 +593,30 @@ def prepare_ip_adapter_image_embeds( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) - if self.do_classifier_free_guidance: + if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) image_embeds.append(single_image_embeds) else: - image_embeds = ip_adapter_image_embeds + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + return image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents @@ -770,20 +787,22 @@ def upcast_vae(self): self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`torch.Tensor`): - generate embedding vectors at these timesteps + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. Returns: - `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 From 10817881adaeb4973dec0f961693a8ef70749266 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 15:36:57 +0530 Subject: [PATCH 29/34] fix ip adapter implementation --- .../pipelines/animatediff/pipeline_animatediff_sdxl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index fac727ec7c34..064df338cc89 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1157,7 +1157,11 @@ def __call__( if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, ) # 7.1 Apply denoising_end From 4772b0df0c1b29dddf350a83e92457eb33a7b601 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 15:44:55 +0530 Subject: [PATCH 30/34] fix last failing test --- tests/pipelines/animatediff/test_animatediff_sdxl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py index bbacd36c8048..753c00016fa5 100644 --- a/tests/pipelines/animatediff/test_animatediff_sdxl.py +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -17,7 +17,7 @@ from diffusers.utils.testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, SDFunctionTesterMixin def to_np(tensor): @@ -28,7 +28,7 @@ def to_np(tensor): class AnimateDiffPipelineSDXLFastTests( - IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase + IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = AnimateDiffSDXLPipeline params = TEXT_TO_IMAGE_PARAMS @@ -264,6 +264,9 @@ def test_prompt_embeds(self): pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), From 7df8faba7aea64270933112794932bfc11e114a2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Mar 2024 15:45:18 +0530 Subject: [PATCH 31/34] make style --- .../animatediff/test_animatediff_sdxl.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py index 753c00016fa5..2db0139154e9 100644 --- a/tests/pipelines/animatediff/test_animatediff_sdxl.py +++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py @@ -17,7 +17,12 @@ from diffusers.utils.testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, SDFunctionTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineTesterMixin, + SDFunctionTesterMixin, + SDXLOptionalComponentsTesterMixin, +) def to_np(tensor): @@ -28,7 +33,11 @@ def to_np(tensor): class AnimateDiffPipelineSDXLFastTests( - IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + SDFunctionTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = AnimateDiffSDXLPipeline params = TEXT_TO_IMAGE_PARAMS @@ -264,7 +273,7 @@ def test_prompt_embeds(self): pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) - + def test_save_load_optional_components(self): self._test_save_load_optional_components() From 33d5a1831777dcc4ce9b2c2e2536816e75c57e6d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 30 Apr 2024 09:14:40 +0530 Subject: [PATCH 32/34] Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Dhruv Nair --- docs/source/en/api/pipelines/animatediff.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index d27550cc8291..08d7b23cfa90 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -112,7 +112,8 @@ from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler from diffusers.utils import export_to_gif # TODO(aryan): ask team to move checkpoint to authors account -adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) +adapter = MotionAdapter.from_pretrained(" +guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) model_id = "stabilityai/stable-diffusion-xl-base-1.0" scheduler = DDIMScheduler.from_pretrained( From 43e873f7f7c53b67d9d35790af3dad3616ce78a0 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 30 Apr 2024 09:24:07 +0530 Subject: [PATCH 33/34] remove todo --- docs/source/en/api/pipelines/animatediff.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 08d7b23cfa90..425764541590 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -111,9 +111,7 @@ from diffusers.models import MotionAdapter from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler from diffusers.utils import export_to_gif -# TODO(aryan): ask team to move checkpoint to authors account -adapter = MotionAdapter.from_pretrained(" -guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) model_id = "stabilityai/stable-diffusion-xl-base-1.0" scheduler = DDIMScheduler.from_pretrained( From cf3ebc97f6ce084adbbb4a38c1ea8b5aa8cde4a7 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 30 Apr 2024 09:27:42 +0530 Subject: [PATCH 34/34] fix doc-builder errors --- .../animatediff/pipeline_animatediff_sdxl.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 064df338cc89..f15cd5dbebf7 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -72,7 +72,9 @@ >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif - >>> adapter = MotionAdapter.from_pretrained("a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16) + >>> adapter = MotionAdapter.from_pretrained( + ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16 + ... ) >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" >>> scheduler = DDIMScheduler.from_pretrained( @@ -82,14 +84,14 @@ ... timestep_spacing="linspace", ... beta_schedule="linear", ... steps_offset=1, - >>> ) + ... ) >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( ... model_id, ... motion_adapter=adapter, ... scheduler=scheduler, ... torch_dtype=torch.float16, ... variant="fp16", - >>> ).to("cuda") + ... ).to("cuda") >>> # enable memory savings >>> pipe.enable_vae_slicing() @@ -103,7 +105,7 @@ ... width=1024, ... height=1024, ... num_frames=16, - >>> ) + ... ) >>> frames = output.frames[0] >>> export_to_gif(frames, "animation.gif") @@ -164,8 +166,8 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): @@ -970,14 +972,14 @@ def __call__( ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a + plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in