From 88c76da8a36b6707084551fa6d26a0368d6d8568 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 4 Dec 2019 13:41:28 -0800 Subject: [PATCH 01/18] Add a check for number of inputs --- src/operator/operator_common.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index f6af58bce995..710a141779d6 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -376,6 +376,12 @@ inline nnvm::NodePtr MakeNode( if (p->op()->attr_parser != nullptr) { p->op()->attr_parser(&(p->attrs)); } + if (inputs != nullptr) { + CHECK_EQ(p->num_inputs(), p->inputs.size()) + << "Number of inputs to operator " << op_name << " (" << p->num_inputs() + << ") does not match the actual number of inputs provided to operator " + << name << " (" << p->inputs.size() << ")."; + } return p; } From 68b1a9df52d84a62af13d3100ffeca4eb8fc11e8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 4 Dec 2019 11:29:12 -0800 Subject: [PATCH 02/18] Fix num inputs for backward_Deconvolution --- src/operator/nn/deconvolution.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index bbcec53e933d..f0a6f8841419 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -445,6 +445,10 @@ NNVM_REGISTER_OP(Deconvolution) .add_arguments(DeconvolutionParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Deconvolution) +.set_num_inputs([](const NodeAttrs& attrs) { + const DeconvolutionParam& params = nnvm::get(attrs.parsed); + return params.no_bias ? 3 : 4; +}) .set_num_outputs([](const NodeAttrs& attrs) { const DeconvolutionParam& params = nnvm::get(attrs.parsed); return params.no_bias ? 2 : 3; From db79a63dd9b833164a1b84ce29dc02b18265a946 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 4 Dec 2019 11:21:24 -0800 Subject: [PATCH 03/18] Fix number of inputs to backward ROIAlign --- src/operator/contrib/roi_align.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc index eeb9f622c2be..38b889b587c1 100644 --- a/src/operator/contrib/roi_align.cc +++ b/src/operator/contrib/roi_align.cc @@ -621,6 +621,7 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017 NNVM_REGISTER_OP(_backward_ROIAlign) +.set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) From c149273f3dda864dc3d04e9a59f3110597f79f3c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 10 Dec 2019 13:39:56 -0800 Subject: [PATCH 04/18] Fix number of inputs to backward_SoftmaxOutput --- src/operator/softmax_output.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 0bf6e2a014a6..194930f7864a 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -258,6 +258,7 @@ NNVM_REGISTER_OP(SoftmaxOutput) NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax"); NNVM_REGISTER_OP(_backward_SoftmaxOutput) +.set_num_inputs(2) .set_num_outputs(2) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ From 9b115577613e7353a05b3caad5104449c35ae7c8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Dec 2019 10:32:12 -0800 Subject: [PATCH 05/18] Fix more operators lying about their number of inputs --- src/operator/contrib/bilinear_resize-inl.h | 9 --------- src/operator/contrib/bilinear_resize.cc | 2 +- src/operator/custom/custom.cc | 2 +- src/operator/image/image_random.cc | 2 +- src/operator/nn/ctc_loss.cc | 2 +- src/operator/nn/lrn.cc | 1 + 6 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index 2167f2558a05..0db9494748a0 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -328,15 +328,6 @@ inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) { } } -inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) { - auto& param = nnvm::get(attrs.parsed); - if (param.mode == bilinear_resize::like) { - return 3; - } else { - return 1; - } -} - inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) { auto& param = nnvm::get(attrs.parsed); if (param.mode == bilinear_resize::like) { diff --git a/src/operator/contrib/bilinear_resize.cc b/src/operator/contrib/bilinear_resize.cc index 0351e479290d..399a5a79bd56 100644 --- a/src/operator/contrib/bilinear_resize.cc +++ b/src/operator/contrib/bilinear_resize.cc @@ -232,7 +232,7 @@ for more details. NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) .set_attr_parser(ParamParser) -.set_num_inputs(BilinearSampleOpNumBackwardInputs) +.set_num_inputs(1) .set_num_outputs(BilinearSampleOpNumBackwardOutputs) .set_attr("TIsBackward", true) .set_attr("FCompute", BilinearSampleOpBackward); diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc index 9130830b900c..3c4843c33395 100644 --- a/src/operator/custom/custom.cc +++ b/src/operator/custom/custom.cc @@ -586,7 +586,7 @@ Please check the tutorial here: https://mxnet.incubator.apache.org/api/faq/new_o NNVM_REGISTER_OP(_backward_Custom) .set_num_inputs([](const NodeAttrs& attrs){ const CustomParam& params = nnvm::get(attrs.parsed); - return params.bwd_idx.size(); + return params.bwd_idx.size() + params.num_auxs; }) .set_num_outputs([](const NodeAttrs& attrs){ const CustomParam& params = nnvm::get(attrs.parsed); diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 0c4603ecc475..aa387e683bfd 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -185,7 +185,7 @@ NNVM_REGISTER_OP(_image_normalize) NNVM_REGISTER_OP(_backward_image_normalize) .set_attr_parser(ParamParser) -.set_num_inputs(1) +.set_num_inputs(2) .set_num_outputs(1) .set_attr("TIsBackward", true) .set_attr("FCompute", NormalizeOpBackward); diff --git a/src/operator/nn/ctc_loss.cc b/src/operator/nn/ctc_loss.cc index aba76fb0c452..096ef8c0d7b4 100644 --- a/src/operator/nn/ctc_loss.cc +++ b/src/operator/nn/ctc_loss.cc @@ -130,7 +130,7 @@ information on the definition and the algorithm. NNVM_REGISTER_OP(_backward_ctc_loss) .set_attr_parser(ParamParser) -.set_num_inputs(1) +.set_num_inputs(4) .set_num_outputs(CTCLossOpNumInputs) .set_attr("TIsBackward", true) .set_attr("FCompute", CTCLossOpBackward); diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 41337352df63..14967912e3c9 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -184,6 +184,7 @@ number of kernels in the layer. .add_arguments(LRNParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_LRN) +.set_num_inputs(3) .set_num_outputs(1) .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 From ffbeded8e63d7a55482743f3309c851223079445 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 1 Nov 2019 13:35:32 -0700 Subject: [PATCH 06/18] Fix input number of backward NMS --- src/operator/contrib/bounding_box.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/bounding_box.cc b/src/operator/contrib/bounding_box.cc index 8b1d53506c47..906daa480786 100644 --- a/src/operator/contrib/bounding_box.cc +++ b/src/operator/contrib/bounding_box.cc @@ -110,7 +110,7 @@ Examples:: .add_arguments(BoxNMSParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_box_nms) -.set_num_inputs(3) +.set_num_inputs(4) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) From 3574d605d15fdfd2eed6c14adff4c712cbfa03b4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 31 Oct 2019 11:32:19 -0700 Subject: [PATCH 07/18] Fixes --- src/operator/nn/batch_norm.cc | 1 + src/operator/nn/convolution.cc | 4 ++++ src/operator/nn/pooling.cc | 5 +++++ 3 files changed, 10 insertions(+) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 04e45d4acfed..ea1c76965a9b 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -593,6 +593,7 @@ then set ``gamma`` to 1 and its gradient to 0. }); NNVM_REGISTER_OP(_backward_BatchNorm) +.set_num_inputs(8) .set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 6d9f84ffc510..36ee4e0c50d3 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -510,6 +510,10 @@ There are other options to tune the performance. .add_arguments(ConvolutionParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Convolution) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConvolutionParam& params = nnvm::get(attrs.parsed); + return params.no_bias ? 3 : 4; +}) .set_num_outputs([](const NodeAttrs& attrs) { const ConvolutionParam& params = nnvm::get(attrs.parsed); return params.no_bias ? 2 : 3; diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 485fc1345dfd..65e490c30571 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -454,6 +454,11 @@ For each window ``X``, the mathematical expression for Lp pooling is: .add_arguments(PoolingParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Pooling) +.set_num_inputs([](const NodeAttrs& attrs) { + const PoolingParam ¶m = nnvm::get(attrs.parsed); + // 1 input gradient, 1 input to fwd op and outputs from fwd op + return 2 + GetNumOutputs(param); +}) .set_num_outputs(1) .set_attr("TIsBackward", true) .set_attr( From 48de72227f1dda05d6b3b5c8332360bc6da0ed1a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Dec 2019 11:05:07 -0800 Subject: [PATCH 08/18] Fix dropout, RNN and upsampling backward number of inputs --- src/operator/nn/upsampling.cc | 8 ++++++++ src/operator/rnn.cc | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index d36b2598ce82..8000106fd6c4 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -211,6 +211,14 @@ Example:: }); NNVM_REGISTER_OP(_backward_UpSampling) +.set_num_inputs([](const NodeAttrs& attrs) { + const UpSamplingParam& param_ = nnvm::get(attrs.parsed); + if (param_.sample_type != up_enum::kNearest) { + return 3; + } else { + return 1; + } +}) .set_num_outputs([](const NodeAttrs& attrs) { const UpSamplingParam& params = nnvm::get(attrs.parsed); return params.sample_type == up_enum::kNearest ? params.num_args : 2; diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 6d568c81bc1c..50e7317ea349 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -403,6 +403,20 @@ The definition of GRU here is slightly different from paper but compatible with .add_arguments(RNNParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_RNN) +.set_num_inputs([](const NodeAttrs& attrs) { + const RNNParam& params = nnvm::get(attrs.parsed); + int ret = 5; + if (params.state_outputs) { + ret += 2; + } + if (params.mode == rnn_enum::kLstm) { + ++ret; + if (params.state_outputs) { + ret += 2; + } + } + return ret; +}) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); return GetNumInputArguments(params); From a823eec45760f429232ca15e4ad85bcb3095c678 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 13 Dec 2019 14:11:56 -0800 Subject: [PATCH 09/18] Fix LeakyRelu number of inputs --- src/operator/leaky_relu.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index 49ba95d306f4..f250057ecfc4 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -206,6 +206,19 @@ The following modified ReLU Activation functions are supported: }); NNVM_REGISTER_OP(_backward_LeakyReLU) +.set_num_outputs([](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + if (param.act_type == leakyrelu::kPReLU) { + // forward has 2 inputs and 1 output + return 2 + 2 * 1; + } else if (param.act_type == leakyrelu::kRReLU) { + // forward has 1 input and 2 outputs + return 1 + 2 * 2; + } else { + // forward has 1 input and 1 output + return 1 + 2 * 1; + } +}) .set_num_outputs([](const NodeAttrs& attrs) { const LeakyReLUParam& param = nnvm::get(attrs.parsed); return param.act_type == leakyrelu::kPReLU ? 2 : 1; From 4cca439db4f4278627cfdd0ce8e4d657ee5c5810 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 16 Dec 2019 08:53:23 -0800 Subject: [PATCH 10/18] Actually fix LeakyRelu --- src/operator/leaky_relu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index f250057ecfc4..c2414ad74600 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -206,7 +206,7 @@ The following modified ReLU Activation functions are supported: }); NNVM_REGISTER_OP(_backward_LeakyReLU) -.set_num_outputs([](const NodeAttrs& attrs) { +.set_num_inputs([](const NodeAttrs& attrs) { const LeakyReLUParam& param = nnvm::get(attrs.parsed); if (param.act_type == leakyrelu::kPReLU) { // forward has 2 inputs and 1 output From 2cc89256e869b57c61278f14d7917fb48c8de87b Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 16 Dec 2019 13:00:32 -0800 Subject: [PATCH 11/18] Fix pooling and concat --- src/operator/nn/concat.cc | 4 ++++ src/operator/nn/pooling.cc | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 1eeef7db5cb5..be781e1a4e77 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -394,6 +394,10 @@ CONCAT_FORWARD_ATTRS .add_arguments(ConcatParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Concat) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return 1 + params.num_args; +}) .set_num_outputs([](const NodeAttrs& attrs) { const ConcatParam& params = nnvm::get(attrs.parsed); return params.num_args; diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 65e490c30571..1161009e9d0a 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -456,8 +456,8 @@ For each window ``X``, the mathematical expression for Lp pooling is: NNVM_REGISTER_OP(_backward_Pooling) .set_num_inputs([](const NodeAttrs& attrs) { const PoolingParam ¶m = nnvm::get(attrs.parsed); - // 1 input gradient, 1 input to fwd op and outputs from fwd op - return 2 + GetNumOutputs(param); + // 1 input to fwd op and 2 * outputs from fwd op (fwd outputs and gradient inputs) + return 1 + 2 * GetNumOutputs(param); }) .set_num_outputs(1) .set_attr("TIsBackward", true) From e2eaae8f98a8744d581b021ceb1f5d0834b07e2f Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 17 Dec 2019 08:36:02 -0800 Subject: [PATCH 12/18] Fix Concat (attempt 2) --- src/operator/nn/concat.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index be781e1a4e77..c5ae45e5f2a2 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -395,8 +395,12 @@ CONCAT_FORWARD_ATTRS NNVM_REGISTER_OP(_backward_Concat) .set_num_inputs([](const NodeAttrs& attrs) { +#if MXNET_USE_MKLDNN const ConcatParam& params = nnvm::get(attrs.parsed); return 1 + params.num_args; +#else + return 1; +#endif }) .set_num_outputs([](const NodeAttrs& attrs) { const ConcatParam& params = nnvm::get(attrs.parsed); From 2b97e558dad895905f5977b015a74debdc1a9c9d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 2 Jan 2020 11:24:50 -0800 Subject: [PATCH 13/18] Fix from review --- src/operator/nn/concat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index c5ae45e5f2a2..0d57ed860d51 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -395,7 +395,7 @@ CONCAT_FORWARD_ATTRS NNVM_REGISTER_OP(_backward_Concat) .set_num_inputs([](const NodeAttrs& attrs) { -#if MXNET_USE_MKLDNN +#if MXNET_USE_MKLDNN == 1 const ConcatParam& params = nnvm::get(attrs.parsed); return 1 + params.num_args; #else From d48b0379747c80731257381008b33f6786075f88 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Jan 2020 20:44:31 -0800 Subject: [PATCH 14/18] Incorporate Dick's changes --- src/operator/operator_common.h | 40 ++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index d50d6778f400..7ae61248343e 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -359,12 +359,29 @@ inline bool dispatch_fallback(StorageTypeVector* stypes, DispatchMode* dispatch) return true; } +inline std::vectorCreateNodeEntries( + nnvm::NodePtr pNode, + const std::vector* pOgrads = nullptr, + const std::vector* pInputs = nullptr) { + if (pOgrads) + pNode->inputs.insert(pNode->inputs.end(), pOgrads->begin(), pOgrads->end()); + + if (pInputs) + pNode->inputs.insert(pNode->inputs.end(), pInputs->begin(), pInputs->end()); + + std::vector ret; + for (uint32_t i = 0; i < pNode->num_outputs(); ++i) + ret.emplace_back(nnvm::NodeEntry{pNode, i, 0}); + + return ret; +} + // make a new node with operator op_name. Inputs are not filled. inline nnvm::NodePtr MakeNode( const char* op_name, const std::string& name, - std::vector const * inputs, - std::unordered_map const * dict, - nnvm::NodePtr const * fwd_node) { + std::vector const * inputs = nullptr, + std::unordered_map const * dict = nullptr, + nnvm::NodePtr const * fwd_node = nullptr) { auto p = nnvm::Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = name; @@ -381,7 +398,6 @@ inline nnvm::NodePtr MakeNode( << "Number of inputs to operator " << op_name << " (" << p->num_inputs() << ") does not match the actual number of inputs provided to operator " << name << " (" << p->inputs.size() << ")."; - } return p; } @@ -401,11 +417,8 @@ inline std::vector MakeGradNode( const std::unordered_map& dict) { auto p = MakeNode(op_name, n->attrs.name + "_backward", &inputs, &dict, &n); - std::vector ret; - for (uint32_t i = 0; i < p->num_outputs(); ++i) { - ret.emplace_back(p, i, 0); - } - return ret; + + return CreateNodeEntries(p); } // quick helper to make gradient nodes that simply pass back zero. could be used in output ops. @@ -452,13 +465,8 @@ inline std::vector MakeNonlossGradNode( return MakeZeroGradNodes(n, ograds); auto p = MakeNode(op_name, n->attrs.name + "_backward", nullptr, &dict, &n); - p->inputs.insert(p->inputs.end(), ograds.begin(), ograds.end()); - p->inputs.insert(p->inputs.end(), inputs.begin(), inputs.end()); - std::vector ret; - for (uint32_t i = 0; i < p->num_outputs(); ++i) { - ret.emplace_back(p, i, 0); - } - return ret; + + return CreateNodeEntries(p, &ograds, &inputs); } /*! \brief Parse keyword arguments as PType arguments and save to parsed */ From 51bb16b775667599a20401d1806c7fb873bd4e58 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Jan 2020 20:36:44 -0800 Subject: [PATCH 15/18] Add guard to MakeNonlossGradNode --- src/operator/operator_common.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 7ae61248343e..29e31503da01 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -369,6 +369,13 @@ inline std::vectorCreateNodeEntries( if (pInputs) pNode->inputs.insert(pNode->inputs.end(), pInputs->begin(), pInputs->end()); + if (!pNode->is_variable()) { + CHECK_EQ(pNode->num_inputs(), pNode->inputs.size()) + << "Number of inputs to operator " << pNode->op()->name << " (" << pNode->num_inputs() + << ") does not match the actual number of inputs provided to operator " + << pNode->attrs.name << " (" << pNode->inputs.size() << ")."; + } + std::vector ret; for (uint32_t i = 0; i < pNode->num_outputs(); ++i) ret.emplace_back(nnvm::NodeEntry{pNode, i, 0}); From 5807913aec82a2073b61b929321077f523cf6d03 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Jan 2020 21:23:41 -0800 Subject: [PATCH 16/18] Fix --- src/operator/operator_common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 29e31503da01..929182630857 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -405,6 +405,7 @@ inline nnvm::NodePtr MakeNode( << "Number of inputs to operator " << op_name << " (" << p->num_inputs() << ") does not match the actual number of inputs provided to operator " << name << " (" << p->inputs.size() << ")."; + } return p; } From 1206e0b77ae6490b1321267b3959428924b413c8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 15 Jan 2020 08:56:46 -0800 Subject: [PATCH 17/18] Fix backward of SoftmaxActivation --- src/operator/nn/softmax_activation.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/softmax_activation.cc b/src/operator/nn/softmax_activation.cc index 9e5a3ab8f6a2..4779b7732688 100644 --- a/src/operator/nn/softmax_activation.cc +++ b/src/operator/nn/softmax_activation.cc @@ -67,6 +67,7 @@ Example:: .add_arguments(SoftmaxActivationParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_SoftmaxActivation) +.set_num_inputs(2) .set_num_outputs(1) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ From 3540a0e4dbb44d3547fdd41dcde62d3eee477d6d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 14 Jan 2020 22:09:11 -0800 Subject: [PATCH 18/18] Fix backward of np_prod and norm --- src/operator/numpy/np_broadcast_reduce_op_value.cc | 2 +- src/operator/tensor/broadcast_reduce_norm_value.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index ad6ac9504d52..4bc8e6737bcf 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -238,7 +238,7 @@ NNVM_REGISTER_OP(_np_prod) .set_attr("FGradient", ReduceGrad{"_backward_np_prod"}); NNVM_REGISTER_OP(_backward_np_prod) -.set_num_inputs(1) +.set_num_inputs(3) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cc b/src/operator/tensor/broadcast_reduce_norm_value.cc index 9acc157f8eca..557c4d9e7746 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cc +++ b/src/operator/tensor/broadcast_reduce_norm_value.cc @@ -105,6 +105,7 @@ Examples:: .add_arguments(NormParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_norm) +.set_num_inputs(3) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true)