From d3fb9818e98e8b92c7abf37cbfae0ddb0e650533 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 15 Mar 2019 12:56:04 -0700 Subject: [PATCH] [Numpy] Change semantics of ndim for operators in `src/operator/contrib` (#14409) * Initial commit * Address comments --- src/operator/contrib/adamw-inl.h | 5 +-- .../contrib/adaptive_avg_pooling-inl.h | 6 ++-- src/operator/contrib/bilinear_resize-inl.h | 2 +- src/operator/contrib/boolean_mask.cc | 2 +- src/operator/contrib/bounding_box-inl.h | 4 ++- src/operator/contrib/count_sketch-inl.h | 2 +- .../contrib/deformable_convolution-inl.h | 14 ++++---- src/operator/contrib/dgl_graph.cc | 32 ++++++------------- src/operator/contrib/fft-inl.h | 2 +- src/operator/contrib/ifft-inl.h | 2 +- src/operator/contrib/index_copy-inl.h | 3 +- src/operator/contrib/multi_proposal-inl.h | 2 +- src/operator/contrib/nnvm_to_onnx.cc | 3 +- src/operator/contrib/optimizer_op.cc | 2 +- src/operator/contrib/proposal-inl.h | 2 +- src/operator/contrib/quadratic_op-inl.h | 2 +- src/operator/contrib/sync_batch_norm-inl.h | 2 +- src/operator/contrib/transformer-inl.h | 4 ++- 18 files changed, 41 insertions(+), 50 deletions(-) diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 07feaefe87aa..6ae9e46b7def 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -87,8 +87,9 @@ inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; - // rescale_grad.shape = (1,) - SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mshadow::Shape1(1)); + // rescale_grad.shape = () + SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); + // TODO(@reminisce): change "none" behavior in ElemwiseAttr return ElemwiseAttr( attrs, in_attrs, out_attrs, mxnet::TShape()); } diff --git a/src/operator/contrib/adaptive_avg_pooling-inl.h b/src/operator/contrib/adaptive_avg_pooling-inl.h index 0d66de0a5692..eedab78db0c5 100644 --- a/src/operator/contrib/adaptive_avg_pooling-inl.h +++ b/src/operator/contrib/adaptive_avg_pooling-inl.h @@ -48,9 +48,9 @@ namespace mxnet { namespace op { struct AdaptiveAvgPoolParam : public dmlc::Parameter { - mxnet::TShape output_size; + mxnet::Tuple output_size; DMLC_DECLARE_PARAMETER(AdaptiveAvgPoolParam) { - DMLC_DECLARE_FIELD(output_size).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(output_size).set_default(mxnet::Tuple()) .describe("int (output size) or a tuple of int for output (height, width)."); } }; @@ -125,7 +125,7 @@ static bool AdaptiveAvgPoolOpInferShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const AdaptiveAvgPoolParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; if (param.output_size.ndim() == 0) { dshape[2] = 1; dshape[3] = 1; diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index 46c8e1aa7c0d..ce9c6c83504c 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -134,7 +134,7 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const BilinearSampleParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; if (param.scale_height.has_value()) { dshape[2] = static_cast(param.scale_height.value() * in_shape->at(0)[2]); } else { diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index e22c493d5e2c..06d8439e23a0 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -121,7 +121,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const NDArray &out = outputs[0]; CHECK_EQ(axis, 0) << "Not supported yet"; CHECK_EQ(data.shape()[axis], idx.shape()[0]); - CHECK_EQ(idx.shape().ndim(), 1U); + CHECK_EQ(idx.shape().ndim(), 1U); // idx is required to be 1-d. // count the number of 1s in `idx`, so that we could know the output dimension size_t idx_size = idx.shape()[0]; std::vector prefix_sum(idx_size, 0); diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 37c4297ff49d..059327ef8334 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -94,7 +94,9 @@ inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs, const BoxNMSParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 2U); - if (in_attrs->at(0).ndim() == 0U && out_attrs->at(0).ndim() == 0U) { + // TODO(@junrushao1994): verify with Joshua Z. Zhang about this operator + if (mxnet::op::shape_is_none(in_attrs->at(0)) + && mxnet::op::shape_is_none(out_attrs->at(0))) { return false; } diff --git a/src/operator/contrib/count_sketch-inl.h b/src/operator/contrib/count_sketch-inl.h index f3a294f6ad46..3ea93e63d6fc 100644 --- a/src/operator/contrib/count_sketch-inl.h +++ b/src/operator/contrib/count_sketch-inl.h @@ -151,7 +151,7 @@ class CountSketchProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 3) <<"Input:[data, h, s]"; const mxnet::TShape &dshape = (*in_shape)[CountSketch::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/deformable_convolution-inl.h b/src/operator/contrib/deformable_convolution-inl.h index f50641fca6d6..3e96cad1c859 100644 --- a/src/operator/contrib/deformable_convolution-inl.h +++ b/src/operator/contrib/deformable_convolution-inl.h @@ -69,11 +69,11 @@ struct DeformableConvolutionParam : public dmlc::Parameter layout; DMLC_DECLARE_PARAMETER(DeformableConvolutionParam) { DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)"); - DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0)) .describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0)) .describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension."); - DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape()) + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0)) .describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Convolution filter(channel) number"); @@ -347,9 +347,9 @@ class DeformableConvolutionProp : public OperatorProperty { param_.Init(kwargs); if (param_.kernel.ndim() == 2) { param_.layout = param_.layout ? param_.layout.value() : mshadow::kNCHW; - if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); - if (param_.dilate.ndim() == 0) param_.dilate = Shape2(1, 1); - if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + if (mxnet::op::shape_is_none(param_.stride)) param_.stride = Shape2(1, 1); + if (mxnet::op::shape_is_none(param_.dilate)) param_.dilate = Shape2(1, 1); + if (mxnet::op::shape_is_none(param_.pad)) param_.pad = Shape2(0, 0); } else { LOG(FATAL) << "not implemented"; } @@ -371,7 +371,7 @@ class DeformableConvolutionProp : public OperatorProperty { out_shape->resize(1, mxnet::TShape()); const mxnet::TShape &dshp = (*in_shape)[conv::kData]; const mxnet::TShape &oshp = (*in_shape)[conv::kOffset]; - if (dshp.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshp)) return false; if (param_.kernel.ndim() == 2) { // 2d conv CHECK_EQ(dshp.ndim(), 4U) \ diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc index f19af84ce9c6..02ef2cee1caa 100644 --- a/src/operator/contrib/dgl_graph.cc +++ b/src/operator/contrib/dgl_graph.cc @@ -265,9 +265,7 @@ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, out_shape[0] = params.max_num_vertices + 1; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape); - success = success && - out_attrs->at(i).ndim() != 0U && - out_attrs->at(i).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i)); } // sub_csr mxnet::TShape out_csr_shape(2); @@ -275,18 +273,14 @@ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, out_csr_shape[1] = in_attrs->at(0)[1]; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape); - success = success && - out_attrs->at(i + num_subgraphs).ndim() != 0U && - out_attrs->at(i + num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + num_subgraphs)); } // sub_layer mxnet::TShape out_layer_shape(1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_layer_shape); - success = success && - out_attrs->at(i + 2*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 2*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } return success; @@ -323,9 +317,7 @@ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, out_shape[0] = params.max_num_vertices + 1; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i, out_shape); - success = success && - out_attrs->at(i).ndim() != 0U && - out_attrs->at(i).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i)); } // sub_csr mxnet::TShape out_csr_shape(2); @@ -333,27 +325,21 @@ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, out_csr_shape[1] = in_attrs->at(0)[1]; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, out_csr_shape); - success = success && - out_attrs->at(i + num_subgraphs).ndim() != 0U && - out_attrs->at(i + num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + num_subgraphs)); } // sub_probability mxnet::TShape out_prob_shape(1); out_prob_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_prob_shape); - success = success && - out_attrs->at(i + 2*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 2*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } // sub_layer mxnet::TShape out_layer_shape(1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { SHAPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, out_prob_shape); - success = success && - out_attrs->at(i + 3*num_subgraphs).ndim() != 0U && - out_attrs->at(i + 3*num_subgraphs).Size() != 0U; + success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 3 * num_subgraphs)); } return success; @@ -1199,7 +1185,7 @@ inline bool EdgeIDShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 2, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool EdgeIDType(const nnvm::NodeAttrs& attrs, @@ -1357,7 +1343,7 @@ inline bool DGLAdjacencyShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool DGLAdjacencyType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/contrib/fft-inl.h b/src/operator/contrib/fft-inl.h index 247f6290c02a..a5471b4ba2e2 100644 --- a/src/operator/contrib/fft-inl.h +++ b/src/operator/contrib/fft-inl.h @@ -241,7 +241,7 @@ class FFTProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1) <<"Input:[data]"; const mxnet::TShape &dshape = (*in_shape)[fft::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/ifft-inl.h b/src/operator/contrib/ifft-inl.h index e53c0f60fa9e..7d8422e838b1 100644 --- a/src/operator/contrib/ifft-inl.h +++ b/src/operator/contrib/ifft-inl.h @@ -231,7 +231,7 @@ class IFFTProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 1) <<"Input:[data]"; const mxnet::TShape &dshape = (*in_shape)[ifft::kData]; // require data to be known - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; out_shape->clear(); if (dshape.ndim() == 4) { diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 903dee13272b..35f88916da20 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -76,8 +76,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]); SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && - out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } } // namespace op diff --git a/src/operator/contrib/multi_proposal-inl.h b/src/operator/contrib/multi_proposal-inl.h index 4b9a41c2fa87..a9afb8e4114e 100644 --- a/src/operator/contrib/multi_proposal-inl.h +++ b/src/operator/contrib/multi_proposal-inl.h @@ -108,7 +108,7 @@ class MultiProposalProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]"; const mxnet::TShape &dshape = in_shape->at(proposal::kClsProb); - if (dshape.ndim() == 0) return false; + if (!mxnet::op::shape_is_none(dshape)) return false; Shape<4> bbox_pred_shape; bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]); SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred, diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc index 0417a085616a..0c8bd79490e3 100644 --- a/src/operator/contrib/nnvm_to_onnx.cc +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -417,7 +417,8 @@ std::unordered_map GetPlaceholderShapes( for (uint32_t i = 0; i < shape_inputs.size(); ++i) { std::string name = ig[ig.input_nodes()[i]].source->attrs.name; mxnet::TShape shp = shape_inputs[i]; - if (shp.ndim() > 0) { + if (!mxnet::op::shape_is_none(shp)) { + // TODO(@reminisce): confirm placeholder_shapes.emplace(name, shp); } } diff --git a/src/operator/contrib/optimizer_op.cc b/src/operator/contrib/optimizer_op.cc index 9f948bad81b6..83bbcdab833d 100644 --- a/src/operator/contrib/optimizer_op.cc +++ b/src/operator/contrib/optimizer_op.cc @@ -45,7 +45,7 @@ inline bool GroupAdagradShape(const nnvm::NodeAttrs &attrs, SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U && + return !mxnet::op::shape_is_none(out_attrs->at(0)) && (in_attrs->at(0)[0] == in_attrs->at(1)[0]) && (in_attrs->at(0)[0] == in_attrs->at(2)[0]); } diff --git a/src/operator/contrib/proposal-inl.h b/src/operator/contrib/proposal-inl.h index 9908ca96ec5f..21e9fe198e63 100644 --- a/src/operator/contrib/proposal-inl.h +++ b/src/operator/contrib/proposal-inl.h @@ -106,7 +106,7 @@ class ProposalProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]"; const mxnet::TShape &dshape = in_shape->at(proposal::kClsProb); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; Shape<4> bbox_pred_shape; bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]); SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred, diff --git a/src/operator/contrib/quadratic_op-inl.h b/src/operator/contrib/quadratic_op-inl.h index e679fedc8e57..a7aca63de17a 100644 --- a/src/operator/contrib/quadratic_op-inl.h +++ b/src/operator/contrib/quadratic_op-inl.h @@ -60,7 +60,7 @@ inline bool QuadraticOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return out_attrs->at(0).ndim() != 0U && out_attrs->at(0).Size() != 0U; + return !mxnet::op::shape_is_none(out_attrs->at(0)); } inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 1e6ab25db0e2..cd1a3285fe06 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -482,7 +482,7 @@ class SyncBatchNormProp : public OperatorProperty { using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; + if (mxnet::op::shape_is_none(dshape)) return false; in_shape->at(1) = mxnet::TShape(Shape1(dshape[1])); in_shape->at(2) = mxnet::TShape(Shape1(dshape[1])); out_shape->clear(); diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h index 01faf244aff9..da3d14e33cf4 100644 --- a/src/operator/contrib/transformer-inl.h +++ b/src/operator/contrib/transformer-inl.h @@ -41,7 +41,9 @@ static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { mshadow::Stream *s = ctx.get_stream(); - double sqrt_dim = std::sqrt(static_cast(inputs[0].shape_[inputs[0].ndim() - 1])); + CHECK_GE(inputs[0].ndim(), 1); + int last_idx = inputs[0].ndim() - 1; + double sqrt_dim = std::sqrt(static_cast(inputs[0].shape_[last_idx])); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch(