diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 74c563afceb1..7012a3c22f50 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -50,7 +50,7 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size(); } const TShape &dshape = (*in_shape)[rnn_enum::kData]; - if (dshape.ndim() == 0) return false; + if (!mxnet::ndim_is_known(dshape)) return false; CHECK_EQ(dshape.ndim(), 3U) \ << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; // data: [sequence len, batch, input dimension]