From 068dd509f0fb7b8eada1b0d2196a845204c7d4a7 Mon Sep 17 00:00:00 2001 From: standardAI Date: Sat, 9 Mar 2024 19:43:07 +0300 Subject: [PATCH 1/4] Add properties and `IPAdapterTesterMixin` tests for `StableDiffusionPanoramaPipeline` --- .../pipeline_stable_diffusion_panorama.py | 20 +++++++++++++++++++ .../test_stable_diffusion_panorama.py | 6 ++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index feda710e0049..fbc628f27737 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -600,6 +600,26 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, c views.append((h_start, h_end, w_start, w_end)) return views + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py index aa7212b0f9ff..dfb083d2e8ac 100644 --- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py @@ -32,14 +32,16 @@ from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() @skip_mps -class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPanoramaPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionPanoramaPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS From 74473584dacb72765ac0dd82d6e022810100419f Mon Sep 17 00:00:00 2001 From: standardAI Date: Mon, 11 Mar 2024 19:06:57 +0300 Subject: [PATCH 2/4] Update torch manual seed to use `torch.Generator(device=device)` --- .../stable_diffusion_panorama/test_stable_diffusion_panorama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py index dfb083d2e8ac..9ea178657e65 100644 --- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py @@ -98,7 +98,7 @@ def get_dummy_components(self): return components def get_dummy_inputs(self, device, seed=0): - generator = torch.manual_seed(seed) + generator = torch.Generator(device=device).manual_seed(seed) inputs = { "prompt": "a photo of the dolomites", "generator": generator, From ac7067d67b54235134cb150919e547524270f9db Mon Sep 17 00:00:00 2001 From: standardAI Date: Tue, 12 Mar 2024 13:15:08 +0300 Subject: [PATCH 3/4] =?UTF-8?q?Refactor=20=F0=9F=93=9E=F0=9F=94=99=20to=20?= =?UTF-8?q?support=20`callback=5Fon=5Fstep=5Fend`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pipeline_stable_diffusion_panorama.py | 232 ++++++++++++++++-- 1 file changed, 213 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index fbc628f27737..4d9e9a4a7e8e 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -13,7 +13,7 @@ import copy import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -59,6 +59,66 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionPanoramaPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin ): @@ -97,6 +157,7 @@ class StableDiffusionPanoramaPipeline( model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -461,10 +522,23 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def decode_latents_with_padding(self, latents, padding=8): - # Add padding to latents for circular inference - # padding is the number of latents to add on each side - # it would slightly increase the memory usage, but remove the boundary artifacts + def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor: + """ + Decode the given latents with padding for circular inference. + + Args: + latents (torch.Tensor): The input latents to decode. + padding (int, optional): The number of latents to add on each side for padding. Defaults to 8. + + Returns: + torch.Tensor: The decoded image with padding removed. + + Notes: + - The padding is added to remove boundary artifacts and improve the output quality. + - This would slightly increase the memory usage. + - The padding pixels are then removed from the decoded image. + + """ latents = 1 / self.vae.config.scaling_factor * latents latents_left = latents[..., :padding] latents_right = latents[..., -padding:] @@ -580,9 +654,60 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False): - # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) - # if panorama's height/width < window_size, num_blocks of height/width should return 1 + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + def get_views( + self, + panorama_height: int, + panorama_width: int, + window_size: int = 64, + stride: int = 8, + circular_padding: bool = False, + ) -> List[Tuple[int, int, int, int]]: + """ + Generates a list of views based on the given parameters. + Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113). + If panorama's height/width < window_size, num_blocks of height/width should return 1. + + Args: + panorama_height (int): The height of the panorama. + panorama_width (int): The width of the panorama. + window_size (int, optional): The size of the window. Defaults to 64. + stride (int, optional): The stride value. Defaults to 8. + circular_padding (bool, optional): Whether to apply circular padding. Defaults to False. + + Returns: + List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains + four integers representing the start and end coordinates of the window in the panorama. + + """ panorama_height /= 8 panorama_width /= 8 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 @@ -604,6 +729,10 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, c def guidance_scale(self): return self._guidance_scale + @property + def guidance_rescale(self): + return self._guidance_rescale + @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @@ -620,6 +749,10 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -628,6 +761,7 @@ def __call__( height: Optional[int] = 512, width: Optional[int] = 2048, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -641,11 +775,13 @@ def __call__( ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, circular_padding: bool = False, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs: Any, ): r""" The call function to the pipeline for generation. @@ -661,6 +797,9 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + The timesteps at which to generate the images. If not specified, then the default + timestep spacing strategy of the scheduler is used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -700,16 +839,12 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied. circular_padding (`bool`, *optional*, defaults to `False`): If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular padding allows the model to seamlessly generate a transition from the rightmost part of the image to @@ -717,6 +852,15 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: @@ -726,6 +870,22 @@ def __call__( second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -741,8 +901,15 @@ def __call__( negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -788,8 +955,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -822,12 +988,23 @@ def __call__( else None ) + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the # views. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue count.zero_() value.zero_() @@ -883,6 +1060,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds_input, + timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample @@ -892,6 +1070,12 @@ def __call__( noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + # compute the previous noisy sample x_t -> x_t-1 latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs @@ -921,6 +1105,16 @@ def __call__( # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -928,7 +1122,7 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - if not output_type == "latent": + if output_type != "latent": if circular_padding: image = self.decode_latents_with_padding(latents) else: From b78683ed6f3e040533ed384c58d7abd3e8ffb441 Mon Sep 17 00:00:00 2001 From: standardAI Date: Wed, 20 Mar 2024 08:33:50 +0300 Subject: [PATCH 4/4] make fix-copies --- .../pipeline_stable_diffusion_panorama.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 4d9e9a4a7e8e..cd5189b85e68 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -655,20 +655,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: - timesteps (`torch.Tensor`): - generate embedding vectors at these timesteps + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. Returns: - `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0