diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt index 3d8ee049970e..f1c7aa63d88d 160000 --- a/3rdparty/onnx-tensorrt +++ b/3rdparty/onnx-tensorrt @@ -1 +1 @@ -Subproject commit 3d8ee049970e81ff4935cc7f36b653c0b27bcbbc +Subproject commit f1c7aa63d88d8d8ef70490f2ebb6b33f7450218b diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e3de2053477..23609e5ec243 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -187,6 +187,7 @@ if(USE_TENSORRT) include_directories(${ONNX_PATH}) include_directories(3rdparty/onnx-tensorrt/) include_directories(3rdparty/) + include_directories(3rdparty/onnx-tensorrt/third_party/onnx/) add_definitions(-DMXNET_USE_TENSORRT=1) add_definitions(-DONNX_NAMESPACE=onnx) diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt index 255da316041f..f4844115c0fd 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt @@ -18,7 +18,7 @@ # # Dockerfile to run MXNet on Ubuntu 16.04 for CPU -FROM nvidia/cuda:9.0-cudnn7-devel +FROM nvidia/cuda:10.0-cudnn7-devel WORKDIR /work/deps diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh index 61e73ef9a62f..1950cad0b52f 100755 --- a/ci/docker/install/tensorrt.sh +++ b/ci/docker/install/tensorrt.sh @@ -26,7 +26,7 @@ pip3 install gluoncv==0.2.0 pushd . cd .. apt-get update -apt-get install -y automake libtool +apt-get install -y automake libtool zip git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git cd protobuf ./autogen.sh @@ -41,7 +41,7 @@ popd # Install TensorRT echo "TensorRT build enabled. Installing TensorRT." -wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb +wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0_1-1_amd64.deb dpkg -i tensorrt.deb apt-get update apt-get install -y --allow-downgrades libnvinfer-dev diff --git a/ci/docker/install/ubuntu_core.sh b/ci/docker/install/ubuntu_core.sh index 4382aa6aefd0..fc903e5c8899 100755 --- a/ci/docker/install/ubuntu_core.sh +++ b/ci/docker/install/ubuntu_core.sh @@ -22,6 +22,10 @@ set -ex apt-get update || true + +# Avoid interactive package installers such as tzdata. +export DEBIAN_FRONTEND=noninteractive + apt-get install -y \ apt-transport-https \ build-essential \ @@ -41,10 +45,11 @@ apt-get install -y \ unzip \ wget - -# Ubuntu 14.04 -if [[ $(lsb_release -r | grep 14.04) ]]; then - apt-get install -y cmake3 -else - apt-get install -y cmake -fi +# Note: we specify an exact cmake version to work around a cmake 3.10 CUDA 10 issue. +# Reference: https://github.com/clab/dynet/issues/1457 +mkdir /opt/cmake && cd /opt/cmake +wget -nv https://cmake.org/files/v3.12/cmake-3.12.4-Linux-x86_64.sh +sh cmake-3.12.4-Linux-x86_64.sh --prefix=/opt/cmake --skip-license +ln -s /opt/cmake/bin/cmake /usr/local/bin/cmake +rm cmake-3.12.4-Linux-x86_64.sh +cmake --version diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 6c3f999480d3..787c9d704b42 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -602,23 +602,23 @@ build_ubuntu_gpu_tensorrt() { cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/ cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/ - rm -rf build - make \ - DEV=1 \ - ENABLE_TESTCOVERAGE=1 \ - USE_BLAS=openblas \ - USE_CUDA=1 \ - USE_CUDA_PATH=/usr/local/cuda \ - USE_CUDNN=1 \ - USE_OPENCV=0 \ - USE_MKLDNN=0 \ - USE_DIST_KVSTORE=0 \ - USE_TENSORRT=1 \ - USE_JEMALLOC=0 \ - USE_GPERFTOOLS=0 \ - ONNX_NAMESPACE=onnx \ - CUDA_ARCH="-gencode arch=compute_70,code=compute_70" \ - -j$(nproc) + cd /work/build + cmake -DUSE_CUDA=1 \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DUSE_CUDNN=1 \ + -DUSE_OPENCV=1 \ + -DUSE_TENSORRT=1 \ + -DUSE_OPENMP=0 \ + -DUSE_MKLDNN=0 \ + -DUSE_MKL_IF_AVAILABLE=OFF \ + -DENABLE_TESTCOVERAGE=ON \ + -DCUDA_ARCH_NAME=Manual \ + -DCUDA_ARCH_BIN=$CI_CMAKE_CUDA_ARCH_BIN \ + -G Ninja \ + /work/mxnet + + ninja -v } build_ubuntu_gpu_mkldnn() { diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index e812c4e24feb..7e8453968da5 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -34,7 +34,7 @@ mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/li mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' +mx_tensorrt_lib = 'build/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/*' diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc index e3a4ae868ce2..c37b856f9d62 100644 --- a/src/executor/onnx_to_tensorrt.cc +++ b/src/executor/onnx_to_tensorrt.cc @@ -28,7 +28,7 @@ #include "./onnx_to_tensorrt.h" -#include +#include #include #include diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc index d26704c35cf5..762dc0de9db5 100644 --- a/src/executor/tensorrt_pass.cc +++ b/src/executor/tensorrt_pass.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include "../operator/contrib/nnvm_to_onnx-inl.h" #include "./exec_pass.h" diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index ec35fee98a96..92bdcab90395 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -21,7 +21,7 @@ #include "trt_graph_executor.h" -#include +#include #include #include "./onnx_to_tensorrt.h" #include "../operator/contrib/tensorrt-inl.h" diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h index 011ffe6b7ddb..e0c4d9369e6e 100644 --- a/src/operator/contrib/nnvm_to_onnx-inl.h +++ b/src/operator/contrib/nnvm_to_onnx-inl.h @@ -37,7 +37,7 @@ #include #include -#include +#include #include #include diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc index 784384e94e1e..ccb6e04b0a47 100644 --- a/src/operator/contrib/nnvm_to_onnx.cc +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -62,15 +62,22 @@ namespace nnvm_to_onnx { op::ONNXParam ConvertNnvmGraphToOnnx( const nnvm::Graph& g, std::unordered_map* const shared_buffer) { - op::ONNXParam onnx_param; - op::nnvm_to_onnx::NameToIdx_t onnx_input_map; - op::nnvm_to_onnx::InferenceMap_t onnx_output_map; + + static std::atomic_ulong subgraph_count = { 0 }; + + op::ONNXParam onnx_param; + op::nnvm_to_onnx::NameToIdx_t onnx_input_map; + op::nnvm_to_onnx::InferenceMap_t onnx_output_map; const nnvm::IndexedGraph& ig = g.indexed_graph(); const auto& storage_types = g.GetAttr("storage_type"); const auto& dtypes = g.GetAttr("dtype"); const auto& shape_inputs = g.GetAttr("shape_inputs"); + // TODO(kellens): At the moment this check always passes no matter the weight dtypes used in your + // graph. We should first iterate over datatypes by name and ensure they're valid types + // (fp16 or fp32) and that they're uniform. Then ensure later conversions set tensor types + // correctly in ONNX. for (auto& e : storage_types) { if (e != mshadow::kFloat32) { LOG(FATAL) << "ONNX converter does not support types other than float32 " @@ -79,9 +86,23 @@ op::ONNXParam ConvertNnvmGraphToOnnx( } ModelProto model_proto; - // Need to determine IR versions and features to support - model_proto.set_ir_version(static_cast(2)); + + // We're currently serializing our models in ONNX 3, opset 8 as it is best supported by the + // currently linked version of the onnx-tensorrt library. + // More information on ONNX versions and opsets can be found at: + // https://github.com/onnx/onnx/blob/master/docs/IR.md + + auto opset_proto = model_proto.add_opset_import(); + const int64 onnx_opset = 8; + const int64 onnx_major_version = 3; + + // Declare our ONNX versions in our protobuf model. + opset_proto->set_version(onnx_opset); + model_proto.set_ir_version(onnx_major_version); + GraphProto* graph_proto = model_proto.mutable_graph(); + auto subgraph_name_id = subgraph_count.fetch_add(1); + graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id)); std::unordered_map placeholder_shapes = GetPlaceholderShapes(shape_inputs, ig); @@ -176,6 +197,20 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, // const bool no_bias = conv_param.no_bias; const dmlc::optional layout = conv_param.layout; + // dilations + AttributeProto* const dilations = node_proto->add_attribute(); + dilations->set_name("dilations"); + dilations->set_type(AttributeProto::INTS); + for (const dim_t kval : dilate) { + dilations->add_ints(static_cast(kval)); + } + + // group + AttributeProto* const group = node_proto->add_attribute(); + group->set_name("group"); + group->set_type(AttributeProto::INT); + group->set_i(static_cast(num_group)); + // kernel shape AttributeProto* const kernel_shape = node_proto->add_attribute(); kernel_shape->set_name("kernel_shape"); @@ -195,14 +230,6 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, pads->add_ints(static_cast(kval)); } - // dilations - AttributeProto* const dilations = node_proto->add_attribute(); - dilations->set_name("dilations"); - dilations->set_type(AttributeProto::INTS); - for (const dim_t kval : dilate) { - dilations->add_ints(static_cast(kval)); - } - // strides AttributeProto* const strides = node_proto->add_attribute(); strides->set_name("strides"); @@ -210,12 +237,6 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, for (const dim_t kval : stride) { strides->add_ints(static_cast(kval)); } - - // group - AttributeProto* const group = node_proto->add_attribute(); - group->set_name("group"); - group->set_type(AttributeProto::INT); - group->set_i(static_cast(num_group)); } // end ConvertConvolution void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, @@ -250,8 +271,12 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, AttributeProto* const pads = node_proto->add_attribute(); pads->set_name("pads"); pads->set_type(AttributeProto::INTS); - for (int kval : pad) { - pads->add_ints(static_cast(kval)); + + // Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice. + for (int i =0; i < 2; i++) { + for (dim_t kval : pad) { + pads->add_ints(static_cast(kval)); + } } // strides @@ -315,11 +340,6 @@ void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs, beta->set_type(AttributeProto::FLOAT); beta->set_f(1.0f); - AttributeProto* const broadcast = node_proto->add_attribute(); - broadcast->set_name("broadcast"); - broadcast->set_type(AttributeProto::INT); - broadcast->set_i(1); - AttributeProto* const transA = node_proto->add_attribute(); transA->set_name("transA"); transA->set_type(AttributeProto::INT); @@ -371,11 +391,6 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, epsilon->set_type(AttributeProto::FLOAT); epsilon->set_f(static_cast(param.eps)); - AttributeProto* const is_test = node_proto->add_attribute(); - is_test->set_name("is_test"); - is_test->set_type(AttributeProto::INT); - is_test->set_i(1); - AttributeProto* const momentum = node_proto->add_attribute(); momentum->set_name("momentum"); momentum->set_type(AttributeProto::FLOAT); @@ -384,31 +399,16 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, AttributeProto* const spatial = node_proto->add_attribute(); spatial->set_name("spatial"); spatial->set_type(AttributeProto::INT); - spatial->set_i(1); - - AttributeProto* const consumed = node_proto->add_attribute(); - consumed->set_name("consumed_inputs"); - consumed->set_type(AttributeProto::INTS); - - for (int i = 0; i < 5; i++) { - int val = (i < 3) ? 0 : 1; - consumed->add_ints(static_cast(val)); - } + // MXNet computes mean and variance per feature for batchnorm. Enabling spatial mode + // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly + // disable this for MXNet's BatchNorm. + spatial->set_i(0); } void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/, const nnvm::IndexedGraph& /*ig*/, const array_view& /*inputs*/) { node_proto->set_op_type("Add"); - AttributeProto* const axis = node_proto->add_attribute(); - axis->set_name("axis"); - axis->set_type(AttributeProto::INT); - axis->set_i(1); - - AttributeProto* const broadcast = node_proto->add_attribute(); - broadcast->set_name("broadcast"); - broadcast->set_type(AttributeProto::INT); - broadcast->set_i(0); // 1 } std::unordered_map GetPlaceholderShapes( @@ -461,32 +461,40 @@ void ConvertPlaceholder( void ConvertConstant( GraphProto* const graph_proto, const std::string& node_name, std::unordered_map* const shared_buffer) { - NodeProto* const node_proto = graph_proto->add_node(); - node_proto->set_name(node_name); - node_proto->add_output(node_name); - node_proto->set_op_type("Constant"); + TensorProto* const initializer_proto = graph_proto->add_initializer(); + + // Create initializer for constants + initializer_proto->set_name(node_name); + // TODO(kellens): convert to fp16 if needed. + initializer_proto->set_data_type(TensorProto_DataType_FLOAT); const NDArray nd = shared_buffer->find(node_name)->second; const TBlob& blob = nd.data(); const TShape shape = blob.shape_; - const int32_t size = shape.Size(); + for (auto& dim : shape) { + initializer_proto->add_dims(static_cast(dim)); + } + + auto size = shape.Size(); + // TODO(kellens): Note hard coded float32 size assumed. std::shared_ptr shared_data_ptr(new float[size]); float* const data_ptr = shared_data_ptr.get(); nd.SyncCopyToCPU(static_cast(data_ptr), size); - AttributeProto* const tensor_attr = node_proto->add_attribute(); - tensor_attr->set_name("value"); - tensor_attr->set_type(AttributeProto::TENSOR); - - TensorProto* const tensor_proto = tensor_attr->mutable_t(); - tensor_proto->set_data_type(TensorProto_DataType_FLOAT); - for (auto& dim : shape) { - tensor_proto->add_dims(static_cast(dim)); + for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) { + initializer_proto->add_float_data(data_ptr[blob_idx]); } - for (int blob_idx = 0; blob_idx < size; ++blob_idx) { - tensor_proto->add_float_data(data_ptr[blob_idx]); + // Create inputs for constants. + ValueInfoProto* const input_proto = graph_proto->add_input(); + input_proto->set_name(node_name); + + // TODO(kellens): (fp16 support) + input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + for (auto& dim : shape) { + auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim(); + new_dim->set_dim_value(static_cast(dim)); } } diff --git a/tests/python/tensorrt/test_resnet18.py b/tests/python/tensorrt/test_resnet18.py new file mode 100644 index 000000000000..fff3ac5dd768 --- /dev/null +++ b/tests/python/tensorrt/test_resnet18.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from mxnet.gluon.model_zoo import vision +from mxnet.test_utils import assert_almost_equal +import mxnet as mx +import numpy as np +import os + +batch_shape = (1, 3, 224, 224) +url = 'https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true' +model_file_name = 'resnet18_v2_trt_test' + + +def get_image(image_url): + fname = mx.test_utils.download(image_url, fname=image_url.split('/')[-1].split('?')[0]) + img = mx.image.imread(fname) + img = mx.image.imresize(img, 224, 224) # Resize + img = img.transpose((2, 0, 1)) # Channel first + img = img.expand_dims(axis=0) # Batchify + img = mx.nd.cast(img, dtype=np.float32) + return img/255.0 + + +def test_tensorrt_resnet18_feature_vect(): + print("downloading sample input") + input_data = get_image(url) + gluon_resnet18 = vision.resnet18_v2(pretrained=True) + gluon_resnet18.hybridize() + gluon_resnet18.forward(input_data) + gluon_resnet18.export(model_file_name) + sym, arg_params, aux_params = mx.model.load_checkpoint(model_file_name, 0) + + os.environ['MXNET_USE_TENSORRT'] = '0' + executor = sym.simple_bind(ctx=mx.gpu(), data=batch_shape, grad_req='null', force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + y = executor.forward(is_train=False, data=input_data) + + os.environ['MXNET_USE_TENSORRT'] = '1' + all_params = arg_params + all_params.update(aux_params) + executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(), all_params=all_params, data=batch_shape, + grad_req='null', force_rebind=True) + y_trt = executor.forward(is_train=False, data=input_data) + + no_trt_output = y[0].asnumpy()[0] + trt_output = y_trt[0].asnumpy()[0] + assert_almost_equal(no_trt_output, trt_output, 1e-4, 1e-4) + + +if __name__ == '__main__': + import nose + + nose.runmodule()