Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-703] Update to TensorRT 5, ONNX IR 3. Fix inference bugs. #13310

Merged
merged 3 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ if(USE_TENSORRT)
include_directories(${ONNX_PATH})
include_directories(3rdparty/onnx-tensorrt/)
include_directories(3rdparty/)
include_directories(3rdparty/onnx-tensorrt/third_party/onnx/)
add_definitions(-DMXNET_USE_TENSORRT=1)
add_definitions(-DONNX_NAMESPACE=onnx)

Expand Down
2 changes: 1 addition & 1 deletion ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
# Dockerfile to run MXNet on Ubuntu 16.04 for CPU

FROM nvidia/cuda:9.0-cudnn7-devel
FROM nvidia/cuda:10.0-cudnn7-devel

WORKDIR /work/deps

Expand Down
4 changes: 2 additions & 2 deletions ci/docker/install/tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pip3 install gluoncv==0.2.0
pushd .
cd ..
apt-get update
apt-get install -y automake libtool
apt-get install -y automake libtool zip
git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
cd protobuf
./autogen.sh
Expand All @@ -41,7 +41,7 @@ popd

# Install TensorRT
echo "TensorRT build enabled. Installing TensorRT."
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0_1-1_amd64.deb
dpkg -i tensorrt.deb
apt-get update
apt-get install -y --allow-downgrades libnvinfer-dev
Expand Down
19 changes: 12 additions & 7 deletions ci/docker/install/ubuntu_core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

set -ex
apt-get update || true

# Avoid interactive package installers such as tzdata.
export DEBIAN_FRONTEND=noninteractive

apt-get install -y \
apt-transport-https \
build-essential \
Expand All @@ -41,10 +45,11 @@ apt-get install -y \
unzip \
wget


# Ubuntu 14.04
if [[ $(lsb_release -r | grep 14.04) ]]; then
apt-get install -y cmake3
else
apt-get install -y cmake
fi
# Note: we specify an exact cmake version to work around a cmake 3.10 CUDA 10 issue.
# Reference: https://github.com/clab/dynet/issues/1457
mkdir /opt/cmake && cd /opt/cmake
wget -nv https://cmake.org/files/v3.12/cmake-3.12.4-Linux-x86_64.sh
sh cmake-3.12.4-Linux-x86_64.sh --prefix=/opt/cmake --skip-license
ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake
rm cmake-3.12.4-Linux-x86_64.sh
cmake --version
34 changes: 17 additions & 17 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -602,23 +602,23 @@ build_ubuntu_gpu_tensorrt() {
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/

rm -rf build
make \
DEV=1 \
ENABLE_TESTCOVERAGE=1 \
USE_BLAS=openblas \
USE_CUDA=1 \
USE_CUDA_PATH=/usr/local/cuda \
USE_CUDNN=1 \
USE_OPENCV=0 \
USE_MKLDNN=0 \
USE_DIST_KVSTORE=0 \
USE_TENSORRT=1 \
USE_JEMALLOC=0 \
USE_GPERFTOOLS=0 \
ONNX_NAMESPACE=onnx \
CUDA_ARCH="-gencode arch=compute_70,code=compute_70" \
-j$(nproc)
cd /work/build
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use signal handler to get stack traces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that an extra build flag? What would I have to add?

cmake -DUSE_CUDA=1 \
-DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
-DCMAKE_C_COMPILER_LAUNCHER=ccache \
-DUSE_CUDNN=1 \
-DUSE_OPENCV=1 \
-DUSE_TENSORRT=1 \
-DUSE_OPENMP=0 \
-DUSE_MKLDNN=0 \
-DUSE_MKL_IF_AVAILABLE=OFF \
-DENABLE_TESTCOVERAGE=ON \
-DCUDA_ARCH_NAME=Manual \
-DCUDA_ARCH_BIN=$CI_CMAKE_CUDA_ARCH_BIN \
-G Ninja \
/work/mxnet

ninja -v
}

build_ubuntu_gpu_mkldnn() {
Expand Down
2 changes: 1 addition & 1 deletion ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_tensorrt_lib = 'build/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/*'

Expand Down
2 changes: 1 addition & 1 deletion src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include "./onnx_to_tensorrt.h"

#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include <NvInfer.h>
#include <google/protobuf/io/coded_stream.h>
Expand Down
2 changes: 1 addition & 1 deletion src/executor/tensorrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <mxnet/op_attr_types.h>
#include <mxnet/operator.h>
#include <nnvm/graph_attr_types.h>
#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include "../operator/contrib/nnvm_to_onnx-inl.h"
#include "./exec_pass.h"
Expand Down
2 changes: 1 addition & 1 deletion src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "trt_graph_executor.h"

#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>
#include <NvInfer.h>
#include "./onnx_to_tensorrt.h"
#include "../operator/contrib/tensorrt-inl.h"
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>

#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include <algorithm>
#include <iostream>
Expand Down
138 changes: 73 additions & 65 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,22 @@ namespace nnvm_to_onnx {
op::ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

static std::atomic_ulong subgraph_count = { 0 };

op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

const nnvm::IndexedGraph& ig = g.indexed_graph();
const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");

// TODO(kellens): At the moment this check always passes no matter the weight dtypes used in your
// graph. We should first iterate over datatypes by name and ensure they're valid types
// (fp16 or fp32) and that they're uniform. Then ensure later conversions set tensor types
// correctly in ONNX.
for (auto& e : storage_types) {
if (e != mshadow::kFloat32) {
LOG(FATAL) << "ONNX converter does not support types other than float32 "
Expand All @@ -79,9 +86,23 @@ op::ONNXParam ConvertNnvmGraphToOnnx(
}

ModelProto model_proto;
// Need to determine IR versions and features to support
model_proto.set_ir_version(static_cast<int64>(2));

// We're currently serializing our models in ONNX 3, opset 8 as it is best supported by the
// currently linked version of the onnx-tensorrt library.
// More information on ONNX versions and opsets can be found at:
// https://github.com/onnx/onnx/blob/master/docs/IR.md

auto opset_proto = model_proto.add_opset_import();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is copy ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I believe so, proto editing in C++ is a little strange but I've seen this pattern in several places. I've run model validation and it certainly failed for me right away if I did not have a properly set opset proto associated with my model proto for onnx v3 models.

const int64 onnx_opset = 8;
const int64 onnx_major_version = 3;

// Declare our ONNX versions in our protobuf model.
opset_proto->set_version(onnx_opset);
model_proto.set_ir_version(onnx_major_version);

GraphProto* graph_proto = model_proto.mutable_graph();
auto subgraph_name_id = subgraph_count.fetch_add(1);
graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id));

std::unordered_map<std::string, TShape> placeholder_shapes =
GetPlaceholderShapes(shape_inputs, ig);
Expand Down Expand Up @@ -176,6 +197,20 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
// const bool no_bias = conv_param.no_bias;
const dmlc::optional<int> layout = conv_param.layout;

// dilations
AttributeProto* const dilations = node_proto->add_attribute();
dilations->set_name("dilations");
dilations->set_type(AttributeProto::INTS);
for (const dim_t kval : dilate) {
dilations->add_ints(static_cast<int64>(kval));
}

// group
AttributeProto* const group = node_proto->add_attribute();
group->set_name("group");
group->set_type(AttributeProto::INT);
group->set_i(static_cast<int64>(num_group));

// kernel shape
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
Expand All @@ -195,27 +230,13 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
pads->add_ints(static_cast<int64>(kval));
}

// dilations
AttributeProto* const dilations = node_proto->add_attribute();
dilations->set_name("dilations");
dilations->set_type(AttributeProto::INTS);
for (const dim_t kval : dilate) {
dilations->add_ints(static_cast<int64>(kval));
}

// strides
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (const dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

// group
AttributeProto* const group = node_proto->add_attribute();
group->set_name("group");
group->set_type(AttributeProto::INT);
group->set_i(static_cast<int64>(num_group));
} // end ConvertConvolution

void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
Expand Down Expand Up @@ -250,8 +271,12 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const pads = node_proto->add_attribute();
pads->set_name("pads");
pads->set_type(AttributeProto::INTS);
for (int kval : pad) {
pads->add_ints(static_cast<int64>(kval));

// Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice.
for (int i =0; i < 2; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already done by roshani if I remember it correctly. are we duplicating code here?

Form my understanding, we would only need onnx->Tensorrt since the mxnet onnx part should be handled by the existing solution. It feels like we are developing to parallel approaches here - but I could be mixing things up

Copy link
Contributor Author

@KellenSunderland KellenSunderland Nov 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I want to have a conversation with them offline, I've already brought this up a few times. The trick here is we need to export to ONNX from native code. This is more or less 100% duplication (minus tests). I'd be tempted to politely suggest we migrate all ONNX stuff to native code and core functionality so that for example I can export to ONNX, and java users running inference can import from ONNX. I think this feature is important enough it shouldn't be python only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Shall we wait with all that until after 1.4 or merge this PR to upgrade Tensorrt and do the refactor afterwards?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait until after 1.4 for sure, it's a big change that will require a lot of discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke to Roshani about this and we agreed that we'd keep both implementation for both for the time being. We think the support level is better in python, so that will be the main place to do import / exports from python. This implementation will try to duplicate their export logic, but it will only include the operators that TRT can make use of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for checking back :)

for (dim_t kval : pad) {
pads->add_ints(static_cast<int64>(kval));
}
}

// strides
Expand Down Expand Up @@ -315,11 +340,6 @@ void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
beta->set_type(AttributeProto::FLOAT);
beta->set_f(1.0f);

AttributeProto* const broadcast = node_proto->add_attribute();
broadcast->set_name("broadcast");
broadcast->set_type(AttributeProto::INT);
broadcast->set_i(1);

AttributeProto* const transA = node_proto->add_attribute();
transA->set_name("transA");
transA->set_type(AttributeProto::INT);
Expand Down Expand Up @@ -371,11 +391,6 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
epsilon->set_type(AttributeProto::FLOAT);
epsilon->set_f(static_cast<float>(param.eps));

AttributeProto* const is_test = node_proto->add_attribute();
is_test->set_name("is_test");
is_test->set_type(AttributeProto::INT);
is_test->set_i(1);

AttributeProto* const momentum = node_proto->add_attribute();
momentum->set_name("momentum");
momentum->set_type(AttributeProto::FLOAT);
Expand All @@ -384,31 +399,16 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const spatial = node_proto->add_attribute();
spatial->set_name("spatial");
spatial->set_type(AttributeProto::INT);
spatial->set_i(1);

AttributeProto* const consumed = node_proto->add_attribute();
consumed->set_name("consumed_inputs");
consumed->set_type(AttributeProto::INTS);

for (int i = 0; i < 5; i++) {
int val = (i < 3) ? 0 : 1;
consumed->add_ints(static_cast<int64>(val));
}
// MXNet computes mean and variance per feature for batchnorm. Enabling spatial mode
// (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
// disable this for MXNet's BatchNorm.
spatial->set_i(0);
KellenSunderland marked this conversation as resolved.
Show resolved Hide resolved
}

void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
node_proto->set_op_type("Add");
AttributeProto* const axis = node_proto->add_attribute();
axis->set_name("axis");
axis->set_type(AttributeProto::INT);
axis->set_i(1);

AttributeProto* const broadcast = node_proto->add_attribute();
broadcast->set_name("broadcast");
broadcast->set_type(AttributeProto::INT);
broadcast->set_i(0); // 1
}

std::unordered_map<std::string, TShape> GetPlaceholderShapes(
Expand Down Expand Up @@ -461,32 +461,40 @@ void ConvertPlaceholder(
void ConvertConstant(
GraphProto* const graph_proto, const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
NodeProto* const node_proto = graph_proto->add_node();
node_proto->set_name(node_name);
node_proto->add_output(node_name);
node_proto->set_op_type("Constant");
TensorProto* const initializer_proto = graph_proto->add_initializer();

// Create initializer for constants
initializer_proto->set_name(node_name);
// TODO(kellens): convert to fp16 if needed.
initializer_proto->set_data_type(TensorProto_DataType_FLOAT);

const NDArray nd = shared_buffer->find(node_name)->second;
const TBlob& blob = nd.data();
const TShape shape = blob.shape_;
const int32_t size = shape.Size();

for (auto& dim : shape) {
initializer_proto->add_dims(static_cast<int64>(dim));
}

auto size = shape.Size();
// TODO(kellens): Note hard coded float32 size assumed.
std::shared_ptr<float> shared_data_ptr(new float[size]);
float* const data_ptr = shared_data_ptr.get();
nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);

AttributeProto* const tensor_attr = node_proto->add_attribute();
tensor_attr->set_name("value");
tensor_attr->set_type(AttributeProto::TENSOR);

TensorProto* const tensor_proto = tensor_attr->mutable_t();
tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
for (auto& dim : shape) {
tensor_proto->add_dims(static_cast<int64>(dim));
for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
initializer_proto->add_float_data(data_ptr[blob_idx]);
}

for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
tensor_proto->add_float_data(data_ptr[blob_idx]);
// Create inputs for constants.
ValueInfoProto* const input_proto = graph_proto->add_input();
input_proto->set_name(node_name);

// TODO(kellens): (fp16 support)
input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
for (auto& dim : shape) {
auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim();
new_dim->set_dim_value(static_cast<int64>(dim));
}
}

Expand Down
Loading