Skip to content

Commit

Permalink
Fix infer shape partial after unknown shape changed to -1 (apache#14869)
Browse files Browse the repository at this point in the history
* change check and shape_is_known

* rever some changes

* revert

* revert

* revert

* add test

* add more tests

* update test dot

* fix test

* update reduce axes

* fix lint

* update check

* fix lint

* address comments

* remove invalid test case

* run all tests

* update test case
  • Loading branch information
roywei authored and reminisce committed May 21, 2019
1 parent aa55e3d commit 5854b98
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/operator/instance_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class InstanceNormOp : public Operator {
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);

CHECK_GE(in_data[instance_norm::kData].ndim(), 3U)
CHECK_GE(in_data[instance_norm::kData].ndim(), 3)
<< "InstanceNorm only supports input tensors of rank > 2.";

Stream<xpu> *s = ctx.get_stream<xpu>();
Expand Down
12 changes: 6 additions & 6 deletions src/operator/l2_normalization-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class L2NormalizationOp : public Operator {
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast<0>(norm, out.shape_);
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
Expand All @@ -120,7 +120,7 @@ class L2NormalizationOp : public Operator {
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
Expand Down Expand Up @@ -174,7 +174,7 @@ class L2NormalizationOp : public Operator {
(grad_out - data * broadcast<0>(temp, data.shape_)) /
broadcast<0>(norm, data.shape_));
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
Expand All @@ -193,7 +193,7 @@ class L2NormalizationOp : public Operator {
(grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
broadcast_with_axis(norm, 0, orig_shape[1]));
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
Expand Down Expand Up @@ -273,12 +273,12 @@ class L2NormalizationProp : public OperatorProperty {
if (param_.mode == l2_normalization::kInstance) {
out_shape->push_back(Shape1(dshape[0]));
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in channel mode";
CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in channel mode";
mxnet::TShape norm_shape = dshape;
norm_shape[1] = 1;
out_shape->push_back(norm_shape);
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in spatial mode";
CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in spatial mode";
out_shape->push_back(Shape2(dshape[0], dshape[1]));
} else {
return false;
Expand Down
4 changes: 2 additions & 2 deletions src/operator/l2_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
}
}
} else if (this->param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
Expand All @@ -94,7 +94,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
}
}
} else if (this->param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
Expand Down
36 changes: 18 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
mkldnn::memory::dims strides(param.conv_param.kernel.ndim());
mkldnn::memory::dims padding(param.conv_param.kernel.ndim());
if (param.conv_param.kernel.ndim() == 1) {
CHECK_GE(param.conv_param.stride.ndim(), 1U);
CHECK_GE(param.conv_param.pad.ndim(), 1U);
CHECK_GE(param.conv_param.dilate.ndim(), 1U);
CHECK_GE(param.conv_param.stride.ndim(), 1);
CHECK_GE(param.conv_param.pad.ndim(), 1);
CHECK_GE(param.conv_param.dilate.ndim(), 1);
strides[0] = param.conv_param.stride[0];
padding[0] = param.conv_param.pad[0];
} else if (param.conv_param.kernel.ndim() == 2) {
CHECK_GE(param.conv_param.stride.ndim(), 2U);
CHECK_GE(param.conv_param.pad.ndim(), 2U);
CHECK_GE(param.conv_param.dilate.ndim(), 2U);
CHECK_GE(param.conv_param.stride.ndim(), 2);
CHECK_GE(param.conv_param.pad.ndim(), 2);
CHECK_GE(param.conv_param.dilate.ndim(), 2);
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
padding[0] = param.conv_param.pad[0];
Expand Down Expand Up @@ -169,15 +169,15 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
CHECK_GE(param.stride.ndim(), 1);
CHECK_GE(param.pad.ndim(), 1);
CHECK_GE(param.dilate.ndim(), 1);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
Expand Down Expand Up @@ -237,15 +237,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
CHECK_GE(param.stride.ndim(), 1);
CHECK_GE(param.pad.ndim(), 1);
CHECK_GE(param.dilate.ndim(), 1);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
Expand Down
18 changes: 9 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand Down Expand Up @@ -128,9 +128,9 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand All @@ -153,9 +153,9 @@ GetDeconvBwdWeightsImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0)
<< "Same pooling convention disables the use of pad parameter.";
}
CHECK_GE(dshape.ndim(), 3U)
CHECK_GE(dshape.ndim(), 3)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
CHECK_LE(dshape.ndim(), 5U)
CHECK_LE(dshape.ndim(), 5)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
Expand Down
4 changes: 2 additions & 2 deletions src/operator/pooling_v1-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ class PoolingV1Prop : public OperatorProperty {
mxnet::ShapeVector *aux_shape) const override {
CHECK_EQ(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
CHECK_GE(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
CHECK_GE(dshape.ndim(), 4) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
<< "Or 5D in (batch, channel, d, y, x)";
CHECK_LE(dshape.ndim(), 5U) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
CHECK_LE(dshape.ndim(), 5) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
<< "Or 5D in (batch, channel, d, y, x)";
mxnet::TShape oshape = dshape;
if (dshape.ndim() == -1) return false;
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,12 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known((*in_attrs)[0])) return false;
if (!ndim_is_known((*in_attrs)[0])) return false;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
param.keepdims, param.exclude));
return true;
return shape_is_known((*out_attrs)[0]);
}

inline bool ReduceMinMaxAxesShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -1373,7 +1373,7 @@ inline bool PickOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (ishape.ndim() == 0) return false;
if (!ndim_is_known(ishape)) return false;
const PickParam& param = nnvm::get<PickParam>(attrs.parsed);
if (!param.axis) LOG(FATAL)
<< "axis=None is not supported by pick yet. Must specify an axis.";
Expand All @@ -1387,7 +1387,7 @@ inline bool PickOpShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*in_attrs, 1,
ReduceAxisShapeImpl(ishape, param.axis, false));
}
return true;
return shape_is_known((*out_attrs)[0]);
}

inline bool PickOpType(const nnvm::NodeAttrs& attrs,
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
if (!mxnet::shape_is_known((*in_attrs)[0])) return false;

mxnet::TShape tshape((*in_attrs)[1]);
if (!shape_assign(&tshape, (*in_attrs)[2])) return false;
Expand Down
15 changes: 13 additions & 2 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
CHECK_GT(lshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
CHECK_GT(rshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
if (lshape.ndim() == 1 && rshape.ndim() == 1) {
CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors";
CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape;
Expand Down Expand Up @@ -1243,7 +1246,8 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape oshape(buf.begin(), buf.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
}
return true;
// return true if output shape is fully inferred
return shape_is_known((*out_attrs)[0]);
}

template<typename xpu>
Expand Down Expand Up @@ -1479,7 +1483,13 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
// return false if lhs and rhs both have fully unknown shape
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
if (lshape.ndim() == 3 && rshape.ndim() == 3) {
// only partially infer shape if last dim of lhs and second dim of rhs is known
bool last_dim_known = dim_size_is_known(lshape, 2);
bool second_dim_known = dim_size_is_known(rshape, 1);
if ( !last_dim_known || !second_dim_known) return false;
CHECK(lshape[0] == rshape[0])
<< "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape
<< " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b;
Expand All @@ -1495,7 +1505,8 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
LOG(FATAL) << "batch_dot currently only support 3D*3D array"
<< lshape << " v.s. " << rshape;
}
return true;
// return true if output shape is fully inferred
return shape_is_known((*out_attrs)[0]);
}

} // namespace op
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2U)
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp) return;
const NDArray& lhs = inputs[0];
Expand Down Expand Up @@ -488,7 +488,7 @@ void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2U)
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp) return;
const NDArray& lhs = inputs[0];
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
using namespace mshadow;
const mxnet::TShape &dshape = (*in_attrs)[embedding::kData];
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim,
param.output_dim));
Expand Down Expand Up @@ -1075,7 +1075,7 @@ inline bool BatchTakeOpShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*in_attrs, 1, (*out_attrs)[0]);
}
if ((*in_attrs)[0].ndim() == 0) return false;
CHECK_GE((*in_attrs)[0].ndim(), 2U) << "Data array must have at least 2 dimensional";
CHECK_GE((*in_attrs)[0].ndim(), 2) << "Data array must have at least 2 dimensional";
if ((*out_attrs)[0].ndim() == 0) return false;
CHECK_EQ((*in_attrs)[0].Size()/(*in_attrs)[0][(*in_attrs)[0].ndim()-1],
(*out_attrs)[0].Size())
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions";
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
mxnet::TShape ret(shp.ndim(), -1);
if (param.axes.ndim() == 0) {
for (int i = 0; i < shp.ndim(); ++i) {
Expand Down
Loading

0 comments on commit 5854b98

Please sign in to comment.