diff --git a/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py b/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py index 22d174c41d..d358137f1e 100644 --- a/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/optimum/habana/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -179,6 +179,7 @@ def __call__( clip_skip: Optional[int] = None, profiling_warmup_steps: Optional[int] = 0, profiling_steps: Optional[int] = 0, + **kwargs, ): r""" The call function to the pipeline for generation. @@ -438,7 +439,7 @@ def __call__( for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations - if j == 2: + if j == kwargs.get("throughput_warmup_steps", 3): t1 = time.time() latents_batch = latents_batches[0] @@ -451,8 +452,6 @@ def __call__( t = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) - capture = True if self.use_hpu_graphs and i < 2 else False - # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch @@ -484,7 +483,6 @@ def __call__( image, cond_scale, guess_mode, - capture, ) if guess_mode and do_classifier_free_guidance: @@ -504,7 +502,6 @@ def __call__( cross_attention_kwargs, down_block_res_samples, mid_block_res_sample, - capture, ) # perform guidance @@ -604,7 +601,6 @@ def unet_hpu( cross_attention_kwargs, down_block_additional_residuals, mid_block_additional_residual, - capture, ): if self.use_hpu_graphs: return self.unet_capture_replay( @@ -613,7 +609,6 @@ def unet_hpu( encoder_hidden_states, down_block_additional_residuals, mid_block_additional_residual, - capture, ) else: return self.unet( @@ -634,7 +629,6 @@ def unet_capture_replay( encoder_hidden_states, down_block_additional_residuals, mid_block_additional_residual, - capture, ): inputs = [ latent_model_input, @@ -647,7 +641,7 @@ def unet_capture_replay( h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) - if capture: + if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() @@ -689,7 +683,6 @@ def controlnet_hpu( controlnet_cond, conditioning_scale, guess_mode, - capture, ): if self.use_hpu_graphs: return self.controlnet_capture_replay( @@ -699,7 +692,6 @@ def controlnet_hpu( controlnet_cond, conditioning_scale, guess_mode, - capture, ) else: return self.controlnet( @@ -721,7 +713,6 @@ def controlnet_capture_replay( controlnet_cond, conditioning_scale, guess_mode, - capture, ): inputs = [ control_model_input, @@ -735,7 +726,7 @@ def controlnet_capture_replay( h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) - if capture: + if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c6e1789a43..a522002650 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -417,7 +417,7 @@ def __call__( for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations - if j == 2: + if j == kwargs.get("throughput_warmup_steps", 3): t1 = time.time() latents_batch = latents_batches[0] @@ -429,8 +429,6 @@ def __call__( timestep = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) - capture = True if self.use_hpu_graphs and i < 2 else False - # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch @@ -444,7 +442,6 @@ def __call__( text_embeddings_batch, timestep_cond, self.cross_attention_kwargs, - capture, ) # perform guidance @@ -547,11 +544,9 @@ def __call__( ) @torch.no_grad() - def unet_hpu( - self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, capture - ): + def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs): if self.use_hpu_graphs: - return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture) + return self.capture_replay(latent_model_input, timestep, encoder_hidden_states) else: return self.unet( latent_model_input, @@ -563,12 +558,12 @@ def unet_hpu( )[0] @torch.no_grad() - def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture): + def capture_replay(self, latent_model_input, timestep, encoder_hidden_states): inputs = [latent_model_input, timestep, encoder_hidden_states, False] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) - if capture: + if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index f1423ed7f5..6c32322c42 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -177,6 +177,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, + **kwargs, ): r""" The call function to the pipeline for generation. @@ -327,7 +328,7 @@ def __call__( for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations - if j == 2: + if j == kwargs.get("throughput_warmup_steps", 3): t1 = time.time() latents_batch = latents_batches[0] @@ -339,8 +340,6 @@ def __call__( timestep = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) - capture = True if self.use_hpu_graphs and i < 2 else False - # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch @@ -353,7 +352,6 @@ def __call__( timestep, text_embeddings_batch, cross_attention_kwargs, - capture, ) # perform guidance @@ -443,9 +441,9 @@ def __call__( ) @torch.no_grad() - def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, capture): + def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs): if self.use_hpu_graphs: - return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture) + return self.capture_replay(latent_model_input, timestep, encoder_hidden_states) else: return self.unet( latent_model_input, @@ -456,12 +454,12 @@ def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_at )[0] @torch.no_grad() - def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture): + def capture_replay(self, latent_model_input, timestep, encoder_hidden_states): inputs = [latent_model_input, timestep, encoder_hidden_states, False] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) - if capture: + if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 594e0e0c30..a574746e38 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -233,6 +233,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, + **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -438,7 +439,7 @@ def __call__( for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations - if j == 2: + if j == kwargs.get("throughput_warmup_steps", 3): t1 = time.time() latents_batch = latents_batches[0] @@ -454,8 +455,6 @@ def __call__( timestep = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) - capture = True if self.use_hpu_graphs and i < 2 else False - # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch @@ -473,7 +472,6 @@ def __call__( timestep, text_embeddings_batch, cross_attention_kwargs, - capture, class_labels=noise_level_input, ) @@ -574,11 +572,9 @@ def __call__( ) @torch.no_grad() - def unet_hpu( - self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, capture, class_labels - ): + def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, class_labels): if self.use_hpu_graphs: - return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture, class_labels) + return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, class_labels) else: return self.unet( latent_model_input, @@ -590,12 +586,12 @@ def unet_hpu( )[0] @torch.no_grad() - def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture, class_labels): + def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, class_labels): inputs = [latent_model_input, timestep, encoder_hidden_states, False, class_labels] h = self.ht.hpu.graphs.input_hash(inputs) cached = self.cache.get(h) - if capture: + if cached is None: # Capture the graph and cache it with self.ht.hpu.stream(self.hpu_stream): graph = self.ht.hpu.HPUGraph() diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 575dda7c28..04cf3b08f6 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -658,7 +658,7 @@ def __call__( for j in self.progress_bar(range(num_batches)): # The throughput is calculated from the 3rd iteration # because compilation occurs in the first two iterations - if j == 2: + if j == kwargs.get("throughput_warmup_steps", 3): t1 = time.time() latents_batch = latents_batches[0] @@ -674,8 +674,6 @@ def __call__( timestep = timesteps[0] timesteps = torch.roll(timesteps, shifts=-1, dims=0) - capture = True if self.use_hpu_graphs and j == 0 and i < 2 else False - # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch @@ -691,7 +689,6 @@ def __call__( timestep_cond, self.cross_attention_kwargs, added_cond_kwargs, - capture, ) # perform guidance @@ -801,7 +798,6 @@ def unet_hpu( timestep_cond, cross_attention_kwargs, added_cond_kwargs, - capture, ): if self.use_hpu_graphs: return self.capture_replay( @@ -811,7 +807,6 @@ def unet_hpu( timestep_cond, cross_attention_kwargs, added_cond_kwargs, - capture, ) else: return self.unet( @@ -833,7 +828,6 @@ def capture_replay( timestep_cond, cross_attention_kwargs, added_cond_kwargs, - capture, ): inputs = [ latent_model_input, diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index a8793c6d38..df74f9f0a9 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -54,7 +54,7 @@ THROUGHPUT_BASELINE_BF16 = 1.019 THROUGHPUT_BASELINE_AUTOCAST = 0.389 else: - THROUGHPUT_BASELINE_BF16 = 0.309 + THROUGHPUT_BASELINE_BF16 = 0.412 THROUGHPUT_BASELINE_AUTOCAST = 0.114 TEXTUAL_INVERSION_THROUGHPUT = 59.13010439968039