From 87ced8b7ad549c5c33e2624d1b36fdce3d13f1a2 Mon Sep 17 00:00:00 2001 From: Haohuan Wang Date: Fri, 26 Jul 2019 18:07:15 -0700 Subject: [PATCH] handle fix_gamma in tensorrt subgraph conversion correctly (#15645) --- .../subgraph/tensorrt/nnvm_to_onnx-inl.h | 21 ++++-- .../subgraph/tensorrt/nnvm_to_onnx.cc | 29 ++++++++- src/operator/subgraph/tensorrt/tensorrt.cc | 2 +- .../tensorrt/test_tensorrt_batchnorm.py | 65 +++++++++++++++++++ 4 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 tests/python/tensorrt/test_tensorrt_batchnorm.py diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h index 4a88aee886db..f5bf8b7b8a1d 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h @@ -33,6 +33,8 @@ #include +#include +#include #include namespace mxnet { @@ -72,15 +74,12 @@ typedef void (*ConverterFunction)(NodeProto *node_proto, const nnvm::IndexedGraph &ig, const array_view &inputs); - // Forward declarations -void ConvertConvolution( - NodeProto *node_proto, +void ConvertConvolution(NodeProto *node_proto, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); - void ConvertPooling(NodeProto *node_proto, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, @@ -142,7 +141,7 @@ void ConvertPad(NodeProto* node_proto, const array_view &inputs); std::string ConvertNnvmGraphToOnnx(const nnvm::Graph &g, - const std::unordered_map* const params_map); + std::unordered_map* params_map); static const std::unordered_map converter_map = { {"Activation", ConvertActivation}, @@ -160,6 +159,18 @@ static const std::unordered_map converter_map = {"SoftmaxOutput", ConvertSoftmaxOutput} }; +typedef void (*PreprocessFunction)(const NodeAttrs &attrs, + const std::vector &inputs, + std::unordered_map *params_map); + +void PreprocessBatchNorm(const NodeAttrs &attrs, + const std::vector &inputs, + std::unordered_map *params_map); + +static const std::unordered_map preprocess_map = { + {"BatchNorm", PreprocessBatchNorm} +}; + } // namespace nnvm_to_onnx } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc index da89c2b476ee..111995db907c 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc @@ -54,7 +54,7 @@ namespace nnvm_to_onnx { std::string ConvertNnvmGraphToOnnx( const nnvm::Graph& g, - const std::unordered_map* const params_map) { + std::unordered_map* params_map) { static std::atomic_ulong subgraph_count = { 0 }; @@ -88,8 +88,21 @@ std::string ConvertNnvmGraphToOnnx( auto placeholder_shapes = GetPlaceholderShapes(shape_inputs, ig); auto placeholder_dtypes = GetPlaceholderDTypes(dtype_inputs, ig); auto output_lookup = GetOutputLookup(ig); - uint32_t current_input = 0; + for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) { + const IndexedGraph::Node& node = ig[node_idx]; + const nnvm::Node* source = node.source; + // If this is a op + if (!source->is_variable()) { + auto mightNeedPreprocessNode = preprocess_map.find(source->op()->name); + // if this op is defined in preprocess_map + if (mightNeedPreprocessNode != preprocess_map.end()) { + mightNeedPreprocessNode->second(source->attrs, source->inputs, params_map); + } + } + } + + uint32_t current_input = 0; // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc. for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) { const IndexedGraph::Node& node = ig[node_idx]; @@ -630,6 +643,18 @@ void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs, node_proto->set_op_type("Dropout"); } +void PreprocessBatchNorm(const NodeAttrs &attrs, + const std::vector &inputs, + std::unordered_map *params_map) { + const auto& param = nnvm::get(attrs.parsed); + if (param.fix_gamma) { + // if mxnet is specify fix_gamma, we will need to preprocess the params map + // to convert the gamma associate with this batch norm layer to 1. + std::string gammaNodeName = inputs[batchnorm::kGamma].node->attrs.name; + (*params_map)[gammaNodeName] = 1.0f; + } +} + } // namespace nnvm_to_onnx } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/tensorrt/tensorrt.cc b/src/operator/subgraph/tensorrt/tensorrt.cc index 30fcee007cfc..7652510ef412 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cc +++ b/src/operator/subgraph/tensorrt/tensorrt.cc @@ -272,7 +272,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, << " instead of: " << max_batch_size; max_batch_size = in_shape[0][0]; } - const auto& params_map = node_param.params_map; + std::unordered_map params_map = node_param.params_map; const auto& inputs_to_idx = node_param.inputs_to_idx; const auto& outputs_to_idx = node_param.outputs_to_idx; const auto& idx_g = graph.indexed_graph(); diff --git a/tests/python/tensorrt/test_tensorrt_batchnorm.py b/tests/python/tensorrt/test_tensorrt_batchnorm.py new file mode 100644 index 000000000000..62af3bbf329b --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt_batchnorm.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet.test_utils import assert_almost_equal + +def get_params(): + arg_params = {} + aux_params = {} + arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3)) + arg_params["trt_bn_test_bn_gamma"] = mx.nd.zeros((1,)) + arg_params["trt_bn_test_bn_beta"] = mx.nd.zeros((1,)) + aux_params["trt_bn_test_bn_moving_mean"] = mx.nd.ones(1) + aux_params["trt_bn_test_bn_moving_var"] = mx.nd.ones(1) + return arg_params, aux_params + +def get_symbol(): + data = mx.sym.Variable("data") + conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, num_filter=1, num_group=1, + name="trt_bn_test_conv") + bn = mx.sym.BatchNorm(data=conv, fix_gamma=True, use_global_stats=False, name="trt_bn_test_bn") + return bn + +def test_batch_norm_runs_correctly_with_fix_gamma(): + arg_params, aux_params = get_params() + arg_params_trt, aux_params_trt = get_params() + + sym = get_symbol() + sym_trt = get_symbol().get_backend_symbol("TensorRT") + + mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, aux_params_trt) + + executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + + executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', + force_rebind=True) + executor_trt.copy_params_from(arg_params_trt, aux_params_trt) + + input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3)) + + y = executor.forward(is_train=False, data=input_data) + y_trt = executor_trt.forward(is_train=False, data=input_data) + + print(y[0].asnumpy()) + print(y_trt[0].asnumpy()) + assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4) + +if __name__ == '__main__': + import nose + nose.runmodule()