diff --git a/benchmarks/scripts/cccl/bench/search.py b/benchmarks/scripts/cccl/bench/search.py index e34113c70eb..d82ed5bf71c 100644 --- a/benchmarks/scripts/cccl/bench/search.py +++ b/benchmarks/scripts/cccl/bench/search.py @@ -49,7 +49,7 @@ def parse_arguments(): '--list-benches', action=argparse.BooleanOptionalAction, help="Show available benchmarks.") parser.add_argument('--num-shards', type=int, default=1, help='Split benchmarks into M pieces and only run one') parser.add_argument('--run-shard', type=int, default=0, help='Run shard N / M of benchmarks') - parser.add_argument('-P0', action=argparse.BooleanOptionalAction, help="Run P0 benchmarks (overwrites -R)") + parser.add_argument('-P0', action=argparse.BooleanOptionalAction, help="Run P0 benchmarks") return parser.parse_args() @@ -64,16 +64,18 @@ def run_benches(algnames, sub_space, seeker): print("#### ERROR exception occured while running {}: '{}'".format(algname, e)) +def filter_benchmarks_by_regex(benchmarks, R): + pattern = re.compile(R) + return list(filter(lambda x: pattern.match(x), benchmarks)) + + def filter_benchmarks(benchmarks, args): if args.run_shard >= args.num_shards: raise ValueError('run-shard must be less than num-shards') - R = args.R + algnames = filter_benchmarks_by_regex(benchmarks.keys(), args.R) if args.P0: - R = '^(?!.*segmented).*(scan|reduce|select|sort).*' - - pattern = re.compile(R) - algnames = list(filter(lambda x: pattern.match(x), benchmarks.keys())) + algnames = filter_benchmarks_by_regex(algnames, '^(?!.*segmented).*(scan|reduce|select|sort).*') algnames.sort() if args.num_shards > 1: