diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 0fc9f20703af..81d96bb70670 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -127,7 +127,8 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray &output) { // For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to // naive implement. - if (input.shape().ndim() > 4 || !SupportMKLDNNQuantize(input.dtype())) { + const int input_ndims = input.shape().ndim(); + if ((input_ndims < 1 || input_ndims > 4) || !SupportMKLDNNQuantize(input.dtype())) { if (req != kWriteInplace) { FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, {input}, {req}, {output}); }