diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 93cb17464f1b..38038f2a4618 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -430,7 +430,7 @@ void SortEntries(const std::unordered_map& entry } /*! - * \brief Given a subgraph, find the output entries of a subgraph. + * \brief Given a subgraph, find the input entries of a subgraph. * \param g pointer to the whole graph * \param simple_nods vector of simple nodes in top sorted order * \param subgraph_nodes vector of pointers of simples of a subgraph. diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc index d82f7544a091..4f5bdcb8561c 100644 --- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc +++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc @@ -74,7 +74,9 @@ std::tuple, auto trt_logger = std::unique_ptr(new TRT_Logger(verbosity)); auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger)); - auto trt_network = InferObject(trt_builder->createNetwork()); + const auto explicitBatch = 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch)); auto trt_parser = InferObject(nvonnxparser::createParser(*trt_network, *trt_logger)); ::ONNX_NAMESPACE::ModelProto parsed_model; // We check for a valid parse, but the main effect is the side effect diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h index 16cc13006d59..b35a1715000e 100644 --- a/src/operator/subgraph/tensorrt/tensorrt-inl.h +++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h @@ -267,6 +267,23 @@ class TensorrtProperty : public SubgraphProperty { return std::make_shared(); } + void PrePartition(const nnvm::Graph& g, + const std::vector>& options_map) override { + auto& in_arg_names = g.GetAttr>("in_arg_names"); + auto& in_aux_names = g.GetAttr>("in_aux_names"); + NDArray **in_args_ptr = g.GetAttr("in_args"); + NDArray **in_aux_ptr = g.GetAttr("in_aux"); + in_args_dict.clear(); + in_aux_dict.clear(); + // we trust the Python API, len(in_arg_names) == len(in_args_ptr) + for (unsigned i = 0; i < in_arg_names.size(); ++i) { + in_args_dict[in_arg_names[i]] = in_args_ptr[i]; + } + for (unsigned i = 0; i < in_aux_names.size(); ++i) { + in_aux_dict[in_aux_names[i]] = in_aux_ptr[i]; + } + } + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id) const override { nnvm::ObjectPtr n = nnvm::Node::Create(); @@ -280,16 +297,33 @@ class TensorrtProperty : public SubgraphProperty { n->attrs.op = Op::Get("_TensorRT"); CHECK(n->attrs.op); n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + + // Mapping subgraph params with NDArrays + TRTParam param; std::ostringstream params_oss; - for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) { - params_oss << e << ";"; + for (auto ¶m_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) { + NDArray *cache; + auto it_args = in_args_dict.find(param_name); + if (it_args != in_args_dict.end()) { + cache = it_args->second; + } else { + auto it_aux = in_aux_dict.find(param_name); + if (it_aux != in_aux_dict.end()) { + cache = it_aux->second; + } + } + if (cache != nullptr) { + param.params_map.emplace(param_name, cache->Copy(Context())); + param.params_map[param_name].WaitToRead(); + params_oss << param_name << ";"; + } } auto tensorrt_params_names = params_oss.str(); - tensorrt_params_names.pop_back(); - n->attrs.dict["subgraph_params_names"] = tensorrt_params_names; - TRTParam param; + if (!tensorrt_params_names.empty()) { + tensorrt_params_names.pop_back(); + } n->attrs.parsed = param; - n->op()->attr_parser(&(n->attrs)); + n->attrs.dict["subgraph_params_names"] = tensorrt_params_names; return n; } @@ -328,6 +362,8 @@ class TensorrtProperty : public SubgraphProperty { } subgraph_node->attrs.parsed = std::move(_params); } + + std::unordered_map in_args_dict, in_aux_dict; }; diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu b/src/operator/subgraph/tensorrt/tensorrt.cu index 4a5b23b3a9f7..826f9a5876b6 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cu +++ b/src/operator/subgraph/tensorrt/tensorrt.cu @@ -56,12 +56,12 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx, param.bindings->at(i) = outputs[p.first].dptr_; } } - const int batch_size = static_cast(inputs[0].shape_[0]); - param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, nullptr); + param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr); } NNVM_REGISTER_OP(_TensorRT) -.set_attr("FStatefulCompute", TRTCompute); +.set_attr("FStatefulCompute", TRTCompute) +.set_attr("FGradient", MakeZeroGradNodes); } // namespace op } // namespace mxnet