diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index deb932a494..8004791443 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -698,6 +698,8 @@ def generate_dataset(batch): f"Output: {tokenizer.batch_decode(outputs, skip_special_tokens=True)[:args.batch_size*args.num_return_sequences]}" ) print(separator) + if args.run_partial_dataset and args.n_iterations == i + 1: + break t_end = time.time() throughput = total_new_tokens_generated / duration