From 0b8948cc9fb8edf8cd80f44e125248e6aa050281 Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Fri, 22 Aug 2025 17:52:23 +0300 Subject: [PATCH] Respect `--dataset_max_samples` when using `--mlcommons_dataset` The current implementation of `run_generation.py` does not take into account the value of `--dataset_max_samples` when using an MLCommons dataset (via the `--mlcommons_dataset` argument). Limiting the number of dataset samples allows for faster debugging. --- examples/text-generation/run_generation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ec64f3f6aa..8b9ee162f2 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -550,6 +550,9 @@ def get_input(ds, batch_size): ds = get_ds(args) input_sentences = get_input(ds, args.batch_size) + if args.dataset_max_samples > 0: + input_sentences = input_sentences[: args.dataset_max_samples] + def generate(input_tokens, size=None, reduce_recompile=False, disable_profiling=False): """Generates sequences from the input sentences and returns them.""" @@ -649,6 +652,9 @@ def rounder(x): acc_file = [] num_token = 0 for i, idx in enumerate(ds.index): + if args.dataset_max_samples > 0 and i >= args.dataset_max_samples: + break + pred = results[i] eos_token_id = 2 try: