diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index 0d8ff554d1..e83d455237 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -286,8 +286,12 @@ def main(): action="store_true", help="Use rescale_betas_zero_snr for controlling image brightness", ) + parser.add_argument("--optimize", action="store_true", help="Use optimized pipeline.") args = parser.parse_args() + if args.optimize and not args.use_habana: + raise ValueError("--optimize can only be used with --use-habana.") + # Select stable diffuson pipeline based on input sdxl_models = ["stable-diffusion-xl", "sdxl"] sd3_models = ["stable-diffusion-3"] @@ -302,6 +306,8 @@ def main(): scheduler = GaudiEulerDiscreteScheduler.from_pretrained( args.model_name_or_path, subfolder="scheduler", **kwargs ) + if args.optimize: + scheduler.hpu_opt = True elif args.scheduler == "euler_ancestral_discrete": scheduler = GaudiEulerAncestralDiscreteScheduler.from_pretrained( args.model_name_or_path, subfolder="scheduler", **kwargs @@ -417,14 +423,31 @@ def main(): pipeline = AutoPipelineForInpainting.from_pretrained(args.model_name_or_path, **kwargs) - else: + elif args.optimize: # Import SDXL pipeline + import habana_frameworks.torch.hpu as torch_hpu + + from optimum.habana.diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_mlperf import ( + StableDiffusionXLPipeline_HPU, + ) + + pipeline = StableDiffusionXLPipeline_HPU.from_pretrained( + args.model_name_or_path, + **kwargs, + ) + + pipeline.to(torch.device("hpu")) + pipeline.unet.set_default_attn_processor(pipeline.unet) + if args.use_hpu_graphs: + pipeline.unet = torch_hpu.wrap_in_hpu_graph(pipeline.unet) + else: from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( args.model_name_or_path, **kwargs, ) + if args.lora_id: pipeline.load_lora_weights(args.lora_id) diff --git a/optimum/habana/diffusers/models/attention_processor.py b/optimum/habana/diffusers/models/attention_processor.py index b0461a272b..097292115a 100755 --- a/optimum/habana/diffusers/models/attention_processor.py +++ b/optimum/habana/diffusers/models/attention_processor.py @@ -19,7 +19,7 @@ import torch import torch.nn.functional as F from diffusers.models.attention_processor import Attention -from diffusers.utils import USE_PEFT_BACKEND, logging +from diffusers.utils import deprecate, logging from diffusers.utils.import_utils import is_xformers_available from torch import nn @@ -107,8 +107,13 @@ def __call__( encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, + *args, + **kwargs, ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -132,16 +137,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -171,7 +175,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) + hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py index 26bfa7b69d..72297d37d4 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_mlperf.py @@ -170,6 +170,7 @@ def __init__( ) self.unet.set_default_attn_processor = set_default_attn_processor_hpu self.unet.forward = gaudi_unet_2d_condition_model_forward + self.quantized = False def run_unet( self, @@ -609,7 +610,6 @@ def __call__( self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: - timesteps = [t.item() for t in timesteps] if self.quantized: for i, t in enumerate(timesteps[0:-2]): if self.interrupt: @@ -666,7 +666,9 @@ def __call__( ) hb_profiler.step() else: - for i, t in enumerate(timesteps): + for i in range(num_inference_steps): + t = timesteps[0] + timesteps = torch.roll(timesteps, shifts=-1, dims=0) if self.interrupt: continue latents = self.run_unet(