From a5d55dea0508f4ce5d0fefcba07c4a7fd8bd25ab Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Tue, 10 Mar 2020 14:01:49 -0700 Subject: [PATCH] [OpPerf] Consolidate array manipulation related operators (#17487) * add shape manipulation, array expanding ops * split as alias of SliceChannel * add rounding ops * add profiler param to function description * add params, improve readability of prepare op input logic, improve opperf readme * fix index merge issue * add join,split ops * minor fixes in join,split * fix if else logic issue, lint * add comment on profiler, remove all completed ops, add res of 2 ops i missed in previous PR * remove unreachable if statements --- benchmark/opperf/README.md | 8 +- benchmark/opperf/nd_operations/README.md | 107 +------ .../array_manipulation_operators.py | 264 ++++++++++++++++++ benchmark/opperf/opperf.py | 28 +- benchmark/opperf/rules/default_params.py | 26 +- benchmark/opperf/utils/op_registry_utils.py | 136 +++++++-- benchmark/opperf/utils/profiler_utils.py | 4 +- 7 files changed, 427 insertions(+), 146 deletions(-) create mode 100644 benchmark/opperf/nd_operations/array_manipulation_operators.py diff --git a/benchmark/opperf/README.md b/benchmark/opperf/README.md index 241734fdd655..9e4fb6aefb6a 100644 --- a/benchmark/opperf/README.md +++ b/benchmark/opperf/README.md @@ -50,7 +50,8 @@ Hence, in this utility, we will build the functionality to allow users and devel Provided you have MXNet installed (any version >= 1.5.1), all you need to use opperf utility is to add path to your cloned MXNet repository to the PYTHONPATH. Note: -To install MXNet, refer [Installing MXNet page](https://mxnet.apache.org/versions/master/install/index.html) +1. Currently, opperf utility requires a cloned mxnet repo. It isn't supported on PyPi binary yet. [Work in Progress] +2. To install MXNet, refer [Installing MXNet page](https://mxnet.apache.org/versions/master/install/index.html) ``` export PYTHONPATH=$PYTHONPATH:/path/to/incubator-mxnet/ @@ -72,6 +73,9 @@ python incubator-mxnet/benchmark/opperf/opperf.py --output-format json --output- 3. **dtype** : By default, `float32`. You can override and set the global dtype for all operator benchmarks. Example: --dtype float64. +4. **profiler** : `native` or `python`. By default, 'native'. You can override and set the global profiler for all operator benchmarks. Example: --profiler 'python'. +Native profiler uses MXNet C++ based built-in profiler. Python profiler uses Python package time. Generally, native profiler is used by developers and python profiler is used by users. + ## Usecase 2 - Run benchmarks for all the operators in a specific category For example, you want to run benchmarks for all NDArray Broadcast Binary Operators, Ex: broadcast_add, broadcast_mod, broadcast_pow etc., You just run the following python script. @@ -117,6 +121,7 @@ add_res = run_performance_test(nd.add, run_backward=True, dtype='float32', ctx=m inputs=[{"lhs": (1024, 1024), "rhs": (1024, 1024)}], warmup=10, runs=25) +print(add_res) ``` Output for the above benchmark run, on a CPU machine, would look something like below: @@ -143,6 +148,7 @@ add_res = run_performance_test([nd.add, nd.subtract], run_backward=True, dtype=' inputs=[{"lhs": (1024, 1024), "rhs": (1024, 1024)}], warmup=10, runs=25) +print(add_res) ``` Output for the above benchmark run, on a CPU machine, would look something like below: diff --git a/benchmark/opperf/nd_operations/README.md b/benchmark/opperf/nd_operations/README.md index 95958662ae8c..1aabce3528a3 100644 --- a/benchmark/opperf/nd_operations/README.md +++ b/benchmark/opperf/nd_operations/README.md @@ -19,103 +19,10 @@ **NOTE:** This list is AUTOGENERATED when you run opperf.py utility -0. LogisticRegressionOutput -1. broadcast_axes -2. ravel_multi_index -3. multi_sgd_mom_update -4. smooth_l1 -5. scatter_nd -6. reshape -7. one_hot -8. linalg_potri -10. multi_sgd_update -12. Convolution_v1 -13. repeat -14. Custom -15. softmax_cross_entropy -16. SwapAxis -17. norm -18. Softmax -20. fill_element_0index -21. cast -22. UpSampling -23. BatchNorm_v1 -24. CTCLoss -25. LRN -26. cast_storage -27. pick -28. GridGenerator -29. sample_multinomial -30. Activation -31. LinearRegressionOutput -32. Pooling_v1 -34. Crop -35. ElementWiseSum -36. diag -37. Reshape -38. Pad -39. linalg_gemm2 -40. crop -43. RNN -45. SoftmaxOutput -46. linalg_extractdiag -48. SequenceLast -51. SequenceReverse -53. SVMOutput -54. linalg_trsm -55. where -56. SoftmaxActivation -58. slice -59. linalg_gelqf -60. softmin -61. linalg_gemm -62. BilinearSampler -64. choose_element_0index -65. tile -67. gather_nd -69. SequenceMask -70. reshape_like -71. slice_axis -72. stack -74. khatri_rao -75. multi_mp_sgd_update -76. linalg_sumlogdiag -77. broadcast_to -78. IdentityAttachKLSparseReg -80. SpatialTransformer -81. Concat -82. uniform -83. InstanceNorm -84. expand_dims -85. multi_mp_sgd_mom_update -86. reverse -87. add_n -88. clip -89. ctc_loss -90. shape_array -91. unravel_index -92. linalg_potrf -93. Cast -94. broadcast_like -95. Embedding -96. linalg_makediag -98. linalg_syrk -99. squeeze -101. ROIPooling -103. SliceChannel -104. slice_like -106. linalg_maketrian -108. pad -109. LayerNorm -110. split -111. MAERegressionOutput -112. Correlation -114. batch_take -115. L2Normalization -116. broadcast_axis -117. linalg_trmm -118. linalg_extracttrian -119. normal -120. take -121. MakeLoss -124. concat +0. preloaded_multi_sgd_update +1. multi_mp_sgd_mom_update +2. IdentityAttachKLSparseReg +3. unravel_index +4. mp_lamb_update_phase1 +5. mp_lamb_update_phase2 +6. scatter_nd diff --git a/benchmark/opperf/nd_operations/array_manipulation_operators.py b/benchmark/opperf/nd_operations/array_manipulation_operators.py new file mode 100644 index 000000000000..1bbc1840df85 --- /dev/null +++ b/benchmark/opperf/nd_operations/array_manipulation_operators.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +from mxnet import nd +from benchmark.opperf.utils.benchmark_utils import run_performance_test +from benchmark.opperf.utils.common_utils import merge_map_list +from benchmark.opperf.rules.default_params import MX_OP_MODULE + +from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks +from benchmark.opperf.utils.op_registry_utils import get_all_rearrange_operators, \ + get_all_shape_operators, get_all_expanding_operators, get_all_rounding_operators + +"""Performance benchmark tests for MXNet Array Manipulation Operators. + +Array Rearrange Operators +1. transpose +2. swapaxes (alias SwapAxis) +3. flip (alias reverse) +4. depth_to_space +5. space_to_depth + +Array Shape Manipulation Operators +1. split (alias SliceChannel) +2. diag +3. reshape +4. reshape_like +5. size_array +6. shape_array + +Array Expanding Operators +1. broadcast_axes (alias broadcast_axis) +2. broadcast_to +3. broadcast_like +4. repeat +5. tile +6. pad +7. expand_dims + + +Array Rounding Operators +1. round +2. rint +3. fix +4. floor +5. ceil +6. trunc + +Array Join & Split Operators +1. concat +2. split +3. stack + +""" + + +def run_rearrange_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', int64_tensor='off', warmup=25, runs=100): + """Runs benchmarks with the given context and precision (dtype) for all the + rearrange operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Type of Profiler to use (native/python) + int64_tensor: str, default 'off' + Input tensor size to use for tests (if on, dimensions >= 2**32) + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ + # Fetch all array rearrange operators + mx_rearrange_ops = get_all_rearrange_operators() + + # Run benchmarks + mx_rearrange_op_results = run_op_benchmarks(mx_rearrange_ops, dtype, ctx, profiler, warmup, runs) + return mx_rearrange_op_results + + +def run_shape_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', int64_tensor='off', warmup=25, runs=100): + """Runs benchmarks with the given context and precision (dtype) for all the + array shape operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Type of Profiler to use (native/python) + int64_tensor: str, default 'off' + Input tensor size to use for tests (if on, dimensions >= 2**32) + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ + # Fetch all array shape operators + mx_shape_ops = get_all_shape_operators() + + # Run benchmarks + mx_shape_op_results = run_op_benchmarks(mx_shape_ops, dtype, ctx, profiler, warmup, runs) + return mx_shape_op_results + + +def run_expanding_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', int64_tensor='off', warmup=25, runs=100): + """Runs benchmarks with the given context and precision (dtype) for all the + array expanding operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Type of Profiler to use (native/python) + int64_tensor: str, default 'off' + Input tensor size to use for tests (if on, dimensions >= 2**32) + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ + # Fetch all array expanding operators + mx_expanding_ops = get_all_expanding_operators() + + # Run benchmarks + mx_expanding_op_results = run_op_benchmarks(mx_expanding_ops, dtype, ctx, profiler, warmup, runs) + return mx_expanding_op_results + + +def run_rounding_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', int64_tensor='off', warmup=25, runs=100): + """Runs benchmarks with the given context and precision (dtype) for all the + array rounding operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Type of Profiler to use (native/python) + int64_tensor: str, default 'off' + Input tensor size to use for tests (if on, dimensions >= 2**32) + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ + # Fetch all array rounding operators + mx_rounding_ops = get_all_rounding_operators() + + # Run benchmarks + mx_rounding_op_results = run_op_benchmarks(mx_rounding_ops, dtype, ctx, profiler, warmup, runs) + return mx_rounding_op_results + + +def run_join_split_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', int64_tensor='off', warmup=25, runs=100): + """Runs benchmarks with the given context and precision (dtype) for all the + join & split operators in MXNet. + + Parameters + ---------- + ctx: mx.ctx + Context to run benchmarks + dtype: str, default 'float32' + Precision to use for benchmarks + profiler: str, default 'native' + Type of Profiler to use (native/python) + int64_tensor: str, default 'off' + Input tensor size to use for tests (if on, dimensions >= 2**32) + warmup: int, default 25 + Number of times to run for warmup + runs: int, default 100 + Number of runs to capture benchmark results + + Returns + ------- + Dictionary of results. Key -> Name of the operator, Value -> Benchmark results. + + """ + # backward not supported for all 3 ops - concat, stack, split + # concat + concat_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "concat")], + run_backward=False, + dtype=dtype, + ctx=ctx, + profiler=profiler, + inputs=[{"args0":nd.random_normal(shape=(100,100)), + "args1":nd.random_normal(shape=(100,100)), + "args2":nd.random_normal(shape=(100,100))} + ], + warmup=warmup, + runs=runs) + + # split + split_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "split")], + run_backward=False, + dtype=dtype, + ctx=ctx, + profiler=profiler, + inputs=[{"data": (1024, 1024), "num_outputs": 2}, + {"data": (10000, 1), "num_outputs": 1}, + {"data": (10000, 100), "num_outputs": 10} + ], + warmup=warmup, + runs=runs) + + # stack + stack_benchmark_res = run_performance_test([getattr(MX_OP_MODULE, "stack")], + run_backward=False, + dtype=dtype, + ctx=ctx, + profiler=profiler, + inputs=[{"args0":nd.random_normal(shape=(100,100)), + "args1":nd.random_normal(shape=(100,100)), + "args2":nd.random_normal(shape=(100,100))} + ], + warmup=warmup, + runs=runs) + mx_join_split_op_results = merge_map_list(concat_benchmark_res + split_benchmark_res + stack_benchmark_res) + return mx_join_split_op_results diff --git a/benchmark/opperf/opperf.py b/benchmark/opperf/opperf.py index c0ac7b7dcd98..47bd970f930d 100755 --- a/benchmark/opperf/opperf.py +++ b/benchmark/opperf/opperf.py @@ -40,11 +40,13 @@ run_convolution_operators_benchmarks, run_transpose_convolution_operators_benchmarks from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks from benchmark.opperf.nd_operations.nn_optimizer_operators import run_optimizer_operators_benchmarks -from benchmark.opperf.nd_operations.array_rearrange import run_rearrange_operators_benchmarks from benchmark.opperf.nd_operations.indexing_routines import run_indexing_routines_benchmarks from benchmark.opperf.nd_operations.nn_loss_operators import run_loss_operators_benchmarks from benchmark.opperf.nd_operations.linalg_operators import run_linalg_operators_benchmarks from benchmark.opperf.nd_operations.misc_operators import run_mx_misc_operators_benchmarks +from benchmark.opperf.nd_operations.array_manipulation_operators import run_rearrange_operators_benchmarks, \ + run_shape_operators_benchmarks, run_expanding_operators_benchmarks, run_rounding_operators_benchmarks, \ + run_join_split_operators_benchmarks from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file from benchmark.opperf.utils.op_registry_utils import get_operators_with_no_benchmark, \ @@ -87,11 +89,23 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n # Run all Sorting and Searching operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_sorting_searching_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + # Run all Indexing routines benchmarks with default input values + mxnet_operator_benchmark_results.append(run_indexing_routines_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + # Run all Array Rearrange operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_rearrange_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) - # Run all Indexing routines benchmarks with default input values - mxnet_operator_benchmark_results.append(run_indexing_routines_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + # Run all Array Shape Manipulation operations benchmarks with default input values + mxnet_operator_benchmark_results.append(run_shape_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + + # Run all Array Expansion operations benchmarks with default input values + mxnet_operator_benchmark_results.append(run_expanding_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + + # Run all Array Rounding operations benchmarks with default input values + mxnet_operator_benchmark_results.append(run_rounding_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) + + # Run all Array Join & Split operations benchmarks with default input values + mxnet_operator_benchmark_results.append(run_join_split_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) # ************************ MXNET NN OPERATOR BENCHMARKS **************************** @@ -109,13 +123,13 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n # Run all Optimizer operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_optimizer_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) - + # Run all Transpose Convolution operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_transpose_convolution_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) # Run all NN loss operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_loss_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) - + # Run all Miscellaneous operations benchmarks with default input values mxnet_operator_benchmark_results.append(run_mx_misc_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler, int64_tensor=int64_tensor, warmup=warmup, runs=runs)) @@ -164,7 +178,7 @@ def main(): help='Use built-in CPP profiler (native) or Python' 'time module.' 'Valid Inputs - native, python') - + parser.add_argument('--int64-tensor', type=str, default='off', help='Run performance tests with large tensor input' 'data (dimension >= 2**32) or standard input data.' @@ -176,7 +190,7 @@ def main(): parser.add_argument('-r', '--runs', type=int, default=100, help='Number of runs to capture benchmark results.' - 'Valid Inputs - positive integers') + 'Valid Inputs - positive integers') args = parser.parse_args() logging.info("Running MXNet operator benchmarks with the following options: {args}".format(args=args)) diff --git a/benchmark/opperf/rules/default_params.py b/benchmark/opperf/rules/default_params.py index 81a1ee5859ec..2c527463aa51 100644 --- a/benchmark/opperf/rules/default_params.py +++ b/benchmark/opperf/rules/default_params.py @@ -107,10 +107,6 @@ DEFAULT_LAM_SE_LARGE_TENSOR = [(2**32 + 1,)] DEFAULT_SHAPE_SU_LARGE_TENSOR = [(2**32,)] -# For reduction operators -# NOTE: Data used is DEFAULT_DATA -DEFAULT_AXIS_SHAPE = [(), 0, (0, 1)] - # For sorting and searching operators # NOTE: Data used is DEFAULT_DATA DEFAULT_AXIS = [0] @@ -306,12 +302,21 @@ DEFAULT_R2_LARGE_TENSOR = [(1,)] DEFAULT_DELTA_LARGE_TENSOR = [(2**16, 2**16), (2**32, 1), (2**25, 2**7)] -# For rearrange operators -# NOTE: Data needs to be a 4D tensor for operators like space_to_depth and depth_to_space +# For array manipulation operators +# NOTE: Data needs to be a 4D tensor for operators like space_to_depth, depth_to_space etc # Hence below we append 4d to mark the difference. # For depth_to_space, dimension 3 needs to be a multiple of 'block' and 1 should be a multiple of `block^2` DEFAULT_DATA_4d = [(1, 4, 2, 4), (10, 25, 10, 100)] DEFAULT_BLOCK_SIZE = [2, 5] +DEFAULT_NUM_OUTPUTS = [1] +DEFAULT_PAD_WIDTH_4d = [(0, 0, 0, 0, 1, 1, 1, 1)] +DEFAULT_MODE_4d = ["constant"] +DEFAULT_REPEATS = [2] + +# broadcast_axis needs input array with atleast 1 dim of size 1 +# since axis is 0 (default) size(dim0)=1 +DEFAULT_DATA_DIM1 = [(1, 1024), (1, 1), (1, 100)] +DEFAULT_SIZE = [2] DEFAULT_DATA_4d_LARGE_TENSOR = [(1, 4, 2, 2**29), (1,2**4,2**4,2**24)] DEFAULT_BLOCK_SIZE_LARGE_TENSOR = [2, 4] @@ -416,7 +421,6 @@ "p": DEFAULT_P, "k_nd": DEFAULT_K_ND, "p_nd": DEFAULT_P_ND, - "axis_shape": DEFAULT_AXIS_SHAPE, "axis": DEFAULT_AXIS, "weight" : DEFAULT_WEIGHT, "weight32" : DEFAULT_WEIGHT, @@ -464,7 +468,13 @@ "data_3d": DEFAULT_DATA_3d, "label_smce": DEFAULT_LABEL_SMCE, "label": DEFAULT_LABEL, - "index": DEFAULT_INDEX, + "num_outputs": DEFAULT_NUM_OUTPUTS, + "data_dim1": DEFAULT_DATA_DIM1, + "size": DEFAULT_SIZE, + "mode_4d": DEFAULT_MODE_4d, + "pad_width_4d": DEFAULT_PAD_WIDTH_4d, + "repeats": DEFAULT_REPEATS, + "reps": DEFAULT_REPEATS, "grid": DEFAULT_GRID, "data_bilinearsampler": DEFAULT_DATA_BILINEAR, "transform_type": DEFAULT_TRANSFORM_TYPE, diff --git a/benchmark/opperf/utils/op_registry_utils.py b/benchmark/opperf/utils/op_registry_utils.py index b27b8e4e73b5..65eb6aab2aac 100644 --- a/benchmark/opperf/utils/op_registry_utils.py +++ b/benchmark/opperf/utils/op_registry_utils.py @@ -112,8 +112,8 @@ def prepare_op_inputs(arg_params, arg_values): def prepare_op_inputs(op, arg_params, int64_tensor): inputs = [] - # 4d tensor is needed only by following two ops - ops_4d = {'depth_to_space', 'space_to_depth'} + # 4d tensor is needed by following ops + ops_4d = ['depth_to_space', 'space_to_depth', 'pad'] # 3d tensor is needed by following ops ops_3d = {'CTCLoss', 'ctc_loss'} @@ -135,6 +135,9 @@ def prepare_op_inputs(op, arg_params, int64_tensor): int_only = {'random_randint'} float_only = {'log_softmax', 'softmax', 'softmin'} + # following ops need atleast 1 dim of size 1 + ops_dim1 = ['broadcast_axis', 'broadcast_like', 'broadcast_to', 'broadcast_axes'] + if int64_tensor == 'on': default_inputs = DEFAULTS_INPUTS_LARGE_TENSOR custom_data |= custom_data_int64 @@ -149,30 +152,43 @@ def prepare_op_inputs(op, arg_params, int64_tensor): # added a logic for using float only dtype as input for ops that take only floats # same for randint (which is the only op that takes only int as input) # rest all operators take int as well as float - if op in int_only and arg_name == "dtype": - arg_values[arg_name] = default_inputs["dtype_int"] - elif (op.startswith(('random','sample')) or op in float_only) and arg_name == "dtype": - arg_values[arg_name] = default_inputs["dtype_float"] - elif "NDArray" in arg_type and op == "ravel_multi_index": - arg_values[arg_name] = default_inputs["ravel_data"] - elif op in custom_data and arg_name + "_" + op.lower() in default_inputs: - arg_values[arg_name] = default_inputs[arg_name + "_" + op.lower()] - elif "NDArray" in arg_type and arg_name + "_nd" in default_inputs: - arg_values[arg_name] = default_inputs[arg_name + "_nd"] - elif "NDArray" in arg_type and op in ops_4d and arg_name + "_4d" in default_inputs: - arg_values[arg_name] = default_inputs[arg_name + "_4d"] - elif "NDArray" in arg_type and op in ops_3d and arg_name + "_3d" in default_inputs: - arg_values[arg_name] = default_inputs[arg_name + "_3d"] - elif "NDArray" in arg_type and op == 'softmax_cross_entropy': - arg_values[arg_name] = default_inputs[arg_name + "_smce"] - elif arg_name in default_inputs: - arg_values[arg_name] = default_inputs[arg_name] - elif "float" in arg_type and arg_name + "_float" in default_inputs: - arg_values[arg_name] = default_inputs[arg_name + "_float"] - elif "Shape" in arg_type and arg_name + "_shape" in default_inputs: - # This is for cases where in some ops 'axis' is Int in some ops a shape tuple. - # Ex: axis in sum is shape, axis in sort is int. - arg_values[arg_name] = default_inputs[arg_name + "_shape"] + if "NDArray" in arg_type: + if op in int_only and arg_name == "dtype": + arg_values[arg_name] = DEFAULTS_INPUTS["dtype_int"] + elif (op.startswith(('random','sample')) or op in float_only) and arg_name == "dtype": + arg_values[arg_name] = DEFAULTS_INPUTS["dtype_float"] + elif op == "ravel_multi_index": + arg_values[arg_name] = DEFAULTS_INPUTS["ravel_data"] + elif op in custom_data and arg_name + "_" + op.lower() in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_" + op.lower()] + elif arg_name + "_nd" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_nd"] + elif op in ops_3d and arg_name + "_3d" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_3d"] + elif op == 'softmax_cross_entropy': + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_smce"] + elif op in ops_4d and arg_name + "_4d" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_4d"] + elif op in ops_dim1 and arg_name + "_dim1" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_dim1"] + # default case + elif arg_name in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name] + else: + # arg_type is not NDArray + if op in int_only and arg_name == "dtype": + arg_values[arg_name] = DEFAULTS_INPUTS["dtype_int"] + elif (op.startswith(('random','sample')) or op in float_only) and arg_name == "dtype": + arg_values[arg_name] = DEFAULTS_INPUTS["dtype_float"] + elif op in custom_data and arg_name + "_" + op.lower() in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_" + op.lower()] + elif op in ops_4d and arg_name + "_4d" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_4d"] + elif op in ops_dim1 and arg_name + "_dim1" in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name + "_dim1"] + #default case + elif arg_name in DEFAULTS_INPUTS: + arg_values[arg_name] = DEFAULTS_INPUTS[arg_name] # Number of different inputs we want to use to test # the operator @@ -434,7 +450,8 @@ def get_all_rearrange_operators(): ------- {"operator_name": {"has_backward", "nd_op_handle", "params"}} """ - rearrange_ops = {'transpose','swapaxes','flip','depth_to_space','space_to_depth'} + rearrange_ops = ['transpose', 'swapaxes', 'flip', 'depth_to_space', + 'space_to_depth', 'SwapAxis', 'reverse'] # Get all mxnet operators mx_operators = _get_all_mxnet_operators() @@ -480,7 +497,7 @@ def get_all_indexing_routines(): """ indexing_routines = {'slice', 'slice_axis', 'slice_like', 'take', 'one_hot', 'where', 'ravel_multi_index', 'gather_nd', 'pick'} - + # Get all mxnet operators mx_operators = _get_all_mxnet_operators() @@ -512,6 +529,69 @@ def get_all_loss_operators(): return loss_mx_operators +def get_all_shape_operators(): + """Gets all array shape manipulation operators registered with MXNet. + + Returns + ------- + {"operator_name": {"has_backward", "nd_op_handle", "params"}} + """ + shape_ops = ['split', 'SliceChannel', 'diag', 'reshape', + 'reshape_like', 'size_array', 'shape_array'] + + # Get all mxnet operators + mx_operators = _get_all_mxnet_operators() + + # Filter for Array Shape Manipulation operators + shape_mx_operators = {} + for op_name, op_params in mx_operators.items(): + if op_name in shape_ops: + shape_mx_operators[op_name] = mx_operators[op_name] + return shape_mx_operators + + +def get_all_expanding_operators(): + """Gets all array expanding operators registered with MXNet. + + Returns + ------- + {"operator_name": {"has_backward", "nd_op_handle", "params"}} + """ + expanding_ops = ['broadcast_axes', 'broadcast_axis', 'broadcast_to', 'broadcast_like', + 'repeat', 'tile', 'pad', 'expand_dims'] + + # Get all mxnet operators + mx_operators = _get_all_mxnet_operators() + + # Filter for Array Expanding operators + expanding_mx_operators = {} + for op_name, op_params in mx_operators.items(): + if op_name in expanding_ops: + expanding_mx_operators[op_name] = mx_operators[op_name] + return expanding_mx_operators + + +def get_all_rounding_operators(): + """Gets all array rounding operators registered with MXNet. + + Returns + ------- + {"operator_name": {"has_backward", "nd_op_handle", "params"}} + """ + rounding_ops = ['round', 'rint', 'fix', 'floor', + 'ceil', 'trunc'] + + # Get all mxnet operators + mx_operators = _get_all_mxnet_operators() + + # Filter for Array Rounding operators + rounding_mx_operators = {} + for op_name, op_params in mx_operators.items(): + if op_name in rounding_ops: + rounding_mx_operators[op_name] = mx_operators[op_name] + return rounding_mx_operators + + def get_operators_with_no_benchmark(operators_with_benchmark): """Gets all MXNet operators with not benchmark. diff --git a/benchmark/opperf/utils/profiler_utils.py b/benchmark/opperf/utils/profiler_utils.py index 76ab90e5c631..1cb29a8fdec8 100644 --- a/benchmark/opperf/utils/profiler_utils.py +++ b/benchmark/opperf/utils/profiler_utils.py @@ -47,10 +47,10 @@ def _get_operator_profile(operator_name, operator_profile_results): # alias map : dictionary of the form {"alias" : "registered_name"} # allows to retrieve alias operator profile from the profiler results - # TODO handling - "identity" : "_copy" alias_map = {"broadcast_plus": "broadcast_add", "broadcast_minus": "broadcast_sub", "flatten": "Flatten", "max_axis": "max", "Custom": "CustomAddOne", "swapaxes": "SwapAxis", "flip": "reverse", "reshape": "Reshape", "crop": "slice", "sum_axis": "sum", "min_axis": "min", "ctc_loss": "CTCLoss", - "fill_element_0index": "TernaryOp", "identity": "_copy", "ElementWiseSum": "add_n", "choose_element_0index": "pick", "stop_gradient": "BlockGrad"} + "fill_element_0index": "TernaryOp", "identity": "_copy", "ElementWiseSum": "add_n", "choose_element_0index": "pick", "stop_gradient": "BlockGrad", + "broadcast_axes": "broadcast_axis"} op_name = None