diff --git a/src/diffusers/pipelines/animatediff/context_utils.py b/src/diffusers/pipelines/animatediff/context_utils.py new file mode 100644 index 000000000000..2b2422b3ff72 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/context_utils.py @@ -0,0 +1,186 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Generator, List, Optional + +import torch + + +class ContextScheduler: + def __init__(self, length: int = 16, stride: int = 3, overlap: int = 4, loop: bool = False, type: str = "uniform_constant") -> None: + self.length = length + self.stride = stride + self.overlap = overlap + self.loop = loop + self.type = type + + def __call__(self, num_frames: int, step_index: int, num_inference_steps: int, generator: Optional[torch.Generator] = None) -> None: + if self.type == "uniform_original_v1": + return uniform_original_v1(num_frames, step_index, self.length, self.stride, self.overlap, self.loop) + elif self.type == "uniform_original_v2": + return uniform_original_v2(num_frames, step_index, self.length, self.stride, self.overlap, self.loop) + elif self.type == "uniform_constant": + return uniform_constant(num_frames, step_index, self.length, self.stride, self.overlap, self.loop) + elif self.type == "simple_overlap": + return simple_overlap(num_frames, self.length, self.overlap, self.loop) + else: + raise ValueError(f"Unknown context scheduler type: {self.type}") + + +def ordered_halving(val: int) -> float: + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + final = as_int / (1 << 64) + return final + + +def _is_sorted(l: List[int]) -> bool: + return all([l[i] < l[i + 1] for i in range(len(l) - 1)]) + + +def uniform_original_v1( + num_frames: int, + step: int, + length: int, + stride: int, + overlap: int, + loop: bool, +): + if num_frames <= length: + yield list(range(num_frames)) + return + + stride = min(stride, int(math.ceil(math.log2(num_frames / length))) + 1) + strides = [1 << i for i in range(stride)] + pad = int(round(num_frames * ordered_halving(step))) + + for s in strides: + start_index = int(ordered_halving(step) * s) + pad + end_index = num_frames + pad + (0 if loop else -overlap) + step_size = length * s - overlap + + for j in range(start_index, end_index, step_size): + context_indices = [(j + s * i) % num_frames for i in range(length)] + yield context_indices + + +def uniform_original_v2( + num_frames: int, + step: int, + length: int, + stride: int, + overlap: int, + loop: bool, +): + if num_frames <= length: + yield list(range(num_frames)) + return + + stride = min(stride, int(math.ceil(math.log2(num_frames / length))) + 1) + strides = [1 << i for i in range(stride)] + pad = int(round(num_frames * ordered_halving(step))) + + for s in strides: + start_index = int(ordered_halving(step) * s) + pad + end_index = num_frames + pad - overlap + step_size = length * s - overlap + + for j in range(start_index, end_index, step_size): + if length * s > num_frames: + yield [e % num_frames for e in range(j, j + num_frames, s)] + continue + + j = j % num_frames + + if j > (j + length * s) % num_frames and not loop: + yield [e for e in range(j, num_frames, s)] + j_stop = (j + length * s) % num_frames + yield [e for e in range(0, j_stop, s)] + continue + + yield [(j + i * s) % num_frames for i in range(length)] + + +def uniform_constant( + num_frames: int, + step: int, + length: int, + stride: int, + overlap: int, + loop: bool, +): + if num_frames <= length: + yield list(range(num_frames)) + return + + stride = min(stride, int(math.ceil(math.log2(num_frames / length))) + 1) + strides = [1 << i for i in range(stride)] + + for s in strides: + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * s) + pad, + num_frames + pad + (0 if loop else -overlap), + (length * s - overlap), + ): + skip_window = False + prev_val = -1 + context_window = [] + + for i in range(length): + e = (j + i * s) % num_frames + if not loop and e < prev_val: + skip_window = True + break + context_window.append(e) + prev_val = e + + if skip_window: + continue + + yield context_window + + +def simple_overlap(num_frames: int, length: int, overlap: int, loop: bool) -> Generator[List[int], None, None]: + if num_frames <= length: + yield list(range(num_frames)) + return + + for i in range(0, num_frames, length - overlap): + context_indices = [j % num_frames for j in range(i, i + length)] + if not loop and not _is_sorted(context_indices): + continue + yield context_indices + + +# def uniform_schedule(num_frames: int, length: int, stride: int, overlap: int, loop: bool, generator: Optional[torch.Generator] = None) -> Generator[List[int], None, None]: +# if num_frames <= length: +# yield list(range(num_frames)) +# return + +# stride = min(stride, int(math.ceil(math.log2(num_frames / length)) + 1)) +# strides = [1 << i for i in range(stride)] + +# for s in strides: +# start_index = int(torch.randint(0, s, (1,), generator=generator).item()) +# end_index = num_frames + (0 if loop else -overlap) +# step_size = length * s - overlap + +# for index in range(start_index, end_index, step_size): +# context_indices = [(index + i * s) % num_frames for i in range(length)] +# if not loop and not _is_sorted(context_indices): +# continue +# yield context_indices diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 2c8b44a2b4ee..9c4443e67f3f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -43,6 +43,7 @@ from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .context_utils import ContextScheduler from .pipeline_output import AnimateDiffPipelineOutput @@ -401,18 +402,30 @@ def prepare_ip_adapter_image_embeds( return image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + def decode_latents(self, latents, decode_batch_size): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, batch_size * num_frames, decode_batch_size): + batch_latents = latents[i : i + decode_batch_size] + batch_video = self.vae.decode(batch_latents).sample + video.append(batch_video) + video = torch.cat(video, dim=0) + video = video.reshape(batch_size, num_frames, *video.shape[1:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video + # image = self.vae.decode(latents).sample + # video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + # video = video.float() + # return video + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -569,12 +582,14 @@ def __call__( negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + context_scheduler: Optional[ContextScheduler] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + decode_batch_size: int = 16, **kwargs, ): r""" @@ -699,6 +714,16 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + + # If _context_scheduler_provided is True, it means that the user has provided a custom context scheduler. + # When True, it results in pipeline behaving in "long context" mode otherwise "normal" mode. + # Long-context mode results in interpolation of the prompt list to match the number of frames, instead of + # using the default behaviour (which is to generate different video for each prompt in the list). + _context_scheduler_provided = context_scheduler is not None + if not _context_scheduler_provided: + context_scheduler = ContextScheduler(length=num_frames, loop=False, type="uniform") + else: + batch_size = 1 device = self._execution_device @@ -706,22 +731,69 @@ def __call__( text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - clip_skip=self.clip_skip, - ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embedding_map = {} + for i, p in enumerate(prompt): + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + p, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt[i] if negative_prompt is not None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + prompt_embedding_map[i] = (prompt_embeds, negative_prompt_embeds) + prompt_embedding_linspace = torch.linspace(0, len(prompt_embedding_map) - 1, num_frames, device=device, dtype=prompt_embeds.dtype) + + def get_prompt_embedding(frame_indices: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + pe_list, npe_list = [], [] + + if _context_scheduler_provided: + middle_frame = frame_indices[len(frame_indices) // 2] + prompt_index = int(prompt_embedding_linspace[middle_frame].item()) + print("prompt_index:", prompt_index) + + if prompt_index in prompt_embedding_map.keys(): + pe, npe = prompt_embedding_map[prompt_index] + else: + # do average between closest left and right prompt embeddings + alpha = prompt_embedding_linspace[middle_frame] - prompt_index + left_pe, left_npe = prompt_embedding_map[prompt_index] + right_pe, right_npe = prompt_embedding_map[prompt_index + 1] + pe = left_pe * (1 - alpha) + right_pe * alpha + npe = left_npe * (1 - alpha) + right_npe * alpha + + pe_list.append(pe) + npe_list.append(npe) + + # mean of all prompt embeds in given context indices + # total_pe = torch.zeros((1, *prompt_embeds.shape[1:]), device=device, dtype=prompt_embeds.dtype) + # total_npe = torch.zeros((1, *prompt_embeds.shape[1:]), device=device, dtype=prompt_embeds.dtype) + # for frame_index in frame_indices: + # prompt_index = int(prompt_embedding_linspace[frame_index].item()) + # if prompt_index in prompt_embedding_map.keys(): + # pe, npe = prompt_embedding_map[prompt_index] + # else: + # alpha = prompt_embedding_linspace[frame_index] - prompt_index + # pe, npe = prompt_embedding_map[prompt_index] + # next_pe, next_npe = prompt_embedding_map[prompt_index + 1] + # pe = pe * (1 - alpha) + next_pe * alpha + # npe = npe * (1 - alpha) + next_npe * alpha + # total_pe += pe + # total_npe += npe + + # pe_list.append(total_pe / len(frame_indices)) + # npe_list.append(total_npe / len(frame_indices)) + else: + for (prompt_embed, negative_prompt_embed) in prompt_embedding_map.values(): + pe_list.append(prompt_embed) + npe_list.append(negative_prompt_embed) + + pe_list = torch.cat(pe_list) + npe_list = torch.cat(npe_list) + return pe_list, npe_list if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -773,26 +845,41 @@ def __call__( # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample + latent_batch_size = latents.shape[0] * (2 if self.do_classifier_free_guidance else 1) + total_noise_preds = torch.zeros((latent_batch_size, *latents.shape[1:]), device=device, dtype=latents.dtype) + total_counts = torch.zeros((1, 1, num_frames, 1, 1), device=device, dtype=latents.dtype) + + for context_indices in context_scheduler(num_frames, i, num_inference_steps, generator): + print("context_indices", context_indices) + # expand the latents if we are doing classifier free guidance + context_latents = latents[:, :, context_indices] + latent_model_input = torch.cat([context_latents] * 2) if self.do_classifier_free_guidance else context_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + context_prompt_embeds, context_negative_prompt_embeds = get_prompt_embedding(context_indices) + if self.do_classifier_free_guidance: + context_prompt_embeds = torch.cat([context_negative_prompt_embeds, context_prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=context_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + total_noise_preds[:, :, context_indices] += noise_pred + total_counts[:, :, context_indices] += 1 # perform guidance if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = (total_noise_preds / total_counts).chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + print("latents:", latents.min(), latents.max(), total_counts.squeeze()) if callback_on_step_end is not None: callback_kwargs = {} @@ -814,7 +901,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, decode_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models