Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions examples/text-generation/run_generation.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def get_input(ds, batch_size):

def generate(input_tokens, size=None, reduce_recompile=False, disable_profiling=False):
"""Generates sequences from the input sentences and returns them."""
profiler = disabled_profiler if disable_profiling else per_token_profiler

timer = HabanaGenerationTime()
timer.start()
Expand All @@ -579,6 +580,7 @@ def generate(input_tokens, size=None, reduce_recompile=False, disable_profiling=
lazy_mode=use_lazy_mode,
hpu_graphs=args.use_hpu_graphs,
ignore_eos=args.ignore_eos,
profiler=profiler,
).cpu()
outputs = outputs.tolist()
for i in range(len(outputs)):
Expand Down Expand Up @@ -626,6 +628,7 @@ def rounder(x):
# Benchmark over n_iterations iterations
N = len(input_sentences)

per_sequence_profiler.start()
if dyn_prompt_lens is None:
for i in range(args.n_iterations):
results = []
Expand All @@ -635,6 +638,7 @@ def rounder(x):
results.extend(generated)
print(f"Generating batch {b}/{N}")
b += 1
per_sequence_profiler.step()
else:
repeated_prompt_len = cycle(dyn_prompt_lens)
for i in range(args.n_iterations):
Expand All @@ -644,8 +648,10 @@ def rounder(x):
for sentence in input_sentences:
generated = generate(sentence, prompt_len, args.reduce_recompile)
results.extend(generated)
per_sequence_profiler.step()
timer.step()
duration = timer.last_duration
per_sequence_profiler.stop()
total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens
throughput = total_new_tokens_generated / duration

Expand Down Expand Up @@ -871,24 +877,24 @@ def rounder(x):
if dyn_prompt_lens is None:
for i in range(args.n_iterations):
generated, first_token_time, rest_token_time, e2e_latency = generate(None, args.reduce_recompile)
per_sequence_profiler.step()
first_token_latencies.append(first_token_time)
rest_token_latencies.append(rest_token_time)
e2e_latencies.append(e2e_latency)
per_sequence_profiler.step()
else:
repeated_prompt_len = cycle(dyn_prompt_lens)
for i in range(args.n_iterations):
prompt_len = next(repeated_prompt_len)
print("Generating for shape,", prompt_len)
generated, first_token_time, rest_token_time, e2e_latency = generate(prompt_len, args.reduce_recompile)
per_sequence_profiler.step()
first_token_latencies.append(first_token_time)
rest_token_latencies.append(rest_token_time)
e2e_latencies.append(e2e_latency)
per_sequence_profiler.step()
timer.step()
per_sequence_profiler.stop()
logger.info("Finished running generate")
duration = timer.last_duration
per_sequence_profiler.stop()
total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens
throughput = total_new_tokens_generated / duration
# Calculate average latencies
Expand Down Expand Up @@ -1045,7 +1051,7 @@ def generate_dataset(batch, disable_profiling=False):
timer.start()
for i, batch in enumerate(dataloader):
timer.step()
generate_dataset(batch)
generate_dataset(batch, disable_profiling=True)
timer.step()
duration = timer.last_duration
# The first three iterations take longer because of graph compilation
Expand All @@ -1054,15 +1060,14 @@ def generate_dataset(batch, disable_profiling=False):
torch_hpu.synchronize()
timer.step()
compilation_duration = timer.last_duration

total_new_tokens_generated = 0
duration = 0
separator = "-" * 50
logger.info("Running generate dataset...")

timer = HabanaGenerationTime()
timer.start()
per_sequence_profiler.start()

for i, batch in enumerate(dataloader):
timer.step()
prompt, outputs = generate_dataset(batch)
Expand All @@ -1079,8 +1084,8 @@ def generate_dataset(batch, disable_profiling=False):
if args.run_partial_dataset and args.n_iterations == i + 1:
break
per_sequence_profiler.step()
timer.step()
per_sequence_profiler.stop()
timer.step()

throughput = total_new_tokens_generated / duration
# Print Stats
Expand Down