Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add an utility for operator benchmarks #14977

Merged
merged 30 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
703e140
Initial end to end working skeleton
sandeep-krishnamurthy May 16, 2019
058f194
Add skeleton for all other operator benchmarks
sandeep-krishnamurthy May 16, 2019
d0df2d7
Add Gluon Conv2D benchmarks
sandeep-krishnamurthy May 16, 2019
8374365
Add readme and user guide, example result
sandeep-krishnamurthy May 17, 2019
0862509
Add licence headers to all files
sandeep-krishnamurthy May 17, 2019
79891ed
fix RAT licence check issues
sandeep-krishnamurthy May 17, 2019
06e5c1c
Add ability to group list of operators with same inputs to benchmark.…
sandeep-krishnamurthy May 20, 2019
6f9865f
Add comparison operator tests and more arithmetic operators
sandeep-krishnamurthy May 20, 2019
91f9c14
Remove Gluon block and instead use only low level NDArray operators
sandeep-krishnamurthy May 20, 2019
0fc26a1
Add GEMM operators
sandeep-krishnamurthy May 20, 2019
a12e29e
Add logical operations
sandeep-krishnamurthy May 21, 2019
7626b1e
Add support to export results as markdown
sandeep-krishnamurthy May 21, 2019
f9b6cd7
Add ability to query MXNet operator registry for operators and run be…
sandeep-krishnamurthy May 23, 2019
976ea47
Delete duplicate arithmetic, logical, comparison operator benchmarks.…
sandeep-krishnamurthy May 24, 2019
e8cc1e8
Add binary elementwise operator benchmarks
sandeep-krishnamurthy May 24, 2019
c2bf1a3
Adding basic logging mechanisms
sandeep-krishnamurthy May 25, 2019
eef953c
Address review comments
sandeep-krishnamurthy Jun 3, 2019
a6d3d98
Few formatting issues resolved
sandeep-krishnamurthy Jun 3, 2019
3b8b2b4
Add unary operators. Remove stale todo files
sandeep-krishnamurthy Jun 3, 2019
180ca72
Fix sanity tests
sandeep-krishnamurthy Jun 3, 2019
f45b39a
Remove mention of hypothesis
sandeep-krishnamurthy Jun 5, 2019
790be23
Add random sampling operator benchmarks.
sandeep-krishnamurthy Jun 7, 2019
60ba9b2
Add all activation operator benchmarks
sandeep-krishnamurthy Jun 7, 2019
0374e76
Add Pooling operator benchmarks
sandeep-krishnamurthy Jun 7, 2019
3752441
Add Convolution operator benchmarks
sandeep-krishnamurthy Jun 8, 2019
bceee47
Add Reduction operator benchmarks
sandeep-krishnamurthy Jun 10, 2019
a63b53c
Add an utility to get list of operator not benchmarked
sandeep-krishnamurthy Jun 10, 2019
afb4803
Autogenerate list of operators to cover
sandeep-krishnamurthy Jun 10, 2019
7cb3ef7
Add basic nn operators - FC, dropout, batchnorm
sandeep-krishnamurthy Jun 11, 2019
5fbc925
Add CPU result file
sandeep-krishnamurthy Jun 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
174 changes: 174 additions & 0 deletions benchmark/opperf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# MXNet Operator Performance Benchmarks

A Python utility for benchmarking and profiling individual MXNet operator execution.
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved

With this utility, for each MXNet operator you can get the following details:

**Timing**
1. Forward execution time
2. Backward execution time
3. Time spent for memory management

**Memory**
1. Total memory allocated

# Motivation

Benchmarks are usually done end-to-end for a given Network Architecture. For example: ResNet-50 benchmarks on ImageNet data. This is good measurement of overall performance and health of a deep learning framework. However, it is important to note the following important factors:
1. Users use a lot more operators that are not part of a standard network like ResNet. Example: Tensor manipulation operators like mean, max, topk, argmax, sort etc.
2. A standard Network Architecture like ResNet-50 is made up of many operators Ex: Convolution2D, Softmax, Dense and more. Consider the following scenarios:
1. We improved the performance of Convolution2D operator, but due to a bug, Softmax performance went down. Overall, we may observe end to end benchmarks are running fine, we may miss out the performance degradation of a single operator which can accumulate and become untraceable.
2. You need to see in a given network, which operator is taking maximum time and plan optimization work. With end to end benchmarks, it is hard to get more fine grained numbers at operator level.
3. We need to know on different hardware infrastructure (Ex: CPU with MKLDNN, GPU with NVIDIA CUDA and cuDNN) how different operators performs. With these details, we can plan the optimization work at operator level, which could exponentially boost up end to end performance.
4. You want to have nightly performance tests across all operators in a deep learning framework to catch regressions early.
5. We can integrate this framework with a CI/CD system to run per operator performance tests for PRs. Example: When a PR modifies the kernel of TransposeConv2D, we can run benchmarks of TransposeConv2D operator to verify performance.

Hence, in this utility, we will build the functionality to allow users and developers of deep learning frameworks to easily run benchmarks for individual operators.

# How to use

## Prerequisites

This utility uses MXNet profiler under the hood to fetch compute and memory metrics. Hence, you need to build MXNet with `USE_PROFILER=1` flag.

Make sure to build the flavor of MXNet, for example - with/without MKL, with CUDA 9 or 10.1 etc., on which you would like to measure operator performance.

## Usecase 1 - Run benchmarks for all the operators

Below command runs all the MXNet operators (NDArray) benchmarks with default inputs and saves the final result as JSON in the given file.
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved

```
python incubator-mxnet/benchmark/opperf/opperf.py --output-format json --output-file mxnet_operator_benchmark_results.json
```

**Other Supported Options:**

1. **output-format** : `json` or `md` for markdown file output or csv.

2. **ctx** : `cpu` or `gpu`. By default, cpu on CPU machine, gpu(0) on GPU machine. You can override and set the global context for all operator benchmarks. Example: --ctx gpu(2).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if there are multiple GPUs? Does the profiler generate results per device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this single operator benchmarks, it runs only on one device and profiler output is only for that device.


3. **dtype** : By default, `float32`. You can override and set the global dtype for all operator benchmarks. Example: --dtype float64.

## Usecase 2 - Run benchmarks for all the operators in a specific category

For example, you want to run benchmarks for all NDArray Arithmetic Operators, you just run the following python script.

```
#! /usr/bin/python
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
from benchmark.opperf.tensor_operations.arithmetic_operations import run_arithmetic_operators_benchmarks

# Run all Arithmetic operations benchmarks with default input values
print(run_arithmetic_operators_benchmarks())
```

Output for the above benchmark run, on a CPU machine, would look something like below:

```
{'subtract': [{'avg_time_forward_broadcast_sub': 5.5137,
'avg_time_mem_alloc_cpu/0': 207618.0469,
'avg_time_backward_broadcast_sub': 7.2976,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}}
],
'add': [{'avg_time_mem_alloc_cpu/0': 207618.0469,
'avg_time_forward_broadcast_add': 4.309,
'avg_time_backward_broadcast_add': 5.6063,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}},
],
'multiply': [{'avg_time_backward_broadcast_mul': 19.1712,
'avg_time_mem_alloc_cpu/0': 207618.0469,
'avg_time_forward_broadcast_mul': 6.4855,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}},
]
}
```

## Usecase 3 - Run benchmarks for specific operator
For example, you want to run benchmarks for `nd.add` operator in MXNet, you just run the following python script.

```
#! /usr/bin/python
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
import mxnet as mx
from mxnet import nd

from benchmark.opperf.utils.benchmark_utils import run_performance_test

add_res = run_performance_test(nd.add, run_backward=True, dtype='float32', ctx=mx.cpu(),
inputs=[{"lhs": (1024, 1024),
"rhs": (1024, 1024)}],
warmup=10, runs=25)
```

Output for the above benchmark run, on a CPU machine, would look something like below:

```
{'add': [{'avg_time_mem_alloc_cpu/0': 102760.4453,
'avg_time_forward_broadcast_add': 4.0372,
'avg_time_backward_broadcast_add': 5.3841,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}}]}

```

## Usecase 3.1 - Run benchmarks for group of operators with same input
For example, you want to run benchmarks for `nd.add`, `nd.sub` operator in MXNet, with the same set of inputs. You just run the following python script.

```
#! /usr/bin/python
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
import mxnet as mx
from mxnet import nd

from benchmark.opperf.utils.benchmark_utils import run_performance_test

add_res = run_performance_test([nd.add, nd.sub], run_backward=True, dtype='float32', ctx=mx.cpu(),
inputs=[{"lhs": (1024, 1024),
"rhs": (1024, 1024)}],
warmup=10, runs=25)
```

Output for the above benchmark run, on a CPU machine, would look something like below:

```
{'add': [{'avg_time_mem_alloc_cpu/0': 102760.4453,
'avg_time_forward_broadcast_add': 4.0372,
'avg_time_backward_broadcast_add': 5.3841,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}}],
'subtract': [{'avg_time_forward_broadcast_sub': 5.5137,
'avg_time_mem_alloc_cpu/0': 207618.0469,
'avg_time_backward_broadcast_sub': 7.2976,
'inputs': {'lhs': (1024, 1024), 'rhs': (1024, 1024)}}
]}

```
# How does it work under the hood?

Under the hood, executes NDArray operator using randomly generated data. Use MXNet profiler to get summary of the operator execution:
1. Memory
2. Computation time (forward, backward)

See the design proposal document for more details - https://cwiki.apache.org/confluence/display/MXNET/MXNet+Operator+Benchmarks

# TODO

All contributions are welcome. Below is the list of desired features:

1. Cover all MXNet operators.
2. Enhance MXNet profiler with additional APIs to programmatically fetch and process profiler data.
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
3. Integration with CI/CD system to run operator benchmarks for PR builds, nightly builds.
4. Dashboards and other modes of presentation of results for analyzing and planning tasks such as operator performance improvements.
5. Integration with tools such as [Hypothesis](https://hypothesis.readthedocs.io/en/latest/) for randomized Tensor Shape generation for profiling to identify bottlenecks in the operators.
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions benchmark/opperf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
16 changes: 16 additions & 0 deletions benchmark/opperf/custom_operations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
67 changes: 67 additions & 0 deletions benchmark/opperf/custom_operations/custom_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx

"""
MXNet's Custom Operator Benchmark Tests.

It does a simple element wise addition to make sure computation
is not too much and we can observe custom operator logistics overhead.
"""


# 1. Define Custom Operator - Element wise Addition Multiplication
class CustomAddOne(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], in_data[0] + 1)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], out_grad[0])


@mx.operator.register("CustomAddOne")
class CustomAddOneProp(mx.operator.CustomOpProp):
def __init__(self):
super(CustomAddOneProp, self).__init__(need_top_grad=True)

def list_arguments(self):
return ['in']

def list_outputs(self):
return ['output']

def infer_shape(self, in_shape):
# inputs, outputs, aux
return [in_shape[0]], [in_shape[0]], []

def create_operator(self, ctx, shapes, dtypes):
return CustomAddOne()


"""Helps to benchmark MXNet's Custom Op for Element wise addition on a (1000, 1) tensor.
Performs both forward and backward operation.

This test mainly uncovers core custom op overhead in MXNet.

Benchmark will be done on the following operation:
native_add -> native_add -> native_add -> CUSTOM_ADD -> native_add -> native_add -> native_add

By default run on 'float32' precision.
"""

# TODO
16 changes: 16 additions & 0 deletions benchmark/opperf/nn_operations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
30 changes: 30 additions & 0 deletions benchmark/opperf/nn_operations/activation_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Performance benchmark Tests for MXNet NDArray NN Activation Operators.

TODO

1. LeakyRelu
2. PRelu
3. Activation (Sigmoid)
4. Activation (Softmax)
5. Activation (Log_Softmax)
6. Activation (tanh)
7. Elu
8. Selu
"""
26 changes: 26 additions & 0 deletions benchmark/opperf/nn_operations/basic_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Performance benchmark tests for MXNet NDArray NN Operators

TODO

1. FullyConnected (Basic)
2. Flatten
3. Embedding

"""
Loading