diff --git a/paddle/phi/kernels/cpu/yolo_box_kernel.cc b/paddle/phi/kernels/cpu/yolo_box_kernel.cc index c80e99e9ea8bdc..dff9f544639347 100644 --- a/paddle/phi/kernels/cpu/yolo_box_kernel.cc +++ b/paddle/phi/kernels/cpu/yolo_box_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/yolo_box_util.h" namespace phi { @@ -35,6 +36,14 @@ void YoloBoxKernel(const Context& dev_ctx, float iou_aware_factor, DenseTensor* boxes, DenseTensor* scores) { + if (x.numel() == 0 || img_size.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(boxes->dims())), 0, boxes); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(scores->dims())), 0, scores); + return; + } + auto* input = &x; auto* imgsize = &img_size; float scale = scale_x_y; diff --git a/paddle/phi/kernels/gpu/yolo_box_kernel.cu b/paddle/phi/kernels/gpu/yolo_box_kernel.cu index 8616b8bb429556..5e68e424fc8196 100644 --- a/paddle/phi/kernels/gpu/yolo_box_kernel.cu +++ b/paddle/phi/kernels/gpu/yolo_box_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/yolo_box_util.h" @@ -115,6 +116,14 @@ void YoloBoxKernel(const Context& dev_ctx, float iou_aware_factor, DenseTensor* boxes, DenseTensor* scores) { + if (x.numel() == 0 || img_size.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(boxes->dims())), 0, boxes); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(scores->dims())), 0, scores); + return; + } + auto* input = &x; float scale = scale_x_y; float bias = -0.5 * (scale - 1.); diff --git a/paddle/phi/kernels/xpu/yolo_box_kernel.cc b/paddle/phi/kernels/xpu/yolo_box_kernel.cc index cfb1e443650bc9..ce2493a6798be4 100644 --- a/paddle/phi/kernels/xpu/yolo_box_kernel.cc +++ b/paddle/phi/kernels/xpu/yolo_box_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/yolo_box_util.h" namespace phi { @@ -34,6 +35,14 @@ void YoloBoxKernel(const Context& dev_ctx, float iou_aware_factor, DenseTensor* boxes, DenseTensor* scores) { + if (x.numel() == 0 || img_size.numel() == 0) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(boxes->dims())), 0, boxes); + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(scores->dims())), 0, scores); + return; + } + using XPUType = typename XPUTypeTrait::Type; int r = 0; auto* input = &x; diff --git a/test/legacy_test/test_yolo_box_op.py b/test/legacy_test/test_yolo_box_op.py index a7199b2ca5a5a6..fe6371bbb1ea24 100644 --- a/test/legacy_test/test_yolo_box_op.py +++ b/test/legacy_test/test_yolo_box_op.py @@ -56,7 +56,13 @@ def YoloBox(x, img_size, attrs): (anchors[i], anchors[(i + 1)]) for i in range(0, len(anchors), 2) ] anchors_s = np.array( - [((an_w / input_w), (an_h / input_h)) for (an_w, an_h) in anchors] + [ + ( + (an_w / input_w if input_w > 0 else 0), + (an_h / input_h if input_h > 0 else 0), + ) + for (an_w, an_h) in anchors + ] ) anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) @@ -309,6 +315,23 @@ def initTestCase(self): self.iou_aware_factor = 0.5 +class TestYoloBoxOp_ZeroSize(TestYoloBoxOp): + def initTestCase(self): + self.__class__.op_type = "yolo_box" + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 0) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + if __name__ == '__main__': paddle.enable_static() unittest.main()