From 860a529a340b2367fcb8d75b1d874ee8dca664c3 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 16 Dec 2023 05:49:30 +0530 Subject: [PATCH 01/15] copy animatediff pipeline --- .../animatediff/pipeline_animatediff_xl.py | 749 ++++++++++++++++++ 1 file changed, 749 insertions(+) create mode 100644 src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py new file mode 100644 index 000000000000..68b358f7645c --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -0,0 +1,749 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("diffusers/motion-adapter") + >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> output = pipe(prompt="A corgi walking in the park") + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs + + +@dataclass +class AnimateDiffPipelineOutput(BaseOutput): + frames: Union[torch.Tensor, np.ndarray] + + +class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or + `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + # Post-processing + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) From 408ec395cd0cd8d5429bbe5094bc030b1f51db30 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 16 Dec 2023 05:59:40 +0530 Subject: [PATCH 02/15] update --- .../animatediff/pipeline_animatediff_xl.py | 349 ++++++++++++------ 1 file changed, 235 insertions(+), 114 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index 68b358f7645c..2dfc7511f1af 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -21,7 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import IPAdapterMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter @@ -73,11 +73,11 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): @dataclass -class AnimateDiffPipelineOutput(BaseOutput): +class AnimateDiffXLPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffXLPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -143,16 +143,20 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): @@ -162,9 +166,12 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders device: (`torch.device`): torch device - num_images_per_prompt (`int`): + num_videos_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not @@ -172,6 +179,9 @@ def encode_prompt( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -179,104 +189,118 @@ def encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ + device = device or self._execution_device + # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): + if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -284,46 +308,75 @@ def encode_prompt( " the batch size of `prompt`." ) else: - uncond_tokens = negative_prompt + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None + negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): @@ -453,16 +506,20 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs def check_inputs( self, prompt, + prompt_2, height, width, callback_steps, negative_prompt=None, + negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: @@ -473,6 +530,7 @@ def check_inputs( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) + if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -485,18 +543,30 @@ def check_inputs( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: @@ -506,6 +576,16 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None @@ -535,19 +615,24 @@ def prepare_latents( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -562,6 +647,9 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -572,12 +660,22 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. @@ -595,7 +693,15 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -629,7 +735,17 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, ) # 2. Define call parameters @@ -650,7 +766,12 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( prompt, device, num_videos_per_prompt, @@ -730,7 +851,7 @@ def __call__( callback(i, t, latents) if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) + return AnimateDiffXLPipelineOutput(frames=latents) # Post-processing video_tensor = self.decode_latents(latents) @@ -746,4 +867,4 @@ def __call__( if not return_dict: return (video,) - return AnimateDiffPipelineOutput(frames=video) + return AnimateDiffXLPipelineOutput(frames=video) From 85edc0633d8f2ceb585a46d4b93757fac259f83b Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 16 Dec 2023 06:19:07 +0530 Subject: [PATCH 03/15] update --- .../animatediff/pipeline_animatediff_xl.py | 87 +++++++++++++++++-- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index 2dfc7511f1af..df40014f1011 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -612,6 +612,22 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + @torch.no_grad() def __call__( self, @@ -640,6 +656,9 @@ def __call__( callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, ): r""" The call function to the pipeline for generation. @@ -720,6 +739,21 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + Examples: Returns: @@ -757,6 +791,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -766,6 +801,7 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) + ( prompt_embeds, negative_prompt_embeds, @@ -816,21 +852,55 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 Add image embeds for IP-Adapter + + # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # Denoising loop + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 9.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids}) + ts = torch.tensor([t], dtype=latent_model_input.dtype, device=latent_model_input.device) + if do_classifier_free_guidance: + ts = ts.repeat(2) + # predict the noise residual noise_pred = self.unet( latent_model_input, - t, + ts, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, @@ -850,10 +920,15 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float32 and latents.dtype == torch.float16: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if output_type == "latent": return AnimateDiffXLPipelineOutput(frames=latents) - # Post-processing + # 10. Post-processing video_tensor = self.decode_latents(latents) if output_type == "pt": @@ -861,7 +936,7 @@ def __call__( else: video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # Offload all models + # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: From 8922d760fb727182c41511f69f47dc3758867061 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sat, 16 Dec 2023 06:21:23 +0530 Subject: [PATCH 04/15] fix copied from comment --- src/diffusers/models/unet_motion_model.py | 2 +- src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 0bbc573e7df1..a3b239e6fb21 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -94,7 +94,7 @@ def __init__( Args: block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each UNet block. + The tuple of output channels for each UNet block. motion_layers_per_block (`int`, *optional*, defaults to 2): The number of motion layers per UNet block. motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index df40014f1011..596dda917d44 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -143,7 +143,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt def encode_prompt( self, prompt: str, From 7c751d8a830ca3ce511db7267d4c3a4ce787a0d2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 18 Dec 2023 07:17:23 +0530 Subject: [PATCH 05/15] add missing upcast_vae function to pipeline --- .../animatediff/pipeline_animatediff_xl.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index 596dda917d44..ea9c597f1005 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -23,6 +23,12 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -124,8 +130,8 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], - feature_extractor: CLIPImageProcessor = None, - image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, ): super().__init__() unet = UNetMotionModel.from_unet2d(unet, motion_adapter) @@ -628,6 +634,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() def __call__( self, From 37fff5f4a217cf8d5483987bd9d8c9fecd5857a4 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 18 Dec 2023 07:18:44 +0530 Subject: [PATCH 06/15] add adapter xl, for now, to make ckpt loadable --- src/diffusers/models/attention.py | 78 ++++++++++--- src/diffusers/models/transformer_temporal.py | 6 + src/diffusers/models/unet_motion_model.py | 114 ++++++++++++++++++- 3 files changed, 179 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 08faaaf3e5bf..db4dcc67b781 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -92,6 +92,20 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: return x +class TemporalSelfAttention(Attention): + def __init__( + self, temporal_position_encoding: bool = False, temporal_position_encoding_max_len: int = 32, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.pos_embed = ( + SinusoidalPositionalEmbedding( + embed_dim=kwargs["query_dim"], max_seq_length=temporal_position_encoding_max_len + ) + if temporal_position_encoding + else None + ) + + @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" @@ -148,6 +162,8 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + temporal_position_encoding: Optional[bool] = False, + temporal_position_encoding_max_len: int = 32, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -182,15 +198,28 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) + if attention_type == "temporal": + self.attn1 = TemporalSelfAttention( + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: @@ -202,15 +231,28 @@ def __init__( if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none + if attention_type == "temporal": + self.attn2 = TemporalSelfAttention( + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index 26e899a9b908..dd6963435576 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -85,6 +85,9 @@ def __init__( double_self_attention: bool = True, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + attention_type: str = "default", + temporal_position_encoding: Optional[bool] = False, + temporal_position_encoding_max_len: int = 32, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -111,6 +114,9 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, positional_embeddings=positional_embeddings, num_positional_embeddings=num_positional_embeddings, + attention_type=attention_type, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) for d in range(num_layers) ] diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index a3b239e6fb21..1e31859fa3df 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -58,6 +58,10 @@ def __init__( activation_fn: str = "geglu", norm_num_groups: int = 32, max_seq_length: int = 32, + positional_embeddings: Optional[str] = None, + attention_type: str = "default", + temporal_position_encoding: Optional[bool] = False, + temporal_position_encoding_max_len: int = 32, ): super().__init__() self.motion_modules = nn.ModuleList([]) @@ -72,8 +76,11 @@ def __init__( attention_bias=attention_bias, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads, - positional_embeddings="sinusoidal", + positional_embeddings=positional_embeddings, num_positional_embeddings=max_seq_length, + attention_type=attention_type, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) @@ -125,6 +132,103 @@ def __init__( num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block, + positional_embeddings="sinusoidal", + ) + ) + + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, + positional_embeddings="sinusoidal", + ) + else: + self.mid_block = None + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block + 1, + positional_embeddings="sinusoidal", + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + +class MotionAdapterXL(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: int = 2, + motion_mid_block_layers_per_block: int = 1, + motion_num_attention_heads: int = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, + temporal_position_encoding: Optional[bool] = False, + temporal_position_encoding_max_len: int = 32, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + motion_layers_per_block (`int`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block, + positional_embeddings=None, + attention_type="temporal", + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) @@ -138,6 +242,10 @@ def __init__( num_attention_heads=motion_num_attention_heads, layers_per_block=motion_mid_block_layers_per_block, max_seq_length=motion_max_seq_length, + positional_embeddings=None, + attention_type="temporal", + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) else: self.mid_block = None @@ -156,6 +264,10 @@ def __init__( num_attention_heads=motion_num_attention_heads, max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block + 1, + positional_embeddings=None, + attention_type="temporal", + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) From 84abe4896c72764bb41b8aa0ba1a8bb157eb7f28 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 18 Dec 2023 07:52:19 +0530 Subject: [PATCH 07/15] add missing constructor params --- .../pipelines/animatediff/pipeline_animatediff_xl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index ea9c597f1005..9d4a5293ccfc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -18,7 +18,13 @@ import numpy as np import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin @@ -119,7 +125,9 @@ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, motion_adapter: MotionAdapter, scheduler: Union[ @@ -139,7 +147,9 @@ def __init__( self.register_modules( vae=vae, text_encoder=text_encoder, + text_encoder_2=text_encoder_2, tokenizer=tokenizer, + tokenizer_2=tokenizer_2, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, From 3db810891035f69ebe70ff0cba266b6a5acc7449 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 19 Dec 2023 06:14:17 +0530 Subject: [PATCH 08/15] fix --- .../animatediff/pipeline_animatediff_xl.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index 9d4a5293ccfc..7b60e33bc765 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -801,6 +801,9 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -844,13 +847,17 @@ def __call__( pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) @@ -893,19 +900,19 @@ def __call__( added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 8. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - ) + # add_text_embeds = pooled_prompt_embeds + # add_time_ids = self._get_add_time_ids( + # original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + # ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + # add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + # add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device) + # add_text_embeds = add_text_embeds.to(device) + # add_time_ids = add_time_ids.to(device) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -928,7 +935,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids}) + # added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids}) ts = torch.tensor([t], dtype=latent_model_input.dtype, device=latent_model_input.device) if do_classifier_free_guidance: ts = ts.repeat(2) From 4fefc71aab7d2ffbbc7d94bc4f43545116cc50a4 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 19 Dec 2023 22:08:23 +0530 Subject: [PATCH 09/15] apply suggestions from review Co-Authored-By: Dhruv Nair --- src/diffusers/models/attention.py | 78 ++++---------- src/diffusers/models/embeddings.py | 1 + src/diffusers/models/transformer_temporal.py | 6 -- src/diffusers/models/unet_motion_model.py | 108 ------------------- 4 files changed, 19 insertions(+), 174 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index db4dcc67b781..08faaaf3e5bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -92,20 +92,6 @@ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: return x -class TemporalSelfAttention(Attention): - def __init__( - self, temporal_position_encoding: bool = False, temporal_position_encoding_max_len: int = 32, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.pos_embed = ( - SinusoidalPositionalEmbedding( - embed_dim=kwargs["query_dim"], max_seq_length=temporal_position_encoding_max_len - ) - if temporal_position_encoding - else None - ) - - @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" @@ -162,8 +148,6 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, - temporal_position_encoding: Optional[bool] = False, - temporal_position_encoding_max_len: int = 32, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -198,28 +182,15 @@ def __init__( else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - if attention_type == "temporal": - self.attn1 = TemporalSelfAttention( - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - else: - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: @@ -231,28 +202,15 @@ def __init__( if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) ) - if attention_type == "temporal": - self.attn2 = TemporalSelfAttention( - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - else: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 73abc9869230..963981014c3d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -296,6 +296,7 @@ class SinusoidalPositionalEmbedding(nn.Module): """ def __init__(self, embed_dim: int, max_seq_length: int = 32): + print(embed_dim) super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index dd6963435576..26e899a9b908 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -85,9 +85,6 @@ def __init__( double_self_attention: bool = True, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, - attention_type: str = "default", - temporal_position_encoding: Optional[bool] = False, - temporal_position_encoding_max_len: int = 32, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -114,9 +111,6 @@ def __init__( norm_elementwise_affine=norm_elementwise_affine, positional_embeddings=positional_embeddings, num_positional_embeddings=num_positional_embeddings, - attention_type=attention_type, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) for d in range(num_layers) ] diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 1e31859fa3df..23acafe95bf9 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -59,9 +59,6 @@ def __init__( norm_num_groups: int = 32, max_seq_length: int = 32, positional_embeddings: Optional[str] = None, - attention_type: str = "default", - temporal_position_encoding: Optional[bool] = False, - temporal_position_encoding_max_len: int = 32, ): super().__init__() self.motion_modules = nn.ModuleList([]) @@ -78,9 +75,6 @@ def __init__( attention_head_dim=in_channels // num_attention_heads, positional_embeddings=positional_embeddings, num_positional_embeddings=max_seq_length, - attention_type=attention_type, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, ) ) @@ -176,108 +170,6 @@ def forward(self, sample): pass -class MotionAdapterXL(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - motion_layers_per_block: int = 2, - motion_mid_block_layers_per_block: int = 1, - motion_num_attention_heads: int = 8, - motion_norm_num_groups: int = 32, - motion_max_seq_length: int = 32, - use_motion_mid_block: bool = True, - temporal_position_encoding: Optional[bool] = False, - temporal_position_encoding_max_len: int = 32, - ): - """Container to store AnimateDiff Motion Modules - - Args: - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each UNet block. - motion_layers_per_block (`int`, *optional*, defaults to 2): - The number of motion layers per UNet block. - motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): - The number of motion layers in the middle UNet block. - motion_num_attention_heads (`int`, *optional*, defaults to 8): - The number of heads to use in each attention layer of the motion module. - motion_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use in each group normalization layer of the motion module. - motion_max_seq_length (`int`, *optional*, defaults to 32): - The maximum sequence length to use in the motion module. - use_motion_mid_block (`bool`, *optional*, defaults to True): - Whether to use a motion module in the middle of the UNet. - """ - - super().__init__() - down_blocks = [] - up_blocks = [] - - for i, channel in enumerate(block_out_channels): - output_channel = block_out_channels[i] - down_blocks.append( - MotionModules( - in_channels=output_channel, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=None, - activation_fn="geglu", - attention_bias=False, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block, - positional_embeddings=None, - attention_type="temporal", - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - ) - - if use_motion_mid_block: - self.mid_block = MotionModules( - in_channels=block_out_channels[-1], - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=None, - activation_fn="geglu", - attention_bias=False, - num_attention_heads=motion_num_attention_heads, - layers_per_block=motion_mid_block_layers_per_block, - max_seq_length=motion_max_seq_length, - positional_embeddings=None, - attention_type="temporal", - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - else: - self.mid_block = None - - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, channel in enumerate(reversed_block_out_channels): - output_channel = reversed_block_out_channels[i] - up_blocks.append( - MotionModules( - in_channels=output_channel, - norm_num_groups=motion_norm_num_groups, - cross_attention_dim=None, - activation_fn="geglu", - attention_bias=False, - num_attention_heads=motion_num_attention_heads, - max_seq_length=motion_max_seq_length, - layers_per_block=motion_layers_per_block + 1, - positional_embeddings=None, - attention_type="temporal", - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - ) - - self.down_blocks = nn.ModuleList(down_blocks) - self.up_blocks = nn.ModuleList(up_blocks) - - def forward(self, sample): - pass - - class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a From 173c505235cfcfc76a578250811716fdd30fd005 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 21 Dec 2023 16:40:41 +0530 Subject: [PATCH 10/15] update --- src/diffusers/models/embeddings.py | 1 - src/diffusers/models/unet_2d_condition.py | 1 + src/diffusers/models/unet_3d_blocks.py | 2 + src/diffusers/models/unet_motion_model.py | 64 ++++++++++++++++++++--- 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 963981014c3d..73abc9869230 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -296,7 +296,6 @@ class SinusoidalPositionalEmbedding(nn.Module): """ def __init__(self, embed_dim: int, max_seq_length: int = 32): - print(embed_dim) super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ddf533d3bd3b..4329f4db4303 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -273,6 +273,7 @@ def __init__( raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index e9c505c347b0..5f66c32ac4a0 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -118,6 +118,7 @@ def get_down_block( raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") return CrossAttnDownBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -252,6 +253,7 @@ def get_up_block( raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") return CrossAttnUpBlockMotion( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 23acafe95bf9..ca91db261966 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -187,6 +187,8 @@ def __init__( sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", @@ -207,6 +209,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, @@ -214,6 +218,10 @@ def __init__( use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + addition_embed_type_num_heads: int = 64, ): super().__init__() @@ -235,6 +243,21 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + # input conv_in_kernel = 3 conv_out_kernel = 3 @@ -257,6 +280,10 @@ def __init__( if encoder_hid_dim_type is None: self.encoder_hid_proj = None + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -264,6 +291,15 @@ def __init__( if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -273,7 +309,7 @@ def __init__( down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -281,13 +317,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, dual_cross_attention=False, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_block[i], ) self.down_blocks.append(down_block) @@ -296,10 +333,11 @@ def __init__( self.mid_block = UNetMidBlockCrossAttnMotion( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, @@ -311,10 +349,11 @@ def __init__( self.mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, @@ -326,6 +365,9 @@ def __init__( # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -344,7 +386,7 @@ def __init__( up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, @@ -353,13 +395,14 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, + cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -442,7 +485,14 @@ def from_unet2d( model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) - model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + have = {} + for x in model.mid_block.attentions.state_dict().keys(): + if x in unet.mid_block.attentions.state_dict().keys(): + have[x] = unet.mid_block.attentions.state_dict()[x].reshape( + model.mid_block.attentions.state_dict()[x].shape + ) + model.mid_block.attentions.load_state_dict(have) if unet.conv_norm_out is not None: model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) From df323f97a3fba47fec88c39f98f08f4a7f88e6be Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 21 Dec 2023 16:42:50 +0530 Subject: [PATCH 11/15] remove debug code --- src/diffusers/models/unet_motion_model.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index ca91db261966..a5c6fd472e17 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -485,14 +485,7 @@ def from_unet2d( model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) - - have = {} - for x in model.mid_block.attentions.state_dict().keys(): - if x in unet.mid_block.attentions.state_dict().keys(): - have[x] = unet.mid_block.attentions.state_dict()[x].reshape( - model.mid_block.attentions.state_dict()[x].shape - ) - model.mid_block.attentions.load_state_dict(have) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) if unet.conv_norm_out is not None: model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) From 2acaa3f90976e7c029e51fac10bd5fb125f3ca52 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 21 Dec 2023 18:06:44 +0530 Subject: [PATCH 12/15] add missing linear proj --- src/diffusers/models/unet_motion_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index a5c6fd472e17..35991ea23d76 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -341,6 +341,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, ) @@ -357,6 +358,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images From 171e9180e6ff619cc719dd02ad3df4f6c5dced57 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 21 Dec 2023 18:13:30 +0530 Subject: [PATCH 13/15] fix bug with config parameter attention_head_dim --- src/diffusers/models/unet_motion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index 35991ea23d76..f530fd83a264 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -460,7 +460,7 @@ def from_unet2d( config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] # Need this for backwards compatibility with UNet2DConditionModel checkpoints - if not config.get("num_attention_heads"): + if config.get("attention_head_dim", None): config["num_attention_heads"] = config["attention_head_dim"] model = cls.from_config(config) From ec3432f589dbec627249ffa2021681ba2e33335e Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 21 Dec 2023 18:26:04 +0530 Subject: [PATCH 14/15] add temporary code to handle tensor size mismatch --- src/diffusers/models/unet_motion_model.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index f530fd83a264..a96ee3cb1e52 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -220,6 +220,8 @@ def __init__( encoder_hid_dim_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, addition_embed_type_num_heads: int = 64, ): @@ -259,8 +261,6 @@ def __init__( raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input - conv_in_kernel = 3 - conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding @@ -487,7 +487,16 @@ def from_unet2d( model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) - model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + # model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + # TODO(aryan): Fix size mismatch + have = {} + for x in model.mid_block.attentions.state_dict().keys(): + if x in unet.mid_block.attentions.state_dict().keys(): + have[x] = unet.mid_block.attentions.state_dict()[x].reshape( + model.mid_block.attentions.state_dict()[x].shape + ) + model.mid_block.attentions.load_state_dict(have) if unet.conv_norm_out is not None: model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) From b5be7a15ff6f76d62facd377f7c2c27386294095 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 25 Dec 2023 03:21:13 +0530 Subject: [PATCH 15/15] handle text_embeds and time_ids --- src/diffusers/models/unet_motion_model.py | 24 ++++++++++++++++++- .../animatediff/pipeline_animatediff_xl.py | 21 ++++++++-------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/unet_motion_model.py b/src/diffusers/models/unet_motion_model.py index a96ee3cb1e52..7076c7645e61 100644 --- a/src/diffusers/models/unet_motion_model.py +++ b/src/diffusers/models/unet_motion_model.py @@ -830,12 +830,34 @@ def forward( t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py index 7b60e33bc765..6631851019bd 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_xl.py @@ -861,6 +861,7 @@ def __call__( lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes @@ -897,22 +898,22 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else {} # 8. Prepare added time ids & embeddings - # add_text_embeds = pooled_prompt_embeds - # add_time_ids = self._get_add_time_ids( - # original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - # ) + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - # add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - # add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) - # add_text_embeds = add_text_embeds.to(device) - # add_time_ids = add_time_ids.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -935,7 +936,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - # added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids}) + added_cond_kwargs.update({"text_embeds": add_text_embeds, "time_ids": add_time_ids}) ts = torch.tensor([t], dtype=latent_model_input.dtype, device=latent_model_input.device) if do_classifier_free_guidance: ts = ts.repeat(2)