diff --git a/plugin/warpctc/warpctc-inl.h b/plugin/warpctc/warpctc-inl.h index 37677d21fd14..9fcbedce74f1 100644 --- a/plugin/warpctc/warpctc-inl.h +++ b/plugin/warpctc/warpctc-inl.h @@ -254,7 +254,7 @@ class WarpCTCProp : public OperatorProperty { CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; const mxnet::TShape &dshape = in_shape->at(0); if (dshape.ndim() == 0) return false; - mxnet::TShape label_shape(dshape.ndim() - 1); + mxnet::TShape label_shape(dshape.ndim() - 1, 1); label_shape[0] = param_.label_length * (dshape[0] / param_.input_length); SHAPE_ASSIGN_CHECK(*in_shape, warpctc_enum::kLabel, label_shape);