diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 8e963461ec06..031dd952d386 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -785,7 +785,7 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs, .get_with_shape(Shape2(batch_size, col), s); Shape<1> sort_index_shape = Shape1(dshape.Size()); index_t workspace_size = sort_index_shape.Size(); - workspace_size += ((sort_index_shape.Size() * sizeof(int32_t) - 1) / sizeof(DType)) * 2; + workspace_size += (sort_index_shape.Size() * 2 * sizeof(int32_t) - 1) / sizeof(DType) + 1; Tensor workspace = ctx.requested[0] .get_space_typed(Shape1(workspace_size), s); Tensor scores_copy(workspace.dptr_,