From b262d99fa74a146072c565c10b4265a18c8a9a2b Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 10 Feb 2020 21:51:54 +0000 Subject: [PATCH] clean up positional args --- .../opperf/nd_operations/unary_operators.py | 2 +- benchmark/opperf/utils/benchmark_utils.py | 2 +- benchmark/opperf/utils/ndarray_utils.py | 26 ++++++++++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/benchmark/opperf/nd_operations/unary_operators.py b/benchmark/opperf/nd_operations/unary_operators.py index af0c86175428..08075906fae5 100644 --- a/benchmark/opperf/nd_operations/unary_operators.py +++ b/benchmark/opperf/nd_operations/unary_operators.py @@ -68,7 +68,7 @@ def run_mx_unary_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n profiler=profiler, inputs=[{"args": [(1024, 1024)], "num_outputs":1}, - "args": [(10000, 1)], + {"args": [(10000, 1)], "num_outputs":1}], warmup=warmup, runs=runs) diff --git a/benchmark/opperf/utils/benchmark_utils.py b/benchmark/opperf/utils/benchmark_utils.py index 5a8154bd81d1..aa526c418112 100644 --- a/benchmark/opperf/utils/benchmark_utils.py +++ b/benchmark/opperf/utils/benchmark_utils.py @@ -67,7 +67,7 @@ def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, kw op_benchmark_result = {op.__name__: []} logging.info("Begin Benchmark - {name}".format(name=op.__name__)) for idx, kwargs in enumerate(kwargs_list): - _, profiler_output = benchmark_helper_func(op, runs, [], **kwargs) + _, profiler_output = benchmark_helper_func(op, runs, **kwargs) # Add inputs used for profiling this operator into result profiler_output["inputs"] = inputs[idx] diff --git a/benchmark/opperf/utils/ndarray_utils.py b/benchmark/opperf/utils/ndarray_utils.py index 485b07c197e0..3f5dda8f036b 100644 --- a/benchmark/opperf/utils/ndarray_utils.py +++ b/benchmark/opperf/utils/ndarray_utils.py @@ -43,19 +43,25 @@ def nd_forward_backward_and_profile(op, runs, **kwargs): for _ in range(runs): with mx.autograd.record(): args = [] + # need to create a new dictionary because can't update dict while iterating + kwargs_new = dict() for key in kwargs: + # separate positional args from key-worded args if key.startswith("args"): - args.append(kwargs.pop(key)) + args.append(kwargs[key]) + else: + kwargs_new[key]=kwargs[key] + # check for positional args if len(args): - res = op(*args, **kwargs) + res = op(*args, **kwargs_new) else: - res = op(**kwargs) + res = op(**kwargs_new) res.backward() nd.waitall() return res -def nd_forward_and_profile(op, runs, *args, **kwargs): +def nd_forward_and_profile(op, runs, **kwargs): """Helper function to run a given NDArray operator (op) for 'runs' number of times with given args and kwargs. Executes ONLY forward pass. @@ -76,13 +82,19 @@ def nd_forward_and_profile(op, runs, *args, **kwargs): """ for _ in range(runs): args = [] + # need to create a new dictionary because can't update dict while iterating + kwargs_new = dict() for key in kwargs: + # separate positional args from key-worded args if key.startswith("args"): - args.append(kwargs.pop(key)) + args.append(kwargs[key]) + else: + kwargs_new[key]=kwargs[key] + # check for positional args if len(args): - res = op(*args, **kwargs) + res = op(*args, **kwargs_new) else: - res = op(**kwargs) + res = op(**kwargs_new) nd.waitall() return res