Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Multithreaded Inference Support #16654

Merged
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a6c95ef
Add cached op threadsafe version with corresponding C APIs, CPP Packa…
anirudh2290 Oct 19, 2019
5304b7c
Fix download cmd in runtime_functions
anirudh2290 Oct 28, 2019
9e3eced
Add CI changes
anirudh2290 Oct 29, 2019
1359ec8
Add stage
anirudh2290 Oct 29, 2019
b9b4b94
Fix lint
anirudh2290 Oct 30, 2019
6e8ff59
Change to DEFAULT for C API
anirudh2290 Oct 30, 2019
58a0790
Fix mxnet_unit_tests path
anirudh2290 Oct 30, 2019
24a888d
export correct LD_LIBRARY_PATH
anirudh2290 Oct 31, 2019
76b5076
Add cpp include dirs
anirudh2290 Oct 31, 2019
29ad64f
Build test with USE_CPP_PACKAGE
anirudh2290 Oct 31, 2019
d5b67e4
Add cached op threadsafe version with corresponding C APIs, CPP Packa…
anirudh2290 Oct 19, 2019
4b36e27
Fix download cmd in runtime_functions
anirudh2290 Oct 28, 2019
62d8979
Merge
anirudh2290 Oct 31, 2019
54a67b6
Merge remote
anirudh2290 Oct 31, 2019
4043ec1
change mkldnn lib name
anirudh2290 Nov 1, 2019
26bf63a
Add static_alloc, static_Shape support
anirudh2290 Nov 1, 2019
9cb121f
Address review comments
anirudh2290 Nov 5, 2019
6f7ac93
Make GetCachedOpThreadSafeState similar to cached_op
anirudh2290 Nov 5, 2019
db72a3e
Address review comments: comments for locking strategy
anirudh2290 Nov 7, 2019
fd6fa6d
multithreaded inference tutorial
anirudh2290 Nov 8, 2019
a8eb875
[Estimator] handle composite metrics in estimator (#16676)
szha Oct 31, 2019
437f1c7
[Estimator] refactor estimator to allow overriding evaluate/fit of a …
szha Oct 31, 2019
2a97260
Pointwise fusion for GPU (#15167)
ptrendx Nov 1, 2019
4f5a909
fix install dir (#16690)
TaoLv Nov 1, 2019
5b901e9
[numpy] add numpy operator : append (#16564)
JiangZhaoh Nov 1, 2019
b3c4f90
Initializer.__eq__ (#16680)
leezu Nov 1, 2019
2e7dd2b
fix binary dependencies in CD and nightly (#16693)
TaoLv Nov 1, 2019
954b63b
[MKL-DNN] Add mxnet mkldnn cmake tutorial (#16688)
xinyu-intel Nov 1, 2019
e5b5366
Revert "[MKLDNN]Fix reorder2default (#16602)" (#16697)
ZhennanQin Nov 1, 2019
0198d80
[Estimator] refactor estimator and clarify docs (#16694)
szha Nov 1, 2019
d104745
Eliminate common expressions (#15657)
ptrendx Nov 1, 2019
b1aba6a
Backport of #16711, #16737, #16408 to 1.6 branch (#16763)
ptrendx Nov 8, 2019
0b833a2
Add example and documentation for multi threaded inference
anirudh2290 Nov 14, 2019
23a02e5
merge changes
anirudh2290 Nov 14, 2019
4435f7d
merge changes
anirudh2290 Nov 14, 2019
453f4e5
Add LICENSE
anirudh2290 Nov 14, 2019
7e5d3ad
Add get_model.py
anirudh2290 Nov 14, 2019
c6ae1b8
Add license for README
anirudh2290 Nov 14, 2019
84e2ef3
Refactor cached op and cached op threadsafe
anirudh2290 Nov 20, 2019
b8270ac
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Nov 20, 2019
6bde360
Add limitation
anirudh2290 Nov 20, 2019
13074e2
Add tests for naive engine
anirudh2290 Dec 3, 2019
3e17e5f
Merge changes
anirudh2290 Dec 4, 2019
2ec6adb
Add latest test changes
anirudh2290 Dec 7, 2019
b96a603
Thread Safety tests in NaiveEngine mode
anirudh2290 Jan 6, 2020
b088f9c
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jan 6, 2020
978b762
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jan 6, 2020
c45ab34
Thread Safety tests update
anirudh2290 Jan 6, 2020
3204fc3
Update thread safety tests, add unsupported use cases
anirudh2290 Jan 7, 2020
a76e7c5
Changes to doc and refactor
anirudh2290 Jan 10, 2020
b3f5e7e
Fix todo owner, indentation and mx_float->float
anirudh2290 Jan 13, 2020
4ccfbd5
Refactor cached op code, remove num_threads arg from example
anirudh2290 Jan 14, 2020
2898766
Merge jenkins ci groovy file
anirudh2290 Jan 14, 2020
8c4e917
Fix lint
anirudh2290 Jan 14, 2020
dfd8a9e
Fix warning
anirudh2290 Jan 14, 2020
79d1878
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jan 15, 2020
cd64b33
Add back cython, required for unix-gpu build
anirudh2290 Jan 15, 2020
3345d8f
Fix for windows
anirudh2290 Jan 15, 2020
0b91267
Add bulking support for thread safe cached op version
anirudh2290 Jan 16, 2020
edd4fdf
Add support for subgraph testing
anirudh2290 Jan 16, 2020
8e7e085
import mxnet before calling get_backend_symbol
anirudh2290 Jan 16, 2020
800847d
Fix symbol json name
anirudh2290 Jan 16, 2020
36ae782
Refactor DynamicForward
anirudh2290 Jan 18, 2020
d942e0c
Add comments
anirudh2290 Jan 21, 2020
2524a24
Add DMLC_ATTRIBUTE_UNUSED
anirudh2290 Jan 21, 2020
231f7b1
Fix use_naive_run issue
anirudh2290 Jan 22, 2020
a663063
Fix lint
anirudh2290 Jan 22, 2020
ad90150
Revert unittest_cpp to old test since it doesnt test thread safety
anirudh2290 Jan 22, 2020
662ab93
Fix doc
anirudh2290 Jan 23, 2020
7288128
Merge
anirudh2290 Jan 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ if(USE_MKLDNN)
list(APPEND mxnet_LINKER_LIBS dnnl)
endif()

if(USE_CPP_PACKAGE)
add_definitions(-DMXNET_USE_CPP_PACKAGE=1)
endif()

# Allow Cuda compiles outside of src tree to find things in 'src' and 'include'
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down Expand Up @@ -825,7 +829,6 @@ if(MSVC AND USE_MXNET_LIB_NAMING)
set_target_properties(mxnet PROPERTIES OUTPUT_NAME "libmxnet")
endif()

add_subdirectory(tests)
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

include(GNUInstallDirs)
install(TARGETS ${MXNET_INSTALL_TARGETS}
Expand Down Expand Up @@ -887,6 +890,7 @@ endif()
if(BUILD_CPP_EXAMPLES)
add_subdirectory(example/image-classification/predict-cpp)
endif()
add_subdirectory(tests)

# ---[ Linter target
if(MSVC)
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ $(BIN) :
# CPP Package
ifeq ($(USE_CPP_PACKAGE), 1)
include cpp-package/cpp-package.mk
CFLAGS += -DMXNET_USE_CPP_PACKAGE=1
endif

include mkldnn.mk
Expand Down
34 changes: 34 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,25 @@ build_ubuntu_gpu_cuda101_cudnn7() {
CUDA_ARCH="$CI_CUDA_COMPUTE_CAPABILITIES" \
USE_SIGNAL_HANDLER=1 \
-j$(nproc)
}

build_ubuntu_gpu_cuda101_cudnn7_mkldnn_cpp_test() {
set -ex
build_ccache_wrappers
make \
DEV=1 \
USE_BLAS=openblas \
USE_MKLDNN=1 \
USE_CUDA=1 \
USE_CUDA_PATH=/usr/local/cuda \
USE_CUDNN=1 \
USE_TVM_OP=0 \
USE_CPP_PACKAGE=1 \
USE_DIST_KVSTORE=1 \
CUDA_ARCH="$CI_CUDA_COMPUTE_CAPABILITIES" \
USE_SIGNAL_HANDLER=1 \
-j$(nproc)
make test USE_CPP_PACKAGE=1 -j$(nproc)
make cython PYTHON=python2
make cython PYTHON=python3
}
Expand Down Expand Up @@ -1212,6 +1230,8 @@ unittest_ubuntu_cpugpu_perl() {

unittest_cpp() {
set -ex
export PYTHONPATH=./python/
python3 -c "import mxnet as mx; mx.test_utils.download_model(\"imagenet1k-resnet-18\"); mx.test_utils.download_model(\"imagenet1k-resnet-152\"); mx.test_utils.download_model(\"imagenet1k-resnet-50\");"
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
build/tests/mxnet_unit_tests
}

Expand Down Expand Up @@ -1366,6 +1386,20 @@ integrationtest_ubuntu_gpu_cpp_package() {
cpp-package/tests/ci_test.sh
}

integrationtest_ubuntu_gpu_capi_cpp_package() {
set -ex
export PYTHONPATH=./python/
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
python3 -c "import mxnet as mx; mx.test_utils.download_model(\"imagenet1k-resnet-18\"); mx.test_utils.download_model(\"imagenet1k-resnet-152\"); mx.test_utils.download_model(\"imagenet1k-resnet-50\");"
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*"
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*" --thread-safety-with-cpu
# Also run thread safety tests in NaiveEngine mode
export MXNET_ENGINE_TYPE=NaiveEngine
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*"
build/tests/cpp/mxnet_unit_tests --gtest_filter="ThreadSafety.*" --thread-safety-with-cpu
unset MXNET_ENGINE_TYPE
}

integrationtest_ubuntu_cpu_dist_kvstore() {
set -ex
pushd .
Expand Down
29 changes: 29 additions & 0 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/l
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so, build/tests/cpp/mxnet_unit_tests'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'

Expand Down Expand Up @@ -261,6 +262,20 @@ def compile_unix_full_gpu() {
}]
}

def compile_unix_full_gpu_mkldnn_cpp_test() {
return ['GPU: CUDA10.1+cuDNN7+MKLDNN+CPPTEST': {
node(NODE_LINUX_CPU) {
ws('workspace/build-gpu-mkldnn-cpp') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.docker_run('ubuntu_build_cuda', 'build_ubuntu_gpu_cuda101_cudnn7_mkldnn_cpp_test', false)
utils.pack_lib('gpu_mkldnn_cpp_test', mx_lib_cpp_capi)
}
}
}
}]
}

def compile_unix_full_gpu_no_tvm_op() {
return ['GPU: CUDA10.1+cuDNN7 TVM_OP OFF': {
node(NODE_LINUX_CPU) {
Expand Down Expand Up @@ -1010,6 +1025,20 @@ def test_unix_cpp_package_gpu() {
}]
}

def test_unix_capi_cpp_package() {
return ['capi-cpp-package GPU': {
node(NODE_LINUX_GPU) {
ws('workspace/it-capi-cpp-package') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu_mkldnn_cpp_test', mx_lib_cpp_capi)
utils.docker_run('ubuntu_gpu_cu101', 'integrationtest_ubuntu_gpu_capi_cpp_package', true)
utils.publish_test_coverage()
}
}
}
}]
}

def test_unix_scala_cpu() {
return ['Scala: CPU': {
node(NODE_LINUX_CPU) {
Expand Down
2 changes: 2 additions & 0 deletions ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ core_logic: {
custom_steps.compile_unix_int64_gpu(),
custom_steps.compile_unix_full_gpu_no_tvm_op(),
custom_steps.compile_unix_cmake_gpu_no_tvm_op(),
custom_steps.compile_unix_full_gpu_mkldnn_cpp_test()
])

utils.parallel_stage('Tests', [
Expand All @@ -64,6 +65,7 @@ core_logic: {
custom_steps.test_unix_distributed_kvstore_gpu(),
custom_steps.test_static_python_gpu(),
custom_steps.test_unix_python3_gpu_no_tvm_op(),
custom_steps.test_unix_capi_cpp_package(),

// Disabled due to: https://github.com/apache/incubator-mxnet/issues/11407
//custom_steps.test_unix_caffe_gpu()
Expand Down
2 changes: 1 addition & 1 deletion cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ inline NDArray::NDArray(const mx_float *data, const Shape &shape,
CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(),
context.GetDeviceId(), false, &handle),
0);
MXNDArraySyncCopyFromCPU(handle, data, shape.Size());
CHECK_EQ(MXNDArraySyncCopyFromCPU(handle, data, shape.Size()), 0);
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
inline NDArray::NDArray(const std::vector<mx_float> &data, const Shape &shape,
Expand Down
2 changes: 2 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class Symbol {
*unnamed (empty string).
*/
std::vector<std::string> ListArguments() const;
/*! \return lists all argument names and aux states of the symbol */
std::vector<std::string> ListInputs() const;
/*! \return get the descriptions of outputs for this symbol */
std::vector<std::string> ListOutputs() const;
/*! \return get the descriptions of auxiliary data for this symbol */
Expand Down
12 changes: 12 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ inline std::vector<std::string> Symbol::ListArguments() const {
}
return ret;
}

inline std::vector<std::string> Symbol::ListInputs() const {
std::vector<std::string> ret;
mx_uint size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctypes.c_uint instead of typedef

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate ?

const char **sarr;
NNSymbolListInputNames(GetHandle(), 0, &size, &sarr);
for (mx_uint i = 0; i < size; ++i) {
ret.push_back(std::string(sarr[i]));
}
return ret;
}

inline std::vector<std::string> Symbol::ListOutputs() const {
std::vector<std::string> ret;
mx_uint size;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
--
layout: page_api
title: Multi Threaded Inference
action: Get Started
action_url: /get_started
permalink: /api/cpp/docs/tutorials/multi_threaded_inference
is_tutorial: true
tag: cpp
--
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

## Multi Threaded Inference API

A long standing request from MXNet users has been to invoke parallel inference on a model from multiple threads while sharing the parameters.
With this use case in mind, the threadsafe version of CachedOp was added to provide a way for customers to do multi-threaded inference for MXNet users.
This doc attempts to do the following:
1. Discuss the current state of thread safety in MXNet
2. Explain how one can use C API and thread safe version of cached op, along with CPP package to achieve iultithreaded inference. This will be useful for end users as well as frontend developers of different language bindings
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
3. Discuss the limitations of the above approach
4. Future Work

## Current state of Thread Safety in MXNet

Examining the current state of thread safety in MXNet we can arrive to the following conclusion:

1. MXNet Dependency Engine is thread safe (except for WaitToRead invoked inside a spawned thread. Please see Limitations section)
2. Graph Executor which is Module/Symbolic/C Predict API backend is not thread safe
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
3. Cached Op (Gluon Backend) is not thread safe

The CachedOpThreadSafe and corresponding C APIs were added to address point 3 above and provide a way
for MXNet users to do multi-threaded inference.

```
/*!
* \brief create cached operator, allows to choose thread_safe version
* of cachedop
*/
MXNET_DLL int MXCreateCachedOpEX(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
CachedOpHandle *out,
bool thread_safe DEFAULT(false));
```

## Multithreaded inference in MXNet with C API and CPP Package

### Prerequisites
To complete this tutorial you need to:
- Learn the basics about [MXNet C++ API](/api/cpp)
- Build MXNet from source with make/cmake
- Build the multi-threaded inference example

### Setup the MXNet C++ API
To use the C++ API in MXNet, you need to build MXNet from source with C++ package. Please follow the [built from source guide](/get_started/ubuntu_setup.html), and [C++ Package documentation](/api/cpp)
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
The summary of those two documents is that you need to build MXNet from source with `USE_CPP_PACKAGE` flag set to 1. For example: `make -j USE_CPP_PACKAGE=1 USE_CUDA=1 USE_CUDNN=1`.
This example requires a build with CUDA and CUDNN.

### Build the example
If you have built mxnet from source with make, then do the following:

```bash
$ cd example/multi_threaded_inference
$ make
```

If you have built mxnet from source with cmake, please uncomment the specific lines for cmake build or set the following environment variables: `MKLDNN_BUILD_DIR (default is $(MXNET_ROOT)/3rdparty/mkldnn/build)`, `MKLDNN_INCLUDE_DIR (default is $(MXNET_ROOT)/3rdparty/mkldnn/include)`, `MXNET_LIB_DIR (default is $(MXNET_ROOT)/lib)`.

### Download the model and run multi threaded inference example
To download a model use the `get_model.py` script. This downloads a model to run inference.

```python
python3 get_model.py --model <model_name>
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
```
e.g.
```python
python3 get_model.py --model imagenet1k-inception-bn
```
Only the supported models with `get_model.py` work with multi threaded inference.
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

To run the multi threaded inference example:

First export `LD_LIBRARY_PATH`:

```bash
$ export LD_LIBRARY_PATH=<MXNET_LIB_DIR>:$LD_LIBRARY_PATH
```

```bash
$ ./multi_threaded_inference [model_name] [num_threads] [is_gpu] [file_names]
```
e.g.

```bash
./multi_threaded_inference imagenet1k-inception-bn 2 1 grace_hopper.jpg dog.jpg
```

The above script spawns 2 threads, shares the same cachedop and params among two threads, and runs inference on GPU. It returns the inference results in the order in which files are provided.

NOTE: This example is to demonstrate the multi-threaded-inference with cached op. The inference results work well only with specific models (e.g. imagenet1k-inception-bn). The results may not necessarily be very accurate because of different preprocessing step required etc.
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved

### Code walkthrough multi-threaded inference with CachedOp

The multi threaded inference example (`multi_threaded_inference.cc`) involves the following steps:

1. Parse arguments and load input image into ndarray
2. Prepare input data and load parameters, copying data to a specific context
3. Preparing arguments to pass to the CachedOp and calling C API to **create cached op**
4. Prepare lambda function which will run in spawned threads. Call C API to **invoke cached op** within the lambda function.
5. Spawn multiple threads and wait for all threads to complete.
6. Post process data to obtain inference results and cleanup.

### Step 1: Parse arguments and load input image into ndarray

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L299-L341](multi_threaded_inference.cc#L299-L341)

The above code parses arguments, loads the image file into a ndarray with a specific shape. There are a few things that are set by default and not configurable. For example, `static_alloc` and `static_shape` are by default set to true.


### Step 2: Prepare input data and load parameters, copying data to a specific context

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L147-L205](multi_threaded_inference.cc#L147-L205)

The above code loads params and copies input data and params to specific context.

### Step 3: Preparing arguments to pass to the CachedOp and calling C API to create cached op

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L207-L233](multi_threaded_inference.cc#L207-233)

The above code prepares `flag_key_cstrs` and `flag_val_cstrs` to be passed the Cached op.
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
The C API call is made with `MXCreateCachedOpEX`. This will lead to creation of thread safe cached
op since the `thread_safe` (which is the last parameter to `MXCreateCachedOpEX`) is set to
true. When this is set to false, it will invoke CachedOp instead of CachedOpThreadSafe.


### Step 4: Prepare lambda function which will run in spawned threads

[https://github.com/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L248-L262](multi_threaded_inference.cc#L248-262)

The above creates the lambda function taking the thread number as the argument.
If `random_sleep` is set it will sleep for a random number (secs) generated between 0 to 5 seconds.
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
Following this, it invokes `MXInvokeCachedOpEx`(from the hdl it determines whether to invoke cached op threadsafe version or not).
When this is set to false, it will invoke CachedOp instead of CachedOpThreadSafe.

### Step 5: Spawn multiple threads and wait for all threads to complete

[https://github.com/anirudh2290/apache/incubator-mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L264-L276](multi_threaded_inference.cc#L264-L276)

Spawns multiple threads, joins and waits to wait for all ops to complete.
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
The other alternative is to wait in the thread on the output ndarray and remove the WaitAll after join.

### Step 6: Post process data to obtain inference results and cleanup

[https://github.com/apache/incubator-/mxnet/example/multi_threaded_inference/multi_threaded_inference.cc#L286-L293](multi_threaded_inference.cc#L286-293)

The above code outputs results for different threads and cleans up the thread safe cached op.

## Current Limitations

1. Only operators tested with the existing model coverage are supported. Other operators and operator types (stateful operators, custom operators are not supported. Existing model coverage is as follows (this list will keep growing as we test more models with different model types):

|Models Tested|MKLDNN|CUDNN|NO-CUDNN|
| --- | --- | --- | --- |
| imagenet1k-resnet-18 | Yes | Yes | Yes |
| imagenet1k-resnet-152 | Yes | Yes | Yes |
| imagenet1k-resnet-50 | Yes | Yes | Yes |

2. Only dense storage types are supported currently.
3. Multi GPU Inference not supported currently.
4. Instantiating multiple instances of SymbolBlockThreadSafe is not supported. Can run parallel inference only on one model per process.
5. dynamic shapes not supported in thread safe cached op.
6. Bulking of ops is not supported.
7. This only supports inference use cases currently, training use cases are not supported.
8. Graph rewrites with subgraph API currently not supported.
9. There is currently no frontend API support to run multi threaded inference. Users can use CreateCachedOpEX and InvokeCachedOp in combination with
the CPP frontend to run multi-threaded inference as of today.
10. Multi threaded inference with threaded engine with Module/Symbolic API and C Predict API are not currently supported.
11. Exception thrown with `wait_to_read` in individual threads can cause issues. Calling invoke from each thread and calling WaitAll after thread joins should still work fine.
12. Tested only on environments supported by CI. This means that MacOS is not supported.

## Future Work

Future work includes Increasing model coverage and addressing most of the limitations mentioned under Current Limitations except the training use case.
For more updates, please subscribe to discussion activity on RFC: https://github.com/apache/incubator-mxnet/issues/16431.
Loading