Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix operators lying about their number of inputs #17049

Merged
merged 21 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/operator/contrib/bilinear_resize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,6 @@ inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) {
}
}

inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
return 3;
} else {
return 1;
}
}

inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/bilinear_resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ for more details.

NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
.set_attr_parser(ParamParser<BilinearSampleParam>)
.set_num_inputs(BilinearSampleOpNumBackwardInputs)
.set_num_inputs(1)
.set_num_outputs(BilinearSampleOpNumBackwardOutputs)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/bounding_box.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Examples::
.add_arguments(BoxNMSParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_contrib_box_nms)
.set_num_inputs(3)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<BoxNMSParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/roi_align.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017


NNVM_REGISTER_OP(_backward_ROIAlign)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<ROIAlignParam>)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/custom/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ Please check the tutorial here: https://mxnet.incubator.apache.org/api/faq/new_o
NNVM_REGISTER_OP(_backward_Custom)
.set_num_inputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
return params.bwd_idx.size();
return params.bwd_idx.size() + params.num_auxs;
})
.set_num_outputs([](const NodeAttrs& attrs){
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/image/image_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ NNVM_REGISTER_OP(_image_normalize)

NNVM_REGISTER_OP(_backward_image_normalize)
.set_attr_parser(ParamParser<NormalizeParam>)
.set_num_inputs(1)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NormalizeOpBackward<cpu>);
Expand Down
13 changes: 13 additions & 0 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,19 @@ The following modified ReLU Activation functions are supported:
});

NNVM_REGISTER_OP(_backward_LeakyReLU)
.set_num_inputs([](const NodeAttrs& attrs) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
if (param.act_type == leakyrelu::kPReLU) {
// forward has 2 inputs and 1 output
return 2 + 2 * 1;
} else if (param.act_type == leakyrelu::kRReLU) {
// forward has 1 input and 2 outputs
return 1 + 2 * 2;
} else {
// forward has 1 input and 1 output
return 1 + 2 * 1;
}
})
.set_num_outputs([](const NodeAttrs& attrs) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
return param.act_type == leakyrelu::kPReLU ? 2 : 1;
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ then set ``gamma`` to 1 and its gradient to 0.
});

NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_inputs(8)
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
Expand Down
8 changes: 8 additions & 0 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ CONCAT_FORWARD_ATTRS
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Concat)
.set_num_inputs([](const NodeAttrs& attrs) {
#if MXNET_USE_MKLDNN
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return 1 + params.num_args;
#else
return 1;
#endif
})
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ There are other options to tune the performance.
.add_arguments(ConvolutionParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Convolution)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
return params.no_bias ? 3 : 4;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
return params.no_bias ? 2 : 3;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/ctc_loss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ information on the definition and the algorithm.

NNVM_REGISTER_OP(_backward_ctc_loss)
.set_attr_parser(ParamParser<CTCLossOpParam>)
.set_num_inputs(1)
.set_num_inputs(4)
.set_num_outputs(CTCLossOpNumInputs)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CTCLossOpBackward<cpu>);
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ NNVM_REGISTER_OP(Deconvolution)
.add_arguments(DeconvolutionParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Deconvolution)
.set_num_inputs([](const NodeAttrs& attrs) {
const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
return params.no_bias ? 3 : 4;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
return params.no_bias ? 2 : 3;
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ number of kernels in the layer.
.add_arguments(LRNParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_LRN)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LRNParam>)
#if MXNET_USE_MKLDNN == 1
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ For each window ``X``, the mathematical expression for Lp pooling is:
.add_arguments(PoolingParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Pooling)
.set_num_inputs([](const NodeAttrs& attrs) {
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
// 1 input to fwd op and 2 * outputs from fwd op (fwd outputs and gradient inputs)
return 1 + 2 * GetNumOutputs(param);
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>(
Expand Down
8 changes: 8 additions & 0 deletions src/operator/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ Example::
});

NNVM_REGISTER_OP(_backward_UpSampling)
.set_num_inputs([](const NodeAttrs& attrs) {
const UpSamplingParam& param_ = nnvm::get<UpSamplingParam>(attrs.parsed);
if (param_.sample_type != up_enum::kNearest) {
return 3;
} else {
return 1;
}
})
.set_num_outputs([](const NodeAttrs& attrs) {
const UpSamplingParam& params = nnvm::get<UpSamplingParam>(attrs.parsed);
return params.sample_type == up_enum::kNearest ? params.num_args : 2;
Expand Down
6 changes: 6 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ inline nnvm::NodePtr MakeNode(
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
if (inputs != nullptr) {
CHECK_EQ(p->num_inputs(), p->inputs.size())
<< "Number of inputs to operator " << op_name << " (" << p->num_inputs()
<< ") does not match the actual number of inputs provided to operator "
<< name << " (" << p->inputs.size() << ").";
}
return p;
}

Expand Down
14 changes: 14 additions & 0 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,20 @@ The definition of GRU here is slightly different from paper but compatible with
.add_arguments(RNNParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_RNN)
.set_num_inputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
int ret = 5;
if (params.state_outputs) {
ret += 2;
}
if (params.mode == rnn_enum::kLstm) {
++ret;
if (params.state_outputs) {
ret += 2;
}
}
return ret;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
return GetNumInputArguments(params);
Expand Down
1 change: 1 addition & 0 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ NNVM_REGISTER_OP(SoftmaxOutput)
NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax");

NNVM_REGISTER_OP(_backward_SoftmaxOutput)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
Expand Down