Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
clean up positional args
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Feb 10, 2020
1 parent 6f282c0 commit b262d99
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion benchmark/opperf/nd_operations/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_mx_unary_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n
profiler=profiler,
inputs=[{"args": [(1024, 1024)],
"num_outputs":1},
"args": [(10000, 1)],
{"args": [(10000, 1)],
"num_outputs":1}],
warmup=warmup,
runs=runs)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/opperf/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, kw
op_benchmark_result = {op.__name__: []}
logging.info("Begin Benchmark - {name}".format(name=op.__name__))
for idx, kwargs in enumerate(kwargs_list):
_, profiler_output = benchmark_helper_func(op, runs, [], **kwargs)
_, profiler_output = benchmark_helper_func(op, runs, **kwargs)

# Add inputs used for profiling this operator into result
profiler_output["inputs"] = inputs[idx]
Expand Down
26 changes: 19 additions & 7 deletions benchmark/opperf/utils/ndarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,25 @@ def nd_forward_backward_and_profile(op, runs, **kwargs):
for _ in range(runs):
with mx.autograd.record():
args = []
# need to create a new dictionary because can't update dict while iterating
kwargs_new = dict()
for key in kwargs:
# separate positional args from key-worded args
if key.startswith("args"):
args.append(kwargs.pop(key))
args.append(kwargs[key])
else:
kwargs_new[key]=kwargs[key]
# check for positional args
if len(args):
res = op(*args, **kwargs)
res = op(*args, **kwargs_new)
else:
res = op(**kwargs)
res = op(**kwargs_new)
res.backward()
nd.waitall()
return res


def nd_forward_and_profile(op, runs, *args, **kwargs):
def nd_forward_and_profile(op, runs, **kwargs):
"""Helper function to run a given NDArray operator (op) for 'runs' number of times with
given args and kwargs. Executes ONLY forward pass.
Expand All @@ -76,13 +82,19 @@ def nd_forward_and_profile(op, runs, *args, **kwargs):
"""
for _ in range(runs):
args = []
# need to create a new dictionary because can't update dict while iterating
kwargs_new = dict()
for key in kwargs:
# separate positional args from key-worded args
if key.startswith("args"):
args.append(kwargs.pop(key))
args.append(kwargs[key])
else:
kwargs_new[key]=kwargs[key]
# check for positional args
if len(args):
res = op(*args, **kwargs)
res = op(*args, **kwargs_new)
else:
res = op(**kwargs)
res = op(**kwargs_new)
nd.waitall()
return res

Expand Down

0 comments on commit b262d99

Please sign in to comment.