Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __call__(
clip_skip: Optional[int] = None,
profiling_warmup_steps: Optional[int] = 0,
profiling_steps: Optional[int] = 0,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -438,7 +439,7 @@ def __call__(
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == 2:
if j == kwargs.get("throughput_warmup_steps", 3):
t1 = time.time()

latents_batch = latents_batches[0]
Expand All @@ -451,8 +452,6 @@ def __call__(
t = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

capture = True if self.use_hpu_graphs and i < 2 else False

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch
Expand Down Expand Up @@ -484,7 +483,6 @@ def __call__(
image,
cond_scale,
guess_mode,
capture,
)

if guess_mode and do_classifier_free_guidance:
Expand All @@ -504,7 +502,6 @@ def __call__(
cross_attention_kwargs,
down_block_res_samples,
mid_block_res_sample,
capture,
)

# perform guidance
Expand Down Expand Up @@ -604,7 +601,6 @@ def unet_hpu(
cross_attention_kwargs,
down_block_additional_residuals,
mid_block_additional_residual,
capture,
):
if self.use_hpu_graphs:
return self.unet_capture_replay(
Expand All @@ -613,7 +609,6 @@ def unet_hpu(
encoder_hidden_states,
down_block_additional_residuals,
mid_block_additional_residual,
capture,
)
else:
return self.unet(
Expand All @@ -634,7 +629,6 @@ def unet_capture_replay(
encoder_hidden_states,
down_block_additional_residuals,
mid_block_additional_residual,
capture,
):
inputs = [
latent_model_input,
Expand All @@ -647,7 +641,7 @@ def unet_capture_replay(
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)

if capture:
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
Expand Down Expand Up @@ -689,7 +683,6 @@ def controlnet_hpu(
controlnet_cond,
conditioning_scale,
guess_mode,
capture,
):
if self.use_hpu_graphs:
return self.controlnet_capture_replay(
Expand All @@ -699,7 +692,6 @@ def controlnet_hpu(
controlnet_cond,
conditioning_scale,
guess_mode,
capture,
)
else:
return self.controlnet(
Expand All @@ -721,7 +713,6 @@ def controlnet_capture_replay(
controlnet_cond,
conditioning_scale,
guess_mode,
capture,
):
inputs = [
control_model_input,
Expand All @@ -735,7 +726,7 @@ def controlnet_capture_replay(
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)

if capture:
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def __call__(
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == 2:
if j == kwargs.get("throughput_warmup_steps", 3):
t1 = time.time()

latents_batch = latents_batches[0]
Expand All @@ -429,8 +429,6 @@ def __call__(
timestep = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

capture = True if self.use_hpu_graphs and i < 2 else False

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
Expand All @@ -444,7 +442,6 @@ def __call__(
text_embeddings_batch,
timestep_cond,
self.cross_attention_kwargs,
capture,
)

# perform guidance
Expand Down Expand Up @@ -547,11 +544,9 @@ def __call__(
)

@torch.no_grad()
def unet_hpu(
self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs, capture
):
def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, timestep_cond, cross_attention_kwargs):
if self.use_hpu_graphs:
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture)
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
else:
return self.unet(
latent_model_input,
Expand All @@ -563,12 +558,12 @@ def unet_hpu(
)[0]

@torch.no_grad()
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture):
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states):
inputs = [latent_model_input, timestep, encoder_hidden_states, False]
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)

if capture:
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -327,7 +328,7 @@ def __call__(
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == 2:
if j == kwargs.get("throughput_warmup_steps", 3):
t1 = time.time()

latents_batch = latents_batches[0]
Expand All @@ -339,8 +340,6 @@ def __call__(
timestep = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

capture = True if self.use_hpu_graphs and i < 2 else False

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch
Expand All @@ -353,7 +352,6 @@ def __call__(
timestep,
text_embeddings_batch,
cross_attention_kwargs,
capture,
)

# perform guidance
Expand Down Expand Up @@ -443,9 +441,9 @@ def __call__(
)

@torch.no_grad()
def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, capture):
def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs):
if self.use_hpu_graphs:
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture)
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states)
else:
return self.unet(
latent_model_input,
Expand All @@ -456,12 +454,12 @@ def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_at
)[0]

@torch.no_grad()
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture):
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states):
inputs = [latent_model_input, timestep, encoder_hidden_states, False]
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)

if capture:
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -438,7 +439,7 @@ def __call__(
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == 2:
if j == kwargs.get("throughput_warmup_steps", 3):
t1 = time.time()

latents_batch = latents_batches[0]
Expand All @@ -454,8 +455,6 @@ def __call__(
timestep = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

capture = True if self.use_hpu_graphs and i < 2 else False

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if do_classifier_free_guidance else latents_batch
Expand All @@ -473,7 +472,6 @@ def __call__(
timestep,
text_embeddings_batch,
cross_attention_kwargs,
capture,
class_labels=noise_level_input,
)

Expand Down Expand Up @@ -574,11 +572,9 @@ def __call__(
)

@torch.no_grad()
def unet_hpu(
self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, capture, class_labels
):
def unet_hpu(self, latent_model_input, timestep, encoder_hidden_states, cross_attention_kwargs, class_labels):
if self.use_hpu_graphs:
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, capture, class_labels)
return self.capture_replay(latent_model_input, timestep, encoder_hidden_states, class_labels)
else:
return self.unet(
latent_model_input,
Expand All @@ -590,12 +586,12 @@ def unet_hpu(
)[0]

@torch.no_grad()
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, capture, class_labels):
def capture_replay(self, latent_model_input, timestep, encoder_hidden_states, class_labels):
inputs = [latent_model_input, timestep, encoder_hidden_states, False, class_labels]
h = self.ht.hpu.graphs.input_hash(inputs)
cached = self.cache.get(h)

if capture:
if cached is None:
# Capture the graph and cache it
with self.ht.hpu.stream(self.hpu_stream):
graph = self.ht.hpu.HPUGraph()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def __call__(
for j in self.progress_bar(range(num_batches)):
# The throughput is calculated from the 3rd iteration
# because compilation occurs in the first two iterations
if j == 2:
if j == kwargs.get("throughput_warmup_steps", 3):
t1 = time.time()

latents_batch = latents_batches[0]
Expand All @@ -674,8 +674,6 @@ def __call__(
timestep = timesteps[0]
timesteps = torch.roll(timesteps, shifts=-1, dims=0)

capture = True if self.use_hpu_graphs and j == 0 and i < 2 else False

# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents_batch] * 2) if self.do_classifier_free_guidance else latents_batch
Expand All @@ -691,7 +689,6 @@ def __call__(
timestep_cond,
self.cross_attention_kwargs,
added_cond_kwargs,
capture,
)

# perform guidance
Expand Down Expand Up @@ -801,7 +798,6 @@ def unet_hpu(
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
capture,
):
if self.use_hpu_graphs:
return self.capture_replay(
Expand All @@ -811,7 +807,6 @@ def unet_hpu(
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
capture,
)
else:
return self.unet(
Expand All @@ -833,7 +828,6 @@ def capture_replay(
timestep_cond,
cross_attention_kwargs,
added_cond_kwargs,
capture,
):
inputs = [
latent_model_input,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
THROUGHPUT_BASELINE_BF16 = 1.019
THROUGHPUT_BASELINE_AUTOCAST = 0.389
else:
THROUGHPUT_BASELINE_BF16 = 0.309
THROUGHPUT_BASELINE_BF16 = 0.412
THROUGHPUT_BASELINE_AUTOCAST = 0.114

TEXTUAL_INVERSION_THROUGHPUT = 59.13010439968039
Expand Down