diff --git a/docs/source/package_reference/stable_diffusion_pipeline.mdx b/docs/source/package_reference/stable_diffusion_pipeline.mdx index 611c28514b..e493bdb8d3 100644 --- a/docs/source/package_reference/stable_diffusion_pipeline.mdx +++ b/docs/source/package_reference/stable_diffusion_pipeline.mdx @@ -40,6 +40,33 @@ To get the most out of it, it should be associated with a scheduler that is opti - all +# GaudiStableDiffusionXLPipeline + +The `GaudiStableDiffusionXLPipeline` class enables to perform text-to-image generation on HPUs using SDXL models. +It inherits from the `GaudiDiffusionPipeline` class that is the parent to any kind of diffuser pipeline. + +To get the most out of it, it should be associated with a scheduler that is optimized for HPUs like `GaudiDDIMScheduler`. +Recommended schedulers are `GaudiEulerDiscreteScheduler` for SDXL base and `GaudiEulerAncestralDiscreteScheduler` for SDXL turbo. + + +## GaudiStableDiffusionXLPipeline + +[[autodoc]] diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.GaudiStableDiffusionXLPipeline + - __call__ + + +## GaudiEulerDiscreteScheduler + +[[autodoc]] diffusers.schedulers.scheduling_euler_discrete.GaudiEulerDiscreteScheduler + - all + + +## GaudiEulerAncestralDiscreteScheduler + +[[autodoc]] diffusers.schedulers.scheduling_euler_ancestral_discrete.GaudiEulerAncestralDiscreteScheduler + - all + + # GaudiStableDiffusionUpscalePipeline The `GaudiStableDiffusionUpscalePipeline` is used to enhance the resolution of input images by a factor of 4 on HPUs. @@ -47,4 +74,4 @@ It inherits from the `GaudiDiffusionPipeline` class that is the parent to any ki [[autodoc]] diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.GaudiStableDiffusionUpscalePipeline - - __call__ \ No newline at end of file + - __call__ diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 21b407f3e8..f14ffc7b87 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -115,6 +115,90 @@ python text_to_image_generation.py \ > - use [the latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) for generating improved results > - use [the pano checkpoint](https://huggingface.co/Intel/ldm3d-pano) to generate panoramic view +### Stable Diffusion XL (SDXL) + +Stable Diffusion XL was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/pdf/2307.01952.pdf) by the Stability AI team. + +Here is how to generate SDXL images with a single prompt: +```python +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "Sailing ship painting by Van Gogh" \ + --num_images_per_prompt 20 \ + --batch_size 4 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --scheduler euler_discrete \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +> HPU graphs are recommended when generating images by batches to get the fastest possible generations. +> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. +> You can enable this mode with `--use_hpu_graphs`. + +Here is how to generate SDXL images with several prompts: +```python +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "Sailing ship painting by Van Gogh" "A shiny flying horse taking off" \ + --num_images_per_prompt 20 \ + --batch_size 8 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --scheduler euler_discrete \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly +increase the number of parameters. Here is how to generate images with several prompts for both `prompt` +and `prompt_2` (2nd text encoder), as well as their negative prompts: +```python +python text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "Sailing ship painting by Van Gogh" "A shiny flying horse taking off" \ + --prompts_2 "Red tone" "Blue tone" \ + --negative_prompts "Low quality" "Sketch" \ + --negative_prompts_2 "Clouds" "Clouds" \ + --num_images_per_prompt 20 \ + --batch_size 8 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --scheduler euler_discrete \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +> HPU graphs are recommended when generating images by batches to get the fastest possible generations. +> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. +> You can enable this mode with `--use_hpu_graphs`. + +### SDXL-Turbo +SDXL-Turbo is a distilled version of SDXL 1.0, trained for real-time synthesis. + +Here is how to generate images with multiple prompts: +```bash +python text_to_image_generation.py \ + --model_name_or_path stabilityai/sdxl-turbo \ + --prompts "Sailing ship painting by Van Gogh" "A shiny flying horse taking off" \ + --num_images_per_prompt 20 \ + --batch_size 8 \ + --image_save_dir /tmp/stable_diffusion_xl_turbo_images \ + --scheduler euler_ancestral_discrete \ + --use_habana \ + --use_hpu_graphs \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +> HPU graphs are recommended when generating images by batches to get the fastest possible generations. +> The first batch of images entails a performance penalty. All subsequent batches will be generated much faster. +> You can enable this mode with `--use_hpu_graphs`. + ## Textual Inversion diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index b9f7ebb1a3..972a497191 100644 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -20,7 +20,11 @@ import torch -from optimum.habana.diffusers import GaudiDDIMScheduler +from optimum.habana.diffusers import ( + GaudiDDIMScheduler, + GaudiEulerAncestralDiscreteScheduler, + GaudiEulerDiscreteScheduler, +) from optimum.habana.utils import set_seed @@ -49,6 +53,14 @@ def main(): help="Path to pre-trained model", ) + parser.add_argument( + "--scheduler", + default="ddim", + choices=["euler_discrete", "euler_ancestral_discrete", "ddim"], + type=str, + help="Name of scheduler", + ) + # Pipeline arguments parser.add_argument( "--prompts", @@ -57,12 +69,29 @@ def main(): default="An image of a squirrel in Picasso style", help="The prompt or prompts to guide the image generation.", ) + parser.add_argument( + "--prompts_2", + type=str, + nargs="*", + default=None, + help="The second prompt or prompts to guide the image generation (applicable to SDXL).", + ) parser.add_argument( "--num_images_per_prompt", type=int, default=1, help="The number of images to generate per prompt." ) parser.add_argument("--batch_size", type=int, default=1, help="The number of images in a batch.") - parser.add_argument("--height", type=int, default=512, help="The height in pixels of the generated images.") - parser.add_argument("--width", type=int, default=512, help="The width in pixels of the generated images.") + parser.add_argument( + "--height", + type=int, + default=0, + help="The height in pixels of the generated images (0=default from model config).", + ) + parser.add_argument( + "--width", + type=int, + default=0, + help="The width in pixels of the generated images (0=default from model config).", + ) parser.add_argument( "--num_inference_steps", type=int, @@ -89,6 +118,13 @@ def main(): default=None, help="The prompt or prompts not to guide the image generation.", ) + parser.add_argument( + "--negative_prompts_2", + type=str, + nargs="*", + default=None, + help="The second prompt or prompts not to guide the image generation (applicable to SDXL).", + ) parser.add_argument( "--eta", type=float, @@ -139,13 +175,28 @@ def main(): args = parser.parse_args() - if args.ldm3d: - from optimum.habana.diffusers import GaudiStableDiffusionLDM3DPipeline as GaudiStableDiffusionPipeline + # Set image resolution + res = {} + if args.width > 0 and args.height > 0: + res["width"] = args.width + res["height"] = args.height + + # Import selected pipeline + sdxl_models = ["stable-diffusion-xl-base-1.0", "sdxl-turbo"] - if args.model_name_or_path == "runwayml/stable-diffusion-v1-5": - args.model_name_or_path = "Intel/ldm3d-4c" + if any(model in args.model_name_or_path for model in sdxl_models): + from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline + + sdxl = True else: - from optimum.habana.diffusers import GaudiStableDiffusionPipeline + if args.ldm3d: + from optimum.habana.diffusers import GaudiStableDiffusionLDM3DPipeline as GaudiStableDiffusionPipeline + + if args.model_name_or_path == "runwayml/stable-diffusion-v1-5": + args.model_name_or_path = "Intel/ldm3d-4c" + else: + from optimum.habana.diffusers import GaudiStableDiffusionPipeline + sdxl = False # Setup logging logging.basicConfig( @@ -156,36 +207,63 @@ def main(): logger.setLevel(logging.INFO) # Initialize the scheduler and the generation pipeline - scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler") + if args.scheduler == "euler_discrete": + scheduler = GaudiEulerDiscreteScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler") + elif args.scheduler == "euler_ancestral_discrete": + scheduler = GaudiEulerAncestralDiscreteScheduler.from_pretrained( + args.model_name_or_path, subfolder="scheduler" + ) + else: + scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler") + kwargs = { "scheduler": scheduler, "use_habana": args.use_habana, "use_hpu_graphs": args.use_hpu_graphs, "gaudi_config": args.gaudi_config_name, } + if args.bf16: kwargs["torch_dtype"] = torch.bfloat16 - pipeline = GaudiStableDiffusionPipeline.from_pretrained( - args.model_name_or_path, - **kwargs, - ) # Set seed before running the model set_seed(args.seed) # Generate images - outputs = pipeline( - prompt=args.prompts, - num_images_per_prompt=args.num_images_per_prompt, - batch_size=args.batch_size, - height=args.height, - width=args.width, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - negative_prompt=args.negative_prompts, - eta=args.eta, - output_type=args.output_type, - ) + if sdxl: + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.model_name_or_path, + **kwargs, + ) + outputs = pipeline( + prompt=args.prompts, + prompt_2=args.prompts_2, + num_images_per_prompt=args.num_images_per_prompt, + batch_size=args.batch_size, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + negative_prompt=args.negative_prompts, + negative_prompt_2=args.negative_prompts_2, + eta=args.eta, + output_type=args.output_type, + **res, + ) + else: + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.model_name_or_path, + **kwargs, + ) + outputs = pipeline( + prompt=args.prompts, + num_images_per_prompt=args.num_images_per_prompt, + batch_size=args.batch_size, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + negative_prompt=args.negative_prompts, + eta=args.eta, + output_type=args.output_type, + **res, + ) # Save the pipeline in the specified directory if not None if args.pipeline_save_dir is not None: diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 29a7810113..42f9c08e37 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -2,4 +2,5 @@ from .pipelines.stable_diffusion.pipeline_stable_diffusion import GaudiStableDiffusionPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import GaudiStableDiffusionLDM3DPipeline from .pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import GaudiStableDiffusionUpscalePipeline -from .schedulers import GaudiDDIMScheduler +from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import GaudiStableDiffusionXLPipeline +from .schedulers import GaudiDDIMScheduler, GaudiEulerAncestralDiscreteScheduler, GaudiEulerDiscreteScheduler diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index a2d9139101..e71f9832e1 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -51,6 +51,8 @@ }, "optimum.habana.diffusers.schedulers": { "GaudiDDIMScheduler": ["save_pretrained", "from_pretrained"], + "GaudiEulerDiscreteScheduler": ["save_pretrained", "from_pretrained"], + "GaudiEulerAncestralDiscreteScheduler": ["save_pretrained", "from_pretrained"], }, } @@ -112,7 +114,7 @@ def __init__( if bf16_full_eval: logger.warning( "`use_torch_autocast` is True in the given Gaudi configuration but " - "`torch_dtype=torch.blfloat16` was given. Disabling mixed precision and continuing in bf16 only." + "`torch_dtype=torch.bfloat16` was given. Disabling mixed precision and continuing in bf16 only." ) self.gaudi_config.use_torch_autocast = False else: diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6c5e8fecbd..c6e1789a43 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -368,6 +368,7 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device="cpu") timesteps = self.scheduler.timesteps.to(device) + self.scheduler.reset_timestep_dependent_params() # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -459,7 +460,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( - noise_pred, latents_batch, **extra_step_kwargs, return_dict=False + noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False )[0] if not self.use_hpu_graphs: @@ -489,8 +490,6 @@ def __call__( image = latents_batch outputs["images"].append(image) - self.scheduler.reset_timestep_dependent_params() - if not self.use_hpu_graphs: self.htcore.mark_step() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 3bcd156d1f..f1423ed7f5 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -285,6 +285,7 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device="cpu") timesteps = self.scheduler.timesteps.to(device) + self.scheduler.reset_timestep_dependent_params() # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -362,7 +363,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( - noise_pred, latents_batch, **extra_step_kwargs, return_dict=False + noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False )[0] if not self.use_hpu_graphs: @@ -380,8 +381,6 @@ def __call__( image = latents_batch outputs["images"].append(image) - self.scheduler.reset_timestep_dependent_params() - if not self.use_hpu_graphs: self.htcore.mark_step() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index e04873a5d3..594e0e0c30 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -381,6 +381,7 @@ def __call__( # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device="cpu") timesteps = self.scheduler.timesteps.to(device) + self.scheduler.reset_timestep_dependent_params() # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) @@ -483,7 +484,7 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_batch = self.scheduler.step( - noise_pred, latents_batch, **extra_step_kwargs, return_dict=False + noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False )[0] if not self.use_hpu_graphs: @@ -516,8 +517,6 @@ def __call__( image = latents_batch outputs["images"].append(image) - self.scheduler.reset_timestep_dependent_params() - if not self.use_hpu_graphs: self.htcore.mark_step() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000..cb1320a1da --- /dev/null +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,876 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from dataclasses import dataclass +from math import ceil +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import BaseOutput, deprecate +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from optimum.utils import logging + +from ....transformers.gaudi_configuration import GaudiConfig +from ....utils import speed_metrics +from ..pipeline_utils import GaudiDiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class GaudiStableDiffusionXLPipelineOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + throughput: float + + +class GaudiStableDiffusionXLPipeline(GaudiDiffusionPipeline, StableDiffusionXLPipeline): + """ + Pipeline for text-to-image generation using Stable Diffusion XL on Gaudi devices + Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L96 + + Extends the [`StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline) class: + - Generation is performed by batches + - Two `mark_step()` were added to add support for lazy mode + - Added support for HPU graphs + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + use_habana (bool, defaults to `False`): + Whether to use Gaudi (`True`) or CPU (`False`). + use_hpu_graphs (bool, defaults to `False`): + Whether to use HPU graphs or not. + gaudi_config (Union[str, [`GaudiConfig`]], defaults to `None`): + Gaudi configuration to use. Can be a string to download it from the Hub. + Or a previously initialized config can be passed. + bf16_full_eval (bool, defaults to `False`): + Whether to use full bfloat16 evaluation instead of 32-bit. + This will be faster and save memory compared to fp32/mixed precision but can harm generated images. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + + StableDiffusionXLPipeline.__init__( + self, + vae, + text_encoder, + text_encoder_2, + tokenizer, + tokenizer_2, + unet, + scheduler, + force_zeros_for_empty_prompt, + ) + + self.to(self._device) + + def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != num_images: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {num_images}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(num_images) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @classmethod + def _split_inputs_into_batches( + cls, + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + add_text_embeds, + negative_pooled_prompt_embeds, + add_time_ids, + negative_add_time_ids, + ): + # Use torch.split to generate num_batches batches of size batch_size + latents_batches = list(torch.split(latents, batch_size)) + prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) + if negative_prompt_embeds is not None: + negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size)) + if add_text_embeds is not None: + add_text_embeds_batches = list(torch.split(add_text_embeds, batch_size)) + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size)) + if add_time_ids is not None: + add_time_ids_batches = list(torch.split(add_time_ids, batch_size)) + if negative_add_time_ids is not None: + negative_add_time_ids_batches = list(torch.split(negative_add_time_ids, batch_size)) + + # If the last batch has less samples than batch_size, pad it with dummy samples + num_dummy_samples = 0 + if latents_batches[-1].shape[0] < batch_size: + num_dummy_samples = batch_size - latents_batches[-1].shape[0] + # Pad latents_batches + sequence_to_stack = (latents_batches[-1],) + tuple( + torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + latents_batches[-1] = torch.vstack(sequence_to_stack) + # Pad prompt_embeds_batches + sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_prompt_embeds_batches if necessary + if negative_prompt_embeds is not None: + sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad add_text_embeds_batches if necessary + if add_text_embeds is not None: + sequence_to_stack = (add_text_embeds_batches[-1],) + tuple( + torch.zeros_like(add_text_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + add_text_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_pooled_prompt_embeds_batches if necessary + if negative_pooled_prompt_embeds is not None: + sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple( + torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :]) + for _ in range(num_dummy_samples) + ) + negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) + # Pad add_time_ids_batches if necessary + if add_time_ids is not None: + sequence_to_stack = (add_time_ids_batches[-1],) + tuple( + torch.zeros_like(add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + add_time_ids_batches[-1] = torch.vstack(sequence_to_stack) + # Pad negative_add_time_ids_batches if necessary + if negative_add_time_ids is not None: + sequence_to_stack = (negative_add_time_ids_batches[-1],) + tuple( + torch.zeros_like(negative_add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples) + ) + negative_add_time_ids_batches[-1] = torch.vstack(sequence_to_stack) + + # Stack batches in the same tensor + latents_batches = torch.stack(latents_batches) + + if negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate( + zip(negative_prompt_embeds_batches, prompt_embeds_batches[:]) + ): + prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch]) + prompt_embeds_batches = torch.stack(prompt_embeds_batches) + + if add_text_embeds is not None: + if negative_pooled_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_pooled_prompt_embeds_batch, add_text_embeds_batch) in enumerate( + zip(negative_pooled_prompt_embeds_batches, add_text_embeds_batches[:]) + ): + add_text_embeds_batches[i] = torch.cat( + [negative_pooled_prompt_embeds_batch, add_text_embeds_batch] + ) + add_text_embeds_batches = torch.stack(add_text_embeds_batches) + else: + add_text_embeds_batches = None + + if add_time_ids is not None: + if negative_add_time_ids is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + for i, (negative_add_time_ids_batch, add_time_ids_batch) in enumerate( + zip(negative_add_time_ids_batches, add_time_ids_batches[:]) + ): + add_time_ids_batches[i] = torch.cat([negative_add_time_ids_batch, add_time_ids_batch]) + add_time_ids_batches = torch.stack(add_time_ids_batches) + else: + add_time_ids_batches = None + + return latents_batches, prompt_embeds_batches, add_text_embeds_batches, add_time_ids_batches, num_dummy_samples + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + batch_size: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + The number of images in a batch. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + #Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + Whether or not to return a [`~diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.GaudiStableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + #[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + #[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + [`~diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.GaudiStableDiffusionXLPipelineOutput`] or `tuple`: + [`~diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.GaudiStableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + num_prompts = 1 + elif prompt is not None and isinstance(prompt, list): + num_prompts = len(prompt) + else: + num_prompts = prompt_embeds.shape[0] + num_batches = ceil((num_images_per_prompt * num_prompts) / batch_size) + logger.info( + f"{num_prompts} prompt(s) received, {num_images_per_prompt} generation(s) per prompt," + f" {batch_size} sample(s) per batch, {num_batches} total batch(es)." + ) + if num_batches < 3: + logger.warning("The first two iterations are slower so it is recommended to feed more batches.") + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device="cpu") + timesteps = self.scheduler.timesteps.to(device) + self.scheduler.reset_timestep_dependent_params() + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + num_prompts * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + prompt_embeds = prompt_embeds.to(device) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(num_prompts * num_images_per_prompt, 1) + negative_add_time_ids = negative_add_time_ids.to(device).repeat(num_prompts * num_images_per_prompt, 1) + + # 7.5 Split into batches (HPU-specific step) + ( + latents_batches, + text_embeddings_batches, + add_text_embeddings_batches, + add_time_ids_batches, + num_dummy_samples, + ) = self._split_inputs_into_batches( + batch_size, + latents, + prompt_embeds, + negative_prompt_embeds, + add_text_embeds, + negative_pooled_prompt_embeds, + add_time_ids, + negative_add_time_ids, + ) + outputs = { + "images": [], + } + t0 = time.time() + t1 = t0 + + self._num_timesteps = len(timesteps) + + # 8. Denoising + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + num_prompts * num_images_per_prompt + ) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + + # 8.3 Denoising loop + for j in self.progress_bar(range(num_batches)): + # The throughput is calculated from the 3rd iteration + # because compilation occurs in the first two iterations + if j == 2: + t1 = time.time() + + latents_batch = latents_batches[0] + latents_batches = torch.roll(latents_batches, shifts=-1, dims=0) + text_embeddings_batch = text_embeddings_batches[0] + text_embeddings_batches = torch.roll(text_embeddings_batches, shifts=-1, dims=0) + add_text_embeddings_batch = add_text_embeddings_batches[0] + add_text_embeddings_batches = torch.roll(add_text_embeddings_batches, shifts=-1, dims=0) + add_time_ids_batch = add_time_ids_batches[0] + add_time_ids_batches = torch.roll(add_time_ids_batches, shifts=-1, dims=0) + + for i in range(num_inference_steps): + timestep = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) + + capture = True if self.use_hpu_graphs and j == 0 and i < 2 else False + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeddings_batch, "time_ids": add_time_ids_batch} + noise_pred = self.unet_hpu( + latent_model_input, + timestep, + text_embeddings_batch, + timestep_cond, + self.cross_attention_kwargs, + added_cond_kwargs, + capture, + ) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_batch = self.scheduler.step( + noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False + )[0] + + if not self.use_hpu_graphs: + self.htcore.mark_step() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + + latents_batch = callback_outputs.pop("latents", latents_batch) + _prompt_embeds = callback_outputs.pop("prompt_embeds", None) + _negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", None) + if _prompt_embeds is not None and _negative_prompt_embeds is not None: + text_embeddings_batch = torch.cat([_negative_prompt_embeds, _prompt_embeds]) + _add_text_embeds = callback_outputs.pop("add_text_embeds", None) + _negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", None) + if _add_text_embeds is not None and _negative_pooled_prompt_embeds is not None: + add_text_embeddings_batch = torch.cat([_negative_pooled_prompt_embeds, _add_text_embeds]) + _add_time_ids = callback_outputs.pop("add_time_ids", None) + _negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", None) + if _add_time_ids is not None and _negative_add_time_ids is not None: + add_time_ids_batch = torch.cat([_add_time_ids, _negative_add_time_ids]) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if not output_type == "latent": + # Post-processing + image = self.vae.decode(latents_batch / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents_batch + + outputs["images"].append(image) + + if not self.use_hpu_graphs: + self.htcore.mark_step() + + speed_metrics_prefix = "generation" + speed_measures = speed_metrics( + split=speed_metrics_prefix, + start_time=t0, + num_samples=num_batches * batch_size if t1 == t0 else (num_batches - 2) * batch_size, + num_steps=num_batches, + start_time_after_warmup=t1, + ) + logger.info(f"Speed metrics: {speed_measures}") + + # Remove dummy generations if needed + if num_dummy_samples > 0: + outputs["images"][-1] = outputs["images"][-1][:-num_dummy_samples] + + # Process generated images + for i, image in enumerate(outputs["images"][:]): + if i == 0: + outputs["images"].clear() + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if output_type == "pil": + outputs["images"] += image + else: + outputs["images"] += [*image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return outputs["images"] + + return GaudiStableDiffusionXLPipelineOutput( + images=outputs["images"], + throughput=speed_measures[f"{speed_metrics_prefix}_samples_per_second"], + ) + + @torch.no_grad() + def unet_hpu( + self, + latent_model_input, + timestep, + encoder_hidden_states, + timestep_cond, + cross_attention_kwargs, + added_cond_kwargs, + capture, + ): + if self.use_hpu_graphs: + return self.capture_replay( + latent_model_input, + timestep, + encoder_hidden_states, + timestep_cond, + cross_attention_kwargs, + added_cond_kwargs, + capture, + ) + else: + return self.unet( + latent_model_input, + timestep, + encoder_hidden_states=encoder_hidden_states, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay( + self, + latent_model_input, + timestep, + encoder_hidden_states, + timestep_cond, + cross_attention_kwargs, + added_cond_kwargs, + capture, + ): + inputs = [ + latent_model_input, + timestep, + encoder_hidden_states, + timestep_cond, + cross_attention_kwargs, + added_cond_kwargs, + ] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if capture: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + + outputs = self.unet( + sample=inputs[0], + timestep=inputs[1], + encoder_hidden_states=inputs[2], + timestep_cond=inputs[3], + cross_attention_kwargs=inputs[4], + added_cond_kwargs=inputs[5], + return_dict=False, + )[0] + + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs diff --git a/optimum/habana/diffusers/schedulers/__init__.py b/optimum/habana/diffusers/schedulers/__init__.py index e65dcd5120..37eb80b1a6 100644 --- a/optimum/habana/diffusers/schedulers/__init__.py +++ b/optimum/habana/diffusers/schedulers/__init__.py @@ -1 +1,3 @@ from .scheduling_ddim import GaudiDDIMScheduler +from .scheduling_euler_ancestral_discrete import GaudiEulerAncestralDiscreteScheduler +from .scheduling_euler_discrete import GaudiEulerDiscreteScheduler diff --git a/optimum/habana/diffusers/schedulers/scheduling_ddim.py b/optimum/habana/diffusers/schedulers/scheduling_ddim.py index 440c15268a..d15420853f 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_ddim.py +++ b/optimum/habana/diffusers/schedulers/scheduling_ddim.py @@ -120,14 +120,24 @@ def reset_timestep_dependent_params(self): self.alpha_prod_t_prev_list = [] self.variance_list = [] - def get_params(self): + def get_params(self, timestep: Optional[int] = None): + """ + Initialize the time-dependent parameters, and retrieve the time-dependent + parameters at each timestep. The tensors are rolled in a separate function + at the end of the scheduler step in case parameters are retrieved multiple + times in a timestep, e.g., when scaling model inputs and in the scheduler step. + + Args: + timestep (`int`, optional): + The current discrete timestep in the diffusion chain. Optionally used to + initialize parameters in cases which start in the middle of the + denoising schedule (e.g. for image-to-image). + """ if not self.are_timestep_dependent_params_set: prev_timesteps = self.timesteps - self.config.num_train_timesteps // self.num_inference_steps - for timestep, prev_timestep in zip(self.timesteps, prev_timesteps): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = ( - self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - ) + for t, prev_t in zip(self.timesteps, prev_timesteps): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod self.alpha_prod_t_list.append(alpha_prod_t) self.alpha_prod_t_prev_list.append(alpha_prod_t_prev) @@ -139,14 +149,23 @@ def get_params(self): self.are_timestep_dependent_params_set = True alpha_prod_t = self.alpha_prod_t_list[0] - self.alpha_prod_t_list = torch.roll(self.alpha_prod_t_list, shifts=-1, dims=0) alpha_prod_t_prev = self.alpha_prod_t_prev_list[0] - self.alpha_prod_t_prev_list = torch.roll(self.alpha_prod_t_prev_list, shifts=-1, dims=0) variance = self.variance_list[0] - self.variance_list = torch.roll(self.variance_list, shifts=-1, dims=0) return alpha_prod_t, alpha_prod_t_prev, variance + def roll_params(self): + """ + Roll tensors to update the values of the time-dependent parameters at each timestep. + """ + if self.are_timestep_dependent_params_set: + self.alpha_prod_t_list = torch.roll(self.alpha_prod_t_list, shifts=-1, dims=0) + self.alpha_prod_t_prev_list = torch.roll(self.alpha_prod_t_prev_list, shifts=-1, dims=0) + self.variance_list = torch.roll(self.variance_list, shifts=-1, dims=0) + else: + raise ValueError("Time-dependent parameters should be set first.") + return + # def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: # """ # Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -170,6 +189,7 @@ def _get_variance(self, alpha_prod_t, alpha_prod_t_prev): def step( self, model_output: torch.FloatTensor, + timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -226,7 +246,7 @@ def step( # Done in self.get_params() below # 2. compute alphas, betas - alpha_prod_t, alpha_prod_t_prev, variance = self.get_params() + alpha_prod_t, alpha_prod_t_prev, variance = self.get_params(timestep) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called @@ -286,6 +306,9 @@ def step( prev_sample = prev_sample + std_dev_t * variance_noise + # Roll parameters for next timestep + self.roll_params() + if not return_dict: return (prev_sample,) diff --git a/optimum/habana/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000000..36b47dc047 --- /dev/null +++ b/optimum/habana/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,258 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import register_to_config +from diffusers.schedulers import EulerAncestralDiscreteScheduler +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput + +from optimum.utils import logging + + +logger = logging.get_logger(__name__) + + +class GaudiEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): + """ + Extends [Diffusers' EulerAncestralDiscreteScheduler](https://huggingface.co/docs/diffusers/en/api/schedulers/euler_ancestral) to run optimally on Gaudi: + - All time-dependent parameters are generated at the beginning + - At each time step, tensors are rolled to update the values of the time-dependent parameters + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + timestep_spacing, + steps_offset, + ) + + self._initial_timestep = None + self.reset_timestep_dependent_params() + + def reset_timestep_dependent_params(self): + self.are_timestep_dependent_params_set = False + self.sigma_t_list = [] + self.sigma_up_t_list = [] + self.sigma_down_t_list = [] + + def get_params(self, timestep: Union[float, torch.FloatTensor]): + """ + Initialize the time-dependent parameters, and retrieve the time-dependent + parameters at each timestep. The tensors are rolled in a separate function + at the end of the scheduler step in case parameters are retrieved multiple + times in a timestep, e.g., when scaling model inputs and in the scheduler step. + + Args: + timestep (`float`): + The current discrete timestep in the diffusion chain. Optionally used to + initialize parameters in cases which start in the middle of the + denoising schedule (e.g. for image-to-image) + """ + if self.step_index is None: + self._init_step_index(timestep) + + if not self.are_timestep_dependent_params_set: + sigmas_from = self.sigmas[self.step_index : -1] + sigmas_to = self.sigmas[(self.step_index + 1) :] + + for sigma_from, sigma_to in zip(sigmas_from, sigmas_to): + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + self.sigma_t_list.append(sigma_from) + self.sigma_up_t_list.append(sigma_up) + self.sigma_down_t_list.append(sigma_down) + + self.sigma_t_list = torch.stack(self.sigma_t_list) + self.sigma_up_t_list = torch.stack(self.sigma_up_t_list) + self.sigma_down_t_list = torch.stack(self.sigma_down_t_list) + self.are_timestep_dependent_params_set = True + + sigma = self.sigma_t_list[0] + sigma_up = self.sigma_up_t_list[0] + sigma_down = self.sigma_down_t_list[0] + + return sigma, sigma_up, sigma_down + + def roll_params(self): + """ + Roll tensors to update the values of the time-dependent parameters at each timestep. + """ + if self.are_timestep_dependent_params_set: + self.sigma_t_list = torch.roll(self.sigma_t_list, shifts=-1, dims=0) + self.sigma_up_t_list = torch.roll(self.sigma_up_t_list, shifts=-1, dims=0) + self.sigma_down_t_list = torch.roll(self.sigma_down_t_list, shifts=-1, dims=0) + else: + raise ValueError("Time-dependent parameters should be set first.") + return + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + + sigma, _, _ = self.get_params(timestep) + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + sigma, sigma_up, sigma_down = self.get_params(timestep) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + + # torch.randn is broken on HPU so running it on CPU + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator) + if device.type == "hpu": + noise = noise.to(device) + + prev_sample = prev_sample + noise * sigma_up + + # upon completion increase step index by one + self._step_index += 1 + self.roll_params() + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) diff --git a/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100644 index 0000000000..d96dc9e757 --- /dev/null +++ b/optimum/habana/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,260 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import register_to_config +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteSchedulerOutput + +from optimum.utils import logging + + +logger = logging.get_logger(__name__) + + +class GaudiEulerDiscreteScheduler(EulerDiscreteScheduler): + """ + Extends [Diffusers' EulerDiscreteScheduler](https://huggingface.co/docs/diffusers/api/schedulers#diffusers.EulerDiscreteScheduler) to run optimally on Gaudi: + - All time-dependent parameters are generated at the beginning + - At each time step, tensors are rolled to update the values of the time-dependent parameters + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + interpolation_type(`str`, defaults to `"linear"`, *optional*): + The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of + `"linear"` or `"log_linear"`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: Optional[bool] = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + interpolation_type, + use_karras_sigmas, + timestep_spacing, + steps_offset, + ) + self._initial_timestep = None + self.reset_timestep_dependent_params() + + def reset_timestep_dependent_params(self): + self.are_timestep_dependent_params_set = False + self.sigma_list = [] + self.sigma_next_list = [] + + def get_params(self, timestep: Union[float, torch.FloatTensor]): + if self.step_index is None: + self._init_step_index(timestep) + + if not self.are_timestep_dependent_params_set: + sigmas = self.sigmas[self.step_index : -1] + sigmas_next = self.sigmas[(self.step_index + 1) :] + + for sigma, sigma_next in zip(sigmas, sigmas_next): + self.sigma_list.append(sigma) + self.sigma_next_list.append(sigma_next) + + self.sigma_list = torch.stack(self.sigma_list) + self.sigma_next_list = torch.stack(self.sigma_next_list) + self.are_timestep_dependent_params_set = True + + sigma = self.sigma_list[0] + sigma_next = self.sigma_next_list[0] + + return sigma, sigma_next + + def roll_params(self): + """ + Roll tensors to update the values of the time-dependent parameters at each timestep. + """ + if self.are_timestep_dependent_params_set: + self.sigma_list = torch.roll(self.sigma_list, shifts=-1, dims=0) + self.sigma_next_list = torch.roll(self.sigma_next_list, shifts=-1, dims=0) + else: + raise ValueError("Time-dependent parameters should be set first.") + return + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + + sigma, _ = self.get_params(timestep) + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + sigma, sigma_next = self.get_params(timestep) + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + device = model_output.device + + # torch.randn is broken on HPU so running it on CPU + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator) + if device.type == "hpu": + noise = noise.to(device) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = sigma_next - sigma_hat + + prev_sample = sample + derivative * dt + + # upon completion increase step index by one + self._step_index += 1 + self.roll_params() + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index aa57fdc9a6..b79bb27fa5 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -30,16 +30,19 @@ from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.testing_utils import slow from optimum.habana import GaudiConfig from optimum.habana.diffusers import ( GaudiDDIMScheduler, GaudiDiffusionPipeline, + GaudiEulerAncestralDiscreteScheduler, + GaudiEulerDiscreteScheduler, GaudiStableDiffusionLDM3DPipeline, GaudiStableDiffusionPipeline, GaudiStableDiffusionUpscalePipeline, + GaudiStableDiffusionXLPipeline, ) from optimum.habana.utils import set_seed @@ -808,3 +811,328 @@ def test_textual_inversion(self): # ) self.assertEqual(image.shape, (512, 512, 3)) # self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3) + + +class GaudiStableDiffusionXLPipelineTester(TestCase): + """ + Tests the StableDiffusionXLPipeline for Gaudi. + """ + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(2, 4), + layers_per_block=2, + time_cond_proj_dim=time_cond_proj_dim, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + norm_num_groups=1, + ) + scheduler = GaudiEulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + } + return inputs + + def test_stable_diffusion_xl_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + gaudi_config = GaudiConfig(use_torch_autocast=False) + sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images[0] + + image_slice = image[-3:, -3:, -1] + + self.assertEqual(image.shape, (64, 64, 3)) + expected_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.47]) + + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) + + def test_stable_diffusion_xl_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + gaudi_config = GaudiConfig(use_torch_autocast=False) + sd_pipe = GaudiStableDiffusionXLPipeline(use_habana=True, gaudi_config=gaudi_config, **components) + sd_pipe.scheduler = GaudiEulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images[0] + + image_slice = image[-3:, -3:, -1] + + self.assertEqual(image.shape, (64, 64, 3)) + expected_slice = np.array([0.4675, 0.5173, 0.4611, 0.4067, 0.5250, 0.4674, 0.5446, 0.5094, 0.4791]) + self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) + + @parameterized.expand(["pil", "np", "latent"]) + def test_stable_diffusion_xl_output_types(self, output_type): + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + num_prompts = 2 + num_images_per_prompt = 3 + + outputs = sd_pipe( + num_prompts * [prompt], + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=2, + output_type=output_type, + ) + + self.assertEqual(len(outputs.images), 2 * 3) + + def test_stable_diffusion_xl_num_images_per_prompt(self): + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # Test num_images_per_prompt=1 (default) + images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images + + self.assertEqual(len(images), 1) + self.assertEqual(images[0].shape, (64, 64, 3)) + + # Test num_images_per_prompt=1 (default) for several prompts + num_prompts = 3 + images = sd_pipe([prompt] * num_prompts, num_inference_steps=2, output_type="np").images + + self.assertEqual(len(images), num_prompts) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt + ).images + + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Test num_images_per_prompt for several prompts + num_prompts = 2 + images = sd_pipe( + [prompt] * num_prompts, + num_inference_steps=2, + output_type="np", + num_images_per_prompt=num_images_per_prompt, + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_xl_batch_sizes(self): + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # Test batch_size > 1 where batch_size is a divider of the total number of generated images + batch_size = 3 + num_images_per_prompt = batch_size**2 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ).images + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Same test for several prompts + num_prompts = 3 + images = sd_pipe( + [prompt] * num_prompts, + num_inference_steps=2, + output_type="np", + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Test batch_size when it is not a divider of the total number of generated images for a single prompt + num_images_per_prompt = 7 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ).images + + self.assertEqual(len(images), num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + # Same test for several prompts + num_prompts = 2 + images = sd_pipe( + [prompt] * num_prompts, + num_inference_steps=2, + output_type="np", + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ).images + + self.assertEqual(len(images), num_prompts * num_images_per_prompt) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_xl_bf16(self): + """Test that stable diffusion works with bf16""" + components = self.get_dummy_components() + gaudi_config = GaudiConfig() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + gaudi_config=gaudi_config, + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device="cpu").manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images[0] + + self.assertEqual(image.shape, (64, 64, 3)) + + def test_stable_diffusion_xl_default(self): + components = self.get_dummy_components() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + gaudi_config="Habana/stable-diffusion", + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device="cpu").manual_seed(0) + images = sd_pipe( + [prompt] * 2, + generator=generator, + num_inference_steps=2, + output_type="np", + batch_size=3, + num_images_per_prompt=5, + ).images + + self.assertEqual(len(images), 10) + self.assertEqual(images[-1].shape, (64, 64, 3)) + + def test_stable_diffusion_xl_hpu_graphs(self): + components = self.get_dummy_components() + + sd_pipe = GaudiStableDiffusionXLPipeline( + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + **components, + ) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device="cpu").manual_seed(0) + images = sd_pipe( + [prompt] * 2, + generator=generator, + num_inference_steps=2, + output_type="np", + batch_size=3, + num_images_per_prompt=5, + ).images + + self.assertEqual(len(images), 10) + self.assertEqual(images[-1].shape, (64, 64, 3))