From 5854b98c0a2c3755af2a8df3235af0266f7bfcd5 Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 21 May 2019 11:40:46 -0700 Subject: [PATCH] Fix infer shape partial after unknown shape changed to -1 (#14869) * change check and shape_is_known * rever some changes * revert * revert * revert * add test * add more tests * update test dot * fix test * update reduce axes * fix lint * update check * fix lint * address comments * remove invalid test case * run all tests * update test case --- src/operator/instance_norm-inl.h | 2 +- src/operator/l2_normalization-inl.h | 12 +-- src/operator/l2_normalization.cc | 4 +- src/operator/nn/mkldnn/mkldnn_convolution.cc | 36 +++---- .../nn/mkldnn/mkldnn_deconvolution.cc | 18 ++-- src/operator/nn/pooling.cc | 4 +- src/operator/pooling_v1-inl.h | 4 +- src/operator/tensor/broadcast_reduce_op.h | 8 +- src/operator/tensor/control_flow_op.h | 1 + src/operator/tensor/dot-inl.h | 15 ++- .../tensor/elemwise_binary_broadcast_op.h | 4 +- src/operator/tensor/indexing_op.h | 4 +- src/operator/tensor/matrix_op-inl.h | 2 +- tests/python/unittest/test_infer_shape.py | 95 +++++++++++++++++-- 14 files changed, 149 insertions(+), 60 deletions(-) diff --git a/src/operator/instance_norm-inl.h b/src/operator/instance_norm-inl.h index b7e579e2d066..c71cbe043afd 100644 --- a/src/operator/instance_norm-inl.h +++ b/src/operator/instance_norm-inl.h @@ -113,7 +113,7 @@ class InstanceNormOp : public Operator { CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); - CHECK_GE(in_data[instance_norm::kData].ndim(), 3U) + CHECK_GE(in_data[instance_norm::kData].ndim(), 3) << "InstanceNorm only supports input tensors of rank > 2."; Stream *s = ctx.get_stream(); diff --git a/src/operator/l2_normalization-inl.h b/src/operator/l2_normalization-inl.h index 975e81f78c25..210d91823075 100644 --- a/src/operator/l2_normalization-inl.h +++ b/src/operator/l2_normalization-inl.h @@ -102,7 +102,7 @@ class L2NormalizationOp : public Operator { norm = F(norm); out = data / broadcast<0>(norm, out.shape_); } else if (param_.mode == l2_normalization::kChannel) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -120,7 +120,7 @@ class L2NormalizationOp : public Operator { norm = F(norm); out = data / broadcast_with_axis(norm, 0, orig_shape[1]); } else if (param_.mode == l2_normalization::kSpatial) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -174,7 +174,7 @@ class L2NormalizationOp : public Operator { (grad_out - data * broadcast<0>(temp, data.shape_)) / broadcast<0>(norm, data.shape_)); } else if (param_.mode == l2_normalization::kChannel) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = out_data[l2_normalization::kOut] @@ -193,7 +193,7 @@ class L2NormalizationOp : public Operator { (grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) / broadcast_with_axis(norm, 0, orig_shape[1])); } else if (param_.mode == l2_normalization::kSpatial) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = out_data[l2_normalization::kOut] @@ -273,12 +273,12 @@ class L2NormalizationProp : public OperatorProperty { if (param_.mode == l2_normalization::kInstance) { out_shape->push_back(Shape1(dshape[0])); } else if (param_.mode == l2_normalization::kChannel) { - CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in channel mode"; + CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in channel mode"; mxnet::TShape norm_shape = dshape; norm_shape[1] = 1; out_shape->push_back(norm_shape); } else if (param_.mode == l2_normalization::kSpatial) { - CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in spatial mode"; + CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in spatial mode"; out_shape->push_back(Shape2(dshape[0], dshape[1])); } else { return false; diff --git a/src/operator/l2_normalization.cc b/src/operator/l2_normalization.cc index 92307af814d2..cbe2caeb394e 100644 --- a/src/operator/l2_normalization.cc +++ b/src/operator/l2_normalization.cc @@ -70,7 +70,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp { } } } else if (this->param_.mode == l2_normalization::kChannel) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -94,7 +94,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp { } } } else if (this->param_.mode == l2_normalization::kSpatial) { - CHECK_GE(orig_shape.ndim(), 3U); + CHECK_GE(orig_shape.ndim(), 3); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index d32a6a343d7d..a394edeef841 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -63,15 +63,15 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP mkldnn::memory::dims strides(param.conv_param.kernel.ndim()); mkldnn::memory::dims padding(param.conv_param.kernel.ndim()); if (param.conv_param.kernel.ndim() == 1) { - CHECK_GE(param.conv_param.stride.ndim(), 1U); - CHECK_GE(param.conv_param.pad.ndim(), 1U); - CHECK_GE(param.conv_param.dilate.ndim(), 1U); + CHECK_GE(param.conv_param.stride.ndim(), 1); + CHECK_GE(param.conv_param.pad.ndim(), 1); + CHECK_GE(param.conv_param.dilate.ndim(), 1); strides[0] = param.conv_param.stride[0]; padding[0] = param.conv_param.pad[0]; } else if (param.conv_param.kernel.ndim() == 2) { - CHECK_GE(param.conv_param.stride.ndim(), 2U); - CHECK_GE(param.conv_param.pad.ndim(), 2U); - CHECK_GE(param.conv_param.dilate.ndim(), 2U); + CHECK_GE(param.conv_param.stride.ndim(), 2); + CHECK_GE(param.conv_param.pad.ndim(), 2); + CHECK_GE(param.conv_param.dilate.ndim(), 2); strides[0] = param.conv_param.stride[0]; strides[1] = param.conv_param.stride[1]; padding[0] = param.conv_param.pad[0]; @@ -169,15 +169,15 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( mkldnn::memory::dims strides(param.kernel.ndim()); mkldnn::memory::dims padding(param.kernel.ndim()); if (param.kernel.ndim() == 1) { - CHECK_GE(param.stride.ndim(), 1U); - CHECK_GE(param.pad.ndim(), 1U); - CHECK_GE(param.dilate.ndim(), 1U); + CHECK_GE(param.stride.ndim(), 1); + CHECK_GE(param.pad.ndim(), 1); + CHECK_GE(param.dilate.ndim(), 1); strides[0] = param.stride[0]; padding[0] = param.pad[0]; } else if (param.kernel.ndim() == 2) { - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.stride.ndim(), 2); + CHECK_GE(param.pad.ndim(), 2); + CHECK_GE(param.dilate.ndim(), 2); strides[0] = param.stride[0]; strides[1] = param.stride[1]; padding[0] = param.pad[0]; @@ -237,15 +237,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( mkldnn::memory::dims strides(param.kernel.ndim()); mkldnn::memory::dims padding(param.kernel.ndim()); if (param.kernel.ndim() == 1) { - CHECK_GE(param.stride.ndim(), 1U); - CHECK_GE(param.pad.ndim(), 1U); - CHECK_GE(param.dilate.ndim(), 1U); + CHECK_GE(param.stride.ndim(), 1); + CHECK_GE(param.pad.ndim(), 1); + CHECK_GE(param.dilate.ndim(), 1); strides[0] = param.stride[0]; padding[0] = param.pad[0]; } else if (param.kernel.ndim() == 2) { - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.stride.ndim(), 2); + CHECK_GE(param.pad.ndim(), 2); + CHECK_GE(param.dilate.ndim(), 2); strides[0] = param.stride[0]; strides[1] = param.stride[1]; padding[0] = param.pad[0]; diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index 4da48fa3f83c..aec5d13c5de9 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -90,9 +90,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.stride.ndim(), 2); + CHECK_GE(param.pad.ndim(), 2); + CHECK_GE(param.dilate.ndim(), 2); mkldnn::memory::dims strides{0, 0}; strides[0] = param.stride[0]; strides[1] = param.stride[1]; @@ -128,9 +128,9 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.stride.ndim(), 2); + CHECK_GE(param.pad.ndim(), 2); + CHECK_GE(param.dilate.ndim(), 2); mkldnn::memory::dims strides{0, 0}; strides[0] = param.stride[0]; strides[1] = param.stride[1]; @@ -153,9 +153,9 @@ GetDeconvBwdWeightsImpl( auto weight_md = GetWeightDesc(weights, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2U); - CHECK_GE(param.pad.ndim(), 2U); - CHECK_GE(param.dilate.ndim(), 2U); + CHECK_GE(param.stride.ndim(), 2); + CHECK_GE(param.pad.ndim(), 2); + CHECK_GE(param.dilate.ndim(), 2); mkldnn::memory::dims strides{0, 0}; strides[0] = param.stride[0]; strides[1] = param.stride[1]; diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 3e081c9a0552..870557756128 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -106,11 +106,11 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0) << "Same pooling convention disables the use of pad parameter."; } - CHECK_GE(dshape.ndim(), 3U) + CHECK_GE(dshape.ndim(), 3) << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; - CHECK_LE(dshape.ndim(), 5U) + CHECK_LE(dshape.ndim(), 5) << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; diff --git a/src/operator/pooling_v1-inl.h b/src/operator/pooling_v1-inl.h index 4241b08a0c5e..efd211312093 100644 --- a/src/operator/pooling_v1-inl.h +++ b/src/operator/pooling_v1-inl.h @@ -243,9 +243,9 @@ class PoolingV1Prop : public OperatorProperty { mxnet::ShapeVector *aux_shape) const override { CHECK_EQ(in_shape->size(), 1U); const mxnet::TShape &dshape = (*in_shape)[0]; - CHECK_GE(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x) " + CHECK_GE(dshape.ndim(), 4) << "Pooling: Input data should be 4D in (batch, channel, y, x) " << "Or 5D in (batch, channel, d, y, x)"; - CHECK_LE(dshape.ndim(), 5U) << "Pooling: Input data should be 4D in (batch, channel, y, x) " + CHECK_LE(dshape.ndim(), 5) << "Pooling: Input data should be 4D in (batch, channel, y, x) " << "Or 5D in (batch, channel, d, y, x)"; mxnet::TShape oshape = dshape; if (dshape.ndim() == -1) return false; diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index f7d9f13fd869..c33447254644 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -304,12 +304,12 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if (!shape_is_known((*in_attrs)[0])) return false; + if (!ndim_is_known((*in_attrs)[0])) return false; const ReduceAxesParam& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims, param.exclude)); - return true; + return shape_is_known((*out_attrs)[0]); } inline bool ReduceMinMaxAxesShape(const nnvm::NodeAttrs& attrs, @@ -1373,7 +1373,7 @@ inline bool PickOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); const mxnet::TShape& ishape = (*in_attrs)[0]; - if (ishape.ndim() == 0) return false; + if (!ndim_is_known(ishape)) return false; const PickParam& param = nnvm::get(attrs.parsed); if (!param.axis) LOG(FATAL) << "axis=None is not supported by pick yet. Must specify an axis."; @@ -1387,7 +1387,7 @@ inline bool PickOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*in_attrs, 1, ReduceAxisShapeImpl(ishape, param.axis, false)); } - return true; + return shape_is_known((*out_attrs)[0]); } inline bool PickOpType(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 96696b244bc3..8fda3344d8f1 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -175,6 +175,7 @@ inline bool WhereOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U) << "where operator takes 3 arguments (" << in_attrs->size() << " given)"; CHECK_EQ(out_attrs->size(), 1U); + if (!mxnet::shape_is_known((*in_attrs)[0])) return false; mxnet::TShape tshape((*in_attrs)[1]); if (!shape_assign(&tshape, (*in_attrs)[2])) return false; diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index f81eb9c04f3a..77e8e36bbef8 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -1207,6 +1207,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& lshape = (*in_attrs)[0]; mxnet::TShape& rshape = (*in_attrs)[1]; + if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false; + CHECK_GT(lshape.ndim(), 0) << "scalar tensor is not supported by this operator."; + CHECK_GT(rshape.ndim(), 0) << "scalar tensor is not supported by this operator."; if (lshape.ndim() == 1 && rshape.ndim() == 1) { CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; @@ -1243,7 +1246,8 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs, mxnet::TShape oshape(buf.begin(), buf.end()); SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); } - return true; + // return true if output shape is fully inferred + return shape_is_known((*out_attrs)[0]); } template @@ -1479,7 +1483,13 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, const DotParam& param = nnvm::get(attrs.parsed); mxnet::TShape& lshape = (*in_attrs)[0]; mxnet::TShape& rshape = (*in_attrs)[1]; + // return false if lhs and rhs both have fully unknown shape + if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false; if (lshape.ndim() == 3 && rshape.ndim() == 3) { + // only partially infer shape if last dim of lhs and second dim of rhs is known + bool last_dim_known = dim_size_is_known(lshape, 2); + bool second_dim_known = dim_size_is_known(rshape, 1); + if ( !last_dim_known || !second_dim_known) return false; CHECK(lshape[0] == rshape[0]) << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; @@ -1495,7 +1505,8 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "batch_dot currently only support 3D*3D array" << lshape << " v.s. " << rshape; } - return true; + // return true if output shape is fully inferred + return shape_is_known((*out_attrs)[0]); } } // namespace op diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 73019fa8389b..f84767dd4b2f 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -452,7 +452,7 @@ void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - CHECK_LE(inputs[1].shape().ndim(), 2U) + CHECK_LE(inputs[1].shape().ndim(), 2) << "input dense matrix should have less than or equal to 2 dimensions"; if (req[0] == kNullOp) return; const NDArray& lhs = inputs[0]; @@ -488,7 +488,7 @@ void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - CHECK_LE(inputs[1].shape().ndim(), 2U) + CHECK_LE(inputs[1].shape().ndim(), 2) << "input dense matrix should have less than or equal to 2 dimensions"; if (req[0] == kNullOp) return; const NDArray& lhs = inputs[0]; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index e8c5e884588b..84b6a65dd29e 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -145,7 +145,7 @@ inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { using namespace mshadow; const mxnet::TShape &dshape = (*in_attrs)[embedding::kData]; - if (!shape_is_known(dshape)) return false; + if (!ndim_is_known(dshape)) return false; const ParamType& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim, param.output_dim)); @@ -1075,7 +1075,7 @@ inline bool BatchTakeOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*in_attrs, 1, (*out_attrs)[0]); } if ((*in_attrs)[0].ndim() == 0) return false; - CHECK_GE((*in_attrs)[0].ndim(), 2U) << "Data array must have at least 2 dimensional"; + CHECK_GE((*in_attrs)[0].ndim(), 2) << "Data array must have at least 2 dimensional"; if ((*out_attrs)[0].ndim() == 0) return false; CHECK_EQ((*in_attrs)[0].Size()/(*in_attrs)[0][(*in_attrs)[0].ndim()-1], (*out_attrs)[0].Size()) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index e9f3a40afed7..50cb1ae4d1df 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -344,7 +344,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; - CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions"; + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; mxnet::TShape ret(shp.ndim(), -1); if (param.axes.ndim() == 0) { for (int i = 0; i < shp.ndim(); ++i) { diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 612861bd8303..2bf7e8bf9d71 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -162,13 +162,90 @@ def test_shape_completely_unknown(): assert out_shapes[0] is None +def test_dot_partial_shape(): + x = mx.sym.Variable("x") + y = mx.sym.Variable("y") + z = mx.sym.dot(x, y) + # batch size(first dim) of lhs unknown + _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5)) + assert result_shape == [(0, 3, 5)] + with mx.np_compat(True): + _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5)) + assert result_shape == [(-1, 3, 5)] + + +def test_batch_dot_partial_shape(): + x = mx.sym.Variable("x") + y = mx.sym.Variable("y") + z = mx.sym.batch_dot(x, y) + # lhs and rhs batch size unknown + _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5)) + assert result_shape == [(0, 3, 5)] + # rhs second dim unknown + _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5)) + assert result_shape == [()] + with mx.np_compat(True): + _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5)) + assert result_shape == [(-1, 3, 5)] + _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5)) + assert result_shape == [None] + + +def test_embedding_partial_shape(): + # testing embedding with batch size unknown + x = mx.sym.Variable("x") + w = mx.sym.Variable("w") + y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10) + _, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10)) + assert result_shape == [(0, 5, 10)] + with mx.np_compat(True): + _, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10)) + assert result_shape == [(-1, 5, 10)] + + +def test_transpose_partial_shape(): + # test converting tensor shape + # from channels first to channels last + # with batch size unknown + axes = [0, 3, 2, 1] + x = mx.sym.Variable("x") + y = mx.sym.transpose(x, axes=axes) + _, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224)) + assert result == [(0, 224, 224, 3)] + + with mx.np_compat(True): + _, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224)) + assert result == [(-1, 224, 224, 3)] + + +def test_pick_partial_shape(): + x = mx.sym.Variable("x") + index = mx.sym.Variable("index") + y = mx.sym.pick(x, index, axis=1) + # batch size unknown + _, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,)) + assert result == [(0, 3)] + with mx.np_compat(True): + _, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,)) + assert result == [(-1, 3)] + + +def test_where_partial_shape(): + x = mx.sym.Variable("x") + y = mx.sym.Variable("y") + cond = mx.sym.Variable("cond") + where_op = mx.sym.where(cond, x, y) + # condition must be fully known to infer shape + _, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2)) + assert result == [()] + _, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2)) + assert result == [()] + with mx.np_compat(True): + _, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2)) + assert result == [None] + _, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2)) + assert result == [None] + if __name__ == "__main__": - test_mlp2_infer_shape() - test_mlp2_infer_error() - test_backward_infer() - test_incomplete_infer_elewise() - test_incomplete_infer_mlp() - test_incomplete_infer_slicechannel() - test_incomplete_infer_convolution() - test_incomplete_infer_concat() - test_shape_completely_unknown() + import nose + nose.runmodule()