Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Opperf] Add optimizer update operator benchmarks to opperf #15522

Merged
merged 12 commits into from
Jul 28, 2019
17 changes: 1 addition & 16 deletions benchmark/opperf/nd_operations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@
0. LogisticRegressionOutput
1. broadcast_axes
2. ravel_multi_index
3. multi_sgd_mom_update
4. smooth_l1
5. scatter_nd
6. reshape
7. one_hot
8. linalg_potri
9. mp_sgd_update
10. multi_sgd_update
11. signum_update
12. Convolution_v1
13. repeat
14. Custom
15. softmax_cross_entropy
16. SwapAxis
17. norm
18. Softmax
19. rmspropalex_update
20. fill_element_0index
21. cast
22. UpSampling
Expand All @@ -52,19 +47,16 @@
30. Activation
31. LinearRegressionOutput
32. Pooling_v1
33. ftml_update
34. Crop
35. ElementWiseSum
36. diag
37. Reshape
38. Pad
39. linalg_gemm2
40. crop
41. rmsprop_update
43. RNN
45. SoftmaxOutput
46. linalg_extractdiag
47. sgd_mom_update
48. SequenceLast
50. flip
51. SequenceReverse
Expand All @@ -73,13 +65,11 @@
54. linalg_trsm
55. where
56. SoftmaxActivation
57. signsgd_update
58. slice
59. linalg_gelqf
60. softmin
61. linalg_gemm
62. BilinearSampler
63. mp_sgd_mom_update
64. choose_element_0index
65. tile
66. space_to_depth
Expand All @@ -89,7 +79,6 @@
71. slice_axis
72. stack
74. khatri_rao
75. multi_mp_sgd_update
76. linalg_sumlogdiag
77. broadcast_to
78. IdentityAttachKLSparseReg
Expand All @@ -98,7 +87,6 @@
82. uniform
83. InstanceNorm
84. expand_dims
85. multi_mp_sgd_mom_update
86. reverse
87. add_n
88. clip
Expand All @@ -114,7 +102,6 @@
98. linalg_syrk
99. squeeze
101. ROIPooling
102. ftrl_update
103. SliceChannel
104. slice_like
105. depth_to_space
Expand All @@ -132,6 +119,4 @@
119. normal
120. take
121. MakeLoss
122. sgd_update
123. adam_update
124. concat
124. concat
68 changes: 68 additions & 0 deletions benchmark/opperf/nd_operations/nn_optimizer_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
from benchmark.opperf.utils.op_registry_utils import get_all_optimizer_operators

"""Performance benchmark tests for MXNet Neural Network Optimizer Update Operators.

1. Stochastic Gradient Descent (SGD)
1.1 multi_sgd_mom_update
1.2 mp_sgd_update
1.3 multi_sgd_update
1.4 sgd_mom_update
1.5 signsgd_update
1.6 mp_sgd_mom_update
1.7 multi_mp_sgd_update
1.8 multi_mp_sgd_mom_update
1.9 sgd_update
3. signum_update
4. rmspropalex_update
5. ftml_update
6. rmsprop_update
7. ftrl_update
8. adam_update
"""


def run_optimizer_operators_benchmarks(ctx=mx.cpu(), dtype='float32', warmup=25, runs=100):
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
"""Runs benchmarks with the given context and precision (dtype) for all the neural network
optimizer update operators in MXNet.

Parameters
----------
ctx: mx.ctx
Context to run benchmarks
dtype: str, default 'float32'
Precision to use for benchmarks
warmup: int, default 25
Number of times to run for warmup
runs: int, default 100
Number of runs to capture benchmark results

Returns
-------
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.

"""
# Fetch all optimizer operators
mx_optimizer_ops = get_all_optimizer_operators()

# Run benchmarks
mx_optimizer_op_results = run_op_benchmarks(mx_optimizer_ops, dtype, ctx, warmup, runs)
return mx_optimizer_op_results
3 changes: 3 additions & 0 deletions benchmark/opperf/opperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from benchmark.opperf.nd_operations.nn_conv_operators import run_pooling_operators_benchmarks, \
run_convolution_operators_benchmarks, run_transpose_convolution_operators_benchmarks
from benchmark.opperf.nd_operations.nn_basic_operators import run_nn_basic_operators_benchmarks
from benchmark.opperf.nd_operations.nn_optimizer_operators import run_optimizer_operators_benchmarks

from benchmark.opperf.utils.common_utils import merge_map_list, save_to_file
from benchmark.opperf.utils.op_registry_utils import get_operators_with_no_benchmark, \
Expand Down Expand Up @@ -92,6 +93,8 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32'):
# Run all Convolution operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))

# Run all Optimizer operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_optimizer_operators_benchmarks(ctx=ctx, dtype=dtype))
# Run all Transpose Convolution operations benchmarks with default input values
mxnet_operator_benchmark_results.append(run_transpose_convolution_operators_benchmarks(ctx=ctx, dtype=dtype))

Expand Down
40 changes: 38 additions & 2 deletions benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@
# NOTE: Data used is DEFAULT_DATA
DEFAULT_AXIS = [0]

# For optimizer operators
DEFAULT_WEIGHT = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_GRAD = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_MOM = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_MEAN = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_VAR = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_N = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_LR = [[0.1,0.5,0.9]]
DEFAULT_GAMMA_1 = [[0.1,0.5,0.9]]
DEFAULT_GAMMA_2 = [[0.1,0.5,0.9]]
DEFAULT_EPSILON = [[1e-08]]
DEFAULT_BETA_1 = [[0.1,0.5,0.9]]
DEFAULT_BETA_2 = [[0.1,0.5,0.9]]
DEFAULT_T = [[1,5]]
DEFAULT_RESCALE_GRAD = [[0.4, 0.77]]
DEFAULT_CLIP_GRADIENT = [[-1.0,0.8]]
DEFAULT_CLIP_WEIGHTS = [[-1.0,0.8]]
DEFAULT_LAZY_UPDATE = [[0,1]]

# Default Inputs. MXNet Op Param Name to Default Input mapping
DEFAULTS_INPUTS = {"data": DEFAULT_DATA,
"lhs": DEFAULT_LHS,
Expand All @@ -81,7 +100,23 @@
"k_nd": DEFAULT_K_ND,
"p_nd": DEFAULT_P_ND,
"axis_shape": DEFAULT_AXIS_SHAPE,
"axis": DEFAULT_AXIS}
"axis": DEFAULT_AXIS,
"weight" : DEFAULT_WEIGHT,
"grad" : DEFAULT_GRAD,
"mean" : DEFAULT_MEAN,
"var" : DEFAULT_VAR,
"mom" : DEFAULT_MOM,
"n" : DEFAULT_N,
"lr" : DEFAULT_LR,
"gamma1" : DEFAULT_GAMMA_1,
"gamma2" : DEFAULT_GAMMA_2,
"epsilon" : DEFAULT_EPSILON,
"beta1" : DEFAULT_BETA_1,
"beta2" : DEFAULT_BETA_2,
"t" : DEFAULT_T,
"rescale_grad" : DEFAULT_RESCALE_GRAD,
"clip_grad" : DEFAULT_CLIP_GRADIENT,
"lazy_update" : DEFAULT_LAZY_UPDATE}

# These are names of MXNet operator parameters that is of type NDArray.
# We maintain this list to automatically recognize these parameters are to be
Expand All @@ -90,4 +125,5 @@
# can just say shape of the tensor, and we automatically create Tensors.
PARAMS_OF_TYPE_NDARRAY = ["lhs", "rhs", "data", "base", "exp",
"mu", "sigma", "lam", "alpha", "beta", "gamma", "k", "p",
"low", "high", "weight", "bias", "moving_mean", "moving_var"]
"low", "high", "weight", "bias", "moving_mean", "moving_var",
"weight", "grad", "mean", "var", "mom", "n"]
22 changes: 22 additions & 0 deletions benchmark/opperf/utils/op_registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,28 @@ def get_all_reduction_operators():
return reduction_mx_operators


def get_all_optimizer_operators():
"""Gets all Optimizer operators registered with MXNet.

Returns
-------
{"operator_name": {"has_backward", "nd_op_handle", "params"}}
"""
optimizer_ops = ['multi_sgd_mom_update', 'mp_sgd_update', 'multi_sgd_update', 'signum_update',
'rmspropalex_update', 'ftml_update', 'rmsprop_update', 'sgd_mom_update', 'signsgd_update',
'mp_sgd_mom_update', 'multi_mp_sgd_update', 'multi_mp_sgd_mom_update', 'ftrl_update', 'sgd_update',
'adam_update']

# Get all mxnet operators
mx_operators = _get_all_mxnet_operators()

# Filter for Optimizer operators
optimizer_mx_operators = {}
for op_name, op_params in mx_operators.items():
if op_name in optimizer_ops and op_name not in unique_ops:
optimizer_mx_operators[op_name] = mx_operators[op_name]
return optimizer_mx_operators

def get_all_sorting_searching_operators():
"""Gets all Sorting and Searching operators registered with MXNet.

Expand Down