From 3921b3f0a107c372fd768778e7d73b24066292e7 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 13 Jul 2022 03:31:25 +0000 Subject: [PATCH 01/21] update unstack_op --- paddle/fluid/operators/unstack_op.cc | 53 ++----------------- paddle/phi/api/yaml/legacy_api.yaml | 10 ++++ paddle/phi/api/yaml/legacy_backward.yaml | 10 ++++ paddle/phi/infermeta/backward.cc | 41 ++++++++++++++ paddle/phi/infermeta/backward.h | 4 ++ python/paddle/fluid/layers/nn.py | 4 +- .../fluid/tests/unittests/test_unstack_op.py | 5 +- python/paddle/tensor/manipulation.py | 4 +- 8 files changed, 77 insertions(+), 54 deletions(-) mode change 100644 => 100755 paddle/fluid/operators/unstack_op.cc mode change 100644 => 100755 paddle/phi/api/yaml/legacy_api.yaml mode change 100644 => 100755 paddle/phi/api/yaml/legacy_backward.yaml mode change 100644 => 100755 paddle/phi/infermeta/backward.cc mode change 100644 => 100755 paddle/phi/infermeta/backward.h mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_unstack_op.py diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc old mode 100644 new mode 100755 index 76fe2ac77d9d8..8b1c33e07e672 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/backward.h" namespace paddle { namespace operators { @@ -63,51 +64,6 @@ class UnStackGradOpMaker : public framework::SingleGradOpMaker { class UnStackGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), - 0, - platform::errors::InvalidArgument( - "The Inputs(Y@Grad) of unstack operator are empty.")); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X", - "UnStackGrad"); - auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); - for (size_t i = 1; i < input_dims.size(); ++i) { - PADDLE_ENFORCE_EQ( - input_dims[i], - input_dims[0], - platform::errors::InvalidArgument( - "The dimensions of all Inputs(Y@Grad) must be the same," - "but received Inputs(Y@Grad)'s %d-th dimension is %d, " - "Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.", - i, - input_dims[i], - i - 1, - input_dims[0])); - } - - int axis = ctx->Attrs().Get("axis"); - int rank = input_dims[0].size(); - PADDLE_ENFORCE_GE(axis, - -(rank + 1), - platform::errors::InvalidArgument( - "The attribute axis is out of range, it must be " - "inside [-(rank+1), rank+1), where rank = %d", - rank)); - PADDLE_ENFORCE_LT(axis, - rank + 1, - platform::errors::InvalidArgument( - "The attribute axis is out of range, it must be " - "inside [-(rank+1), rank+1), where rank = %d", - rank)); - if (axis < 0) axis += (rank + 1); - - auto vec = phi::vectorize(input_dims[0]); - vec.insert(vec.begin() + axis, input_dims.size()); - ctx->SetOutputDim(framework::GradVarName("X"), phi::make_ddim(vec)); - } }; } // namespace operators @@ -119,12 +75,13 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(unstack, UnStackInferMetaFunctor, PD_INFER_META(phi::UnStackInferMeta)); - +DELCARE_INFER_SHAPE_FUNCTOR(unstack_grad, + UnStackGradInferMetaFunctor, + PT_INFER_META(phi::UnStackGradInferMeta)); REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, ops::UnStackGradOpMaker, ops::UnStackGradOpMaker, UnStackInferMetaFunctor); - -REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp); +REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, UnStackGradInferMetaFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml old mode 100644 new mode 100755 index c307fc7a19d5d..568dd9fe7501a --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2205,6 +2205,16 @@ func : unique data_type : x +# unstack +- api : unstack + args : (Tensor x, int axis, int num) + output : Tensor[]{num} + infer_meta : + func : UnStackInferMeta + kernel : + func : unstack + backward : unstack_grad + - api : unsqueeze args : (Tensor x, IntArray axis) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml old mode 100644 new mode 100755 index f01598e643420..f792da7d24d98 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2201,6 +2201,16 @@ func : unfold_grad no_need_buffer : x +- backward_api : unstack_grad + forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) + args : (Tensor[] x, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnStackGradInferMeta + param : [x, axis] + kernel : + func : unstack_grad + - backward_api : unsqueeze_double_grad forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x) args : (Tensor grad_x_grad, IntArray axes) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc old mode 100644 new mode 100755 index f59ea5549bd71..6a555e97d687b --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -661,4 +661,45 @@ void StackGradInferMeta(const MetaTensor& out_grad, } } +void UnStackGradInferMeta(const std::vector& x, + int axis, + MetaTensor* x_grad) { + std::vector input_dims(x.size()); + for(int i = 0; i < x.size(); ++i){ + input_dims[i] = x[i]->dims(); + } + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ( + input_dims[i], + input_dims[0], + phi::errors::InvalidArgument( + "The dimensions of all Inputs(Y@Grad) must be the same," + "but received Inputs(Y@Grad)'s %d-th dimension is %d, " + "Inputs(Y@Grad)'s 0-th to %d-th dimension is %d.", + i, + input_dims[i], + i - 1, + input_dims[0])); + } + + int rank = input_dims[0].size(); + PADDLE_ENFORCE_GE(axis, + -(rank + 1), + phi::errors::InvalidArgument( + "The attribute axis is out of range, it must be " + "inside [-(rank+1), rank+1), where rank = %d", + rank)); + PADDLE_ENFORCE_LT(axis, + rank + 1, + phi::errors::InvalidArgument( + "The attribute axis is out of range, it must be " + "inside [-(rank+1), rank+1), where rank = %d", + rank)); + if (axis < 0) axis += (rank + 1); + + auto vec = phi::vectorize(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + x_grad->set_dims(phi::make_ddim(vec)); +} + } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h old mode 100644 new mode 100755 index 0e7ed640d8ffb..c7f025e1105ef --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -261,4 +261,8 @@ void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); +void UnStackGradInferMeta(const std::vector& x, + int axis, + MetaTensor* x_grad); + } // namespace phi diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 050d6bfcb6bbb..2608a509d1b0a 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10685,12 +10685,12 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ - if _non_static_mode(): + if in_dygraph_mode(): if num == None: num = x.shape[axis] if num == 0: return [] - return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num) + return _C_ops.final_state_unstack(x, num, 'axis', int(axis), 'num', num) helper = LayerHelper('unstack', **locals()) if num is None: diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py old mode 100644 new mode 100755 index 730a74dc54c5a..690e58db88eb0 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -37,6 +37,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'unstack' + self.python_api = paddle.fluid.layers.unstack self.x = np.random.random(size=self.input_dim).astype(self.dtype) outs = np.split(self.x, self.input_dim[self.axis], self.axis) @@ -52,10 +53,10 @@ def setUp(self): self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], self.get_y_names()) + self.check_grad(['X'], self.get_y_names(), check_eager=True) class TestStackOp3(TestUnStackOpBase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c445402412e16..ad7c052e346b8 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -454,12 +454,12 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ - if _non_static_mode(): + if in_dygraph_mode(): if num == None: num = x.shape[axis] if num == 0: return [] - return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num) + return _C_ops.final_state_unstack(x, num, 'axis', int(axis), 'num', num) helper = LayerHelper('unstack', **locals()) if num is None: From 1ccb8f042ffcc00bd7cecddcfe0269e68ee84351 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 13 Jul 2022 04:00:46 +0000 Subject: [PATCH 02/21] update unstack_op --- paddle/phi/infermeta/backward.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 6a555e97d687b..e40e4d315992b 100755 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -665,7 +665,7 @@ void UnStackGradInferMeta(const std::vector& x, int axis, MetaTensor* x_grad) { std::vector input_dims(x.size()); - for(int i = 0; i < x.size(); ++i){ + for(size_t i = 0; i < x.size(); ++i){ input_dims[i] = x[i]->dims(); } for (size_t i = 1; i < input_dims.size(); ++i) { From e872d80fb2a62fafaa992bdf6b344389d8576eb3 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 13 Jul 2022 04:21:46 +0000 Subject: [PATCH 03/21] update unstack_op --- paddle/fluid/operators/unstack_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc index 8b1c33e07e672..f1eb99a1c92e9 100755 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -75,9 +75,9 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(unstack, UnStackInferMetaFunctor, PD_INFER_META(phi::UnStackInferMeta)); -DELCARE_INFER_SHAPE_FUNCTOR(unstack_grad, +DECLARE_INFER_SHAPE_FUNCTOR(unstack_grad, UnStackGradInferMetaFunctor, - PT_INFER_META(phi::UnStackGradInferMeta)); + PD_INFER_META(phi::UnStackGradInferMeta)); REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, From 9c9e750e179f98ef3f9b52b802986acf07b4cc23 Mon Sep 17 00:00:00 2001 From: ShiningZhang Date: Wed, 13 Jul 2022 08:09:17 +0000 Subject: [PATCH 04/21] fix unstack test --- paddle/fluid/operators/unstack_op.cc | 4 ++-- paddle/phi/api/yaml/legacy_backward.yaml | 4 ++-- paddle/phi/infermeta/backward.cc | 2 +- python/paddle/fluid/layers/nn.py | 2 +- python/paddle/fluid/tests/unittests/test_unstack_op.py | 3 ++- python/paddle/tensor/manipulation.py | 2 +- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc index 8b1c33e07e672..f1eb99a1c92e9 100755 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -75,9 +75,9 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(unstack, UnStackInferMetaFunctor, PD_INFER_META(phi::UnStackInferMeta)); -DELCARE_INFER_SHAPE_FUNCTOR(unstack_grad, +DECLARE_INFER_SHAPE_FUNCTOR(unstack_grad, UnStackGradInferMetaFunctor, - PT_INFER_META(phi::UnStackGradInferMeta)); + PD_INFER_META(phi::UnStackGradInferMeta)); REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f792da7d24d98..dd889f30ed3d9 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2203,11 +2203,11 @@ - backward_api : unstack_grad forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) - args : (Tensor[] x, int axis) + args : (Tensor[] out_grad, int axis) output : Tensor(x_grad) infer_meta : func : UnStackGradInferMeta - param : [x, axis] + param : [out_grad, axis] kernel : func : unstack_grad diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 6a555e97d687b..e40e4d315992b 100755 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -665,7 +665,7 @@ void UnStackGradInferMeta(const std::vector& x, int axis, MetaTensor* x_grad) { std::vector input_dims(x.size()); - for(int i = 0; i < x.size(); ++i){ + for(size_t i = 0; i < x.size(); ++i){ input_dims[i] = x[i]->dims(); } for (size_t i = 1; i < input_dims.size(); ++i) { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 2608a509d1b0a..fcc3ecd9a4ec0 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10690,7 +10690,7 @@ def unstack(x, axis=0, num=None): num = x.shape[axis] if num == 0: return [] - return _C_ops.final_state_unstack(x, num, 'axis', int(axis), 'num', num) + return _C_ops.final_state_unstack(x, axis, num) helper = LayerHelper('unstack', **locals()) if num is None: diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 690e58db88eb0..5cd944aa85166 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -15,6 +15,7 @@ from op_test import OpTest import numpy as np import unittest +import paddle class TestUnStackOpBase(OpTest): @@ -37,7 +38,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'unstack' - self.python_api = paddle.fluid.layers.unstack + self.python_api = paddle.unstack self.x = np.random.random(size=self.input_dim).astype(self.dtype) outs = np.split(self.x, self.input_dim[self.axis], self.axis) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 0b15c86ccd6af..99dbe0ae51edf 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -459,7 +459,7 @@ def unstack(x, axis=0, num=None): num = x.shape[axis] if num == 0: return [] - return _C_ops.final_state_unstack(x, num, 'axis', int(axis), 'num', num) + return _C_ops.final_state_unstack(x, axis, num) helper = LayerHelper('unstack', **locals()) if num is None: From 95ba10ae933d659f36f5e12caff4c3fe416b1f2b Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 13 Jul 2022 09:04:26 +0000 Subject: [PATCH 05/21] update unstack --- python/paddle/fluid/layers/nn.py | 7 +++++++ python/paddle/tensor/manipulation.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fcc3ecd9a4ec0..8cab4aa2aef78 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10692,6 +10692,13 @@ def unstack(x, axis=0, num=None): return [] return _C_ops.final_state_unstack(x, axis, num) + if _non_static_mode(): + if num == None: + num = x.shape[axis] + if num == 0: + return [] + return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num) + helper = LayerHelper('unstack', **locals()) if num is None: if axis is None or x.shape[axis] <= 0: diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 99dbe0ae51edf..e59eca6d0f6d3 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -461,6 +461,13 @@ def unstack(x, axis=0, num=None): return [] return _C_ops.final_state_unstack(x, axis, num) + if _non_static_mode(): + if num == None: + num = x.shape[axis] + if num == 0: + return [] + return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num) + helper = LayerHelper('unstack', **locals()) if num is None: if axis is None or x.shape[axis] <= 0: From f5e8a0cd8487ac0c00a1dfaefef5144c75e366ee Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 13 Jul 2022 11:29:50 +0000 Subject: [PATCH 06/21] update with remote --- paddle/phi/api/yaml/legacy_backward.yaml | 4 ++-- python/paddle/fluid/tests/unittests/test_unstack_op.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f792da7d24d98..dd889f30ed3d9 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2203,11 +2203,11 @@ - backward_api : unstack_grad forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) - args : (Tensor[] x, int axis) + args : (Tensor[] out_grad, int axis) output : Tensor(x_grad) infer_meta : func : UnStackGradInferMeta - param : [x, axis] + param : [out_grad, axis] kernel : func : unstack_grad diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 690e58db88eb0..5cd944aa85166 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -15,6 +15,7 @@ from op_test import OpTest import numpy as np import unittest +import paddle class TestUnStackOpBase(OpTest): @@ -37,7 +38,7 @@ def setUp(self): self.initDefaultParameters() self.initParameters() self.op_type = 'unstack' - self.python_api = paddle.fluid.layers.unstack + self.python_api = paddle.unstack self.x = np.random.random(size=self.input_dim).astype(self.dtype) outs = np.split(self.x, self.input_dim[self.axis], self.axis) From 694d59088a80d11e7991324206b4268679fd7406 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 14 Jul 2022 06:56:35 +0000 Subject: [PATCH 07/21] fix unstack_test.py --- python/paddle/fluid/tests/unittests/test_unstack_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 5cd944aa85166..bb28bdeba79d3 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -46,9 +46,12 @@ def setUp(self): del new_shape[self.axis] y_names = self.get_y_names() tmp = [] + tmp_names = [] for i in range(self.input_dim[self.axis]): tmp.append((y_names[i], np.reshape(outs[i], new_shape))) + tmp_names.append(y_names[i]) + self.python_out_sig = tmp_names self.inputs = {'X': self.x} self.outputs = {'Y': tmp} self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]} From 2fe8dcd3b46736214a774843a36e2d1d75b9f485 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 14 Jul 2022 07:15:47 +0000 Subject: [PATCH 08/21] temp_save_change_nms_op --- paddle/fluid/operators/detection/nms_op.cc | 8 -- paddle/phi/kernels/cpu/nms_kernel.cc | 80 +++++++++++++++++ .../kernels/gpu/nms_kernel.cu} | 90 ++++++++----------- .../nms_op.h => phi/kernels/nms_kernel.h} | 42 +++++---- 4 files changed, 142 insertions(+), 78 deletions(-) mode change 100644 => 100755 paddle/fluid/operators/detection/nms_op.cc create mode 100755 paddle/phi/kernels/cpu/nms_kernel.cc rename paddle/{fluid/operators/detection/nms_op.cu => phi/kernels/gpu/nms_kernel.cu} (56%) mode change 100644 => 100755 rename paddle/{fluid/operators/detection/nms_op.h => phi/kernels/nms_kernel.h} (56%) mode change 100644 => 100755 diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc old mode 100644 new mode 100755 index 3c5feaa656a32..3286c04466473 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -132,13 +132,6 @@ static void NMS(const T* boxes_data, template class NMSKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* boxes = context.Input("Boxes"); - Tensor* output = context.Output("KeepBoxesIdxs"); - int64_t* output_data = output->mutable_data(context.GetPlace()); - auto threshold = context.template Attr("iou_threshold"); - NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); - } }; } // namespace operators @@ -152,4 +145,3 @@ REGISTER_OPERATOR( ops::NMSOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel, ops::NMSKernel); diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc new file mode 100755 index 0000000000000..482dd66914baa --- /dev/null +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/trace_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diagonal.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +static void NMS(const T* boxes_data, + int64_t* output_data, + float threshold, + int64_t num_boxes) { + auto num_masks = CeilDivide(num_boxes, 64); + std::vector masks(num_masks, 0); + + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + T box_1[4]; + for (int k = 0; k < 4; ++k) { + box_1[k] = boxes_data[i * 4 + k]; + } + for (int64_t j = i + 1; j < num_boxes; ++j) { + if (masks[j / 64] & 1ULL << (j % 64)) continue; + T box_2[4]; + for (int k = 0; k < 4; ++k) { + box_2[k] = boxes_data[j * 4 + k]; + } + bool is_overlap = CalculateIoU(box_1, box_2, threshold); + if (is_overlap) { + masks[j / 64] |= 1ULL << (j % 64); + } + } + } + + int64_t output_data_idx = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + output_data[output_data_idx++] = i; + } + + for (; output_data_idx < num_boxes; ++output_data_idx) { + output_data[output_data_idx] = 0; + } +} + +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output){ + int64_t* output_data = dev_ctx.template Alloc(output); + NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); +} + +} // namespace phi + +PD_REGISTER_KERNEL(nms, + CPU, + ALL_LAYOUT, + phi::NMSKernel, + float, + double) {} \ No newline at end of file diff --git a/paddle/fluid/operators/detection/nms_op.cu b/paddle/phi/kernels/gpu/nms_kernel.cu old mode 100644 new mode 100755 similarity index 56% rename from paddle/fluid/operators/detection/nms_op.cu rename to paddle/phi/kernels/gpu/nms_kernel.cu index 935d13cfd4a47..5dc8a51b05478 --- a/paddle/fluid/operators/detection/nms_op.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -1,28 +1,25 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/operators/detection/nms_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -static const int64_t threadsPerBlock = sizeof(int64_t) * 8; - -namespace paddle { -namespace operators { - -using framework::Tensor; +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nms_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/diagonal.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +namespace phi { template static __global__ void NMS(const T* boxes_data, @@ -53,15 +50,12 @@ static __global__ void NMS(const T* boxes_data, } } -template -class NMSCudaKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* boxes = context.Input("Boxes"); - Tensor* output = context.Output("KeepBoxesIdxs"); - auto* output_data = output->mutable_data(context.GetPlace()); - - auto threshold = context.template Attr("iou_threshold"); +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output){ + auto* output_data = dev_ctx.template Alloc(output); const int64_t num_boxes = boxes->dims()[0]; const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); @@ -69,19 +63,19 @@ class NMSCudaKernel : public framework::OpKernel { dim3 grid(blocks_per_line, blocks_per_line); auto mask_data = - memory::Alloc(context.cuda_device_context(), + memory::Alloc(dev_ctx.cuda_device_context(), num_boxes * blocks_per_line * sizeof(uint64_t)); uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); - NMS<<>>( + NMS<<>>( boxes->data(), threshold, num_boxes, mask_dev); std::vector mask_host(num_boxes * blocks_per_line); - memory::Copy(platform::CPUPlace(), + memory::Copy(phi::CPUPlace(), mask_host.data(), - context.GetPlace(), + dev_ctx.GetPlace(), mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t), - context.cuda_device_context().stream()); + dev_ctx.cuda_device_context().stream()); std::vector remv(blocks_per_line); @@ -100,19 +94,13 @@ class NMSCudaKernel : public framework::OpKernel { } } } - memory::Copy(context.GetPlace(), + memory::Copy(dev_ctx.GetPlace(), output_data, - platform::CPUPlace(), + phi::CPUPlace(), output_host, sizeof(int64_t) * num_boxes, - context.cuda_device_context().stream()); - } -}; - -} // namespace operators -} // namespace paddle + dev_ctx.cuda_device_context().stream()); +} -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(nms, - ops::NMSCudaKernel, - ops::NMSCudaKernel); +} +PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} \ No newline at end of file diff --git a/paddle/fluid/operators/detection/nms_op.h b/paddle/phi/kernels/nms_kernel.h old mode 100644 new mode 100755 similarity index 56% rename from paddle/fluid/operators/detection/nms_op.h rename to paddle/phi/kernels/nms_kernel.h index f5cd1c9203784..24307138db1fc --- a/paddle/fluid/operators/detection/nms_op.h +++ b/paddle/phi/kernels/nms_kernel.h @@ -1,24 +1,23 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" -namespace paddle { -namespace operators { +namespace phi { HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) { return (n + m - 1) / m; @@ -48,5 +47,10 @@ HOSTDEVICE inline bool CalculateIoU(const T* const box_1, return inter_area / union_area > threshold; } -} // namespace operators -} // namespace paddle +template +void NMSKernel(const Context& dev_ctx, + const DenseTensor& boxes, + float threshold, + DenseTensor* output); + +} // namespace phi \ No newline at end of file From 840b5ea24637b0f2465216370efd82963680e181 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 14 Jul 2022 09:04:36 +0000 Subject: [PATCH 09/21] add nms test --- paddle/fluid/operators/detection/nms_op.cc | 32 +++++++------------ paddle/phi/api/yaml/legacy_api.yaml | 9 ++++++ paddle/phi/infermeta/unary.cc | 15 +++++++++ paddle/phi/infermeta/unary.h | 4 +++ .../fluid/tests/unittests/test_nms_op.py | 4 ++- python/paddle/vision/ops.py | 3 ++ 6 files changed, 45 insertions(+), 22 deletions(-) mode change 100644 => 100755 paddle/phi/infermeta/unary.cc mode change 100644 => 100755 paddle/phi/infermeta/unary.h mode change 100644 => 100755 python/paddle/fluid/tests/unittests/test_nms_op.py mode change 100644 => 100755 python/paddle/vision/ops.py diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index 3286c04466473..2760117247794 100755 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -11,8 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include "paddle/fluid/operators/detection/nms_op.h" +#pragma once +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/framework/infershape_utils.h" #include @@ -65,24 +69,6 @@ class NMSOpMaker : public framework::OpProtoAndCheckerMaker { class NMSOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS"); - OP_INOUT_CHECK( - ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", "NMS"); - - auto boxes_dim = ctx->GetInputDim("Boxes"); - PADDLE_ENFORCE_EQ(boxes_dim.size(), - 2, - platform::errors::InvalidArgument( - "The Input Boxes must be 2-dimention " - "whose shape must be [N, 4] " - "N is the number of boxes " - "in last dimension in format [x1, x2, y1, y2]. ")); - auto num_boxes = boxes_dim[0]; - - ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -138,10 +124,14 @@ class NMSKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(nms, + NMSInferMetaFunctor, + PD_INFER_META(phi::NMSInferMeta)); REGISTER_OPERATOR( nms, ops::NMSOp, ops::NMSOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + NMSInferMetaFunctor); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 14315b0088c91..28d4325a4f591 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1474,6 +1474,15 @@ optional : weight backward : nll_loss_grad +- api : nms + args : (Tensor x, float threshold) + output : Tensor(out) + infer_meta : + func : NMSInferMeta + kernel : + func : nms + data_type : x + - api : norm args : (Tensor x, int axis, float epsilon, bool is_test) output : Tensor(out), Tensor(norm) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc old mode 100644 new mode 100755 index 0048f130adf62..4455b7e16351a --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1312,6 +1312,21 @@ void NanmedianInferMeta(const MetaTensor& x, out->set_dims(make_ddim(out_dim)); } +void NMSInferMeta(const MetaTensor& x, + float threshold, + MetaTensor* out){ + auto boxes_dim = x.dims(); + PADDLE_ENFORCE_EQ(boxes_dim.size(), + 2, + platform::errors::InvalidArgument( + "The Input Boxes must be 2-dimention " + "whose shape must be [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2]. ")); + auto num_boxes = boxes_dim[0]; + out->set_dims(framework::make_ddim(num_boxes)); +} + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h old mode 100644 new mode 100755 index 0b9298cfd362f..b611f877883dc --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -186,6 +186,10 @@ void NanmedianInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* median_index); +void NMSInferMeta(const MetaTensor& x, + float threshold, + MetaTensor* out); + void NormInferMeta(const MetaTensor& x, int axis, float epsilon, diff --git a/python/paddle/fluid/tests/unittests/test_nms_op.py b/python/paddle/fluid/tests/unittests/test_nms_op.py old mode 100644 new mode 100755 index f3c253d45c0de..7196d0dea220a --- a/python/paddle/fluid/tests/unittests/test_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_nms_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np from op_test import OpTest +import paddle def iou(box_a, box_b): @@ -71,6 +72,7 @@ class TestNMSOp(OpTest): def setUp(self): self.op_type = 'nms' + self.python_api = paddle.vision.ops.nms self.dtype = np.float64 self.init_dtype_type() boxes = np.random.rand(32, 4).astype(self.dtype) @@ -86,7 +88,7 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) if __name__ == "__main__": diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py old mode 100644 new mode 100755 index 7febf4f740ea2..0b64f823a6427 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1471,6 +1471,9 @@ def nms(boxes, """ def _nms(boxes, iou_threshold): + if in_dygraph_mode(): + return _C_ops.final_state_nms(boxes, iou_threshold) + if _non_static_mode(): return _C_ops.nms(boxes, 'iou_threshold', iou_threshold) From 00e9eb00c5f24744868d158056368159fe85d1bf Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 15 Jul 2022 02:50:13 +0000 Subject: [PATCH 10/21] update nms fix --- paddle/fluid/operators/detection/nms_op.cc | 40 +--------------------- paddle/phi/infermeta/unary.cc | 4 +-- paddle/phi/kernels/cpu/nms_kernel.cc | 13 +++---- 3 files changed, 8 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index 2760117247794..b1b687c11556e 100755 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once + #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/fluid/framework/op_registry.h" @@ -77,44 +77,6 @@ class NMSOp : public framework::OperatorWithKernel { } }; -template -static void NMS(const T* boxes_data, - int64_t* output_data, - float threshold, - int64_t num_boxes) { - auto num_masks = CeilDivide(num_boxes, 64); - std::vector masks(num_masks, 0); - - for (int64_t i = 0; i < num_boxes; ++i) { - if (masks[i / 64] & 1ULL << (i % 64)) continue; - T box_1[4]; - for (int k = 0; k < 4; ++k) { - box_1[k] = boxes_data[i * 4 + k]; - } - for (int64_t j = i + 1; j < num_boxes; ++j) { - if (masks[j / 64] & 1ULL << (j % 64)) continue; - T box_2[4]; - for (int k = 0; k < 4; ++k) { - box_2[k] = boxes_data[j * 4 + k]; - } - bool is_overlap = CalculateIoU(box_1, box_2, threshold); - if (is_overlap) { - masks[j / 64] |= 1ULL << (j % 64); - } - } - } - - int64_t output_data_idx = 0; - for (int64_t i = 0; i < num_boxes; ++i) { - if (masks[i / 64] & 1ULL << (i % 64)) continue; - output_data[output_data_idx++] = i; - } - - for (; output_data_idx < num_boxes; ++output_data_idx) { - output_data[output_data_idx] = 0; - } -} - template class NMSKernel : public framework::OpKernel { public: diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4455b7e16351a..20cddcf2e624a 100755 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1318,13 +1318,13 @@ void NMSInferMeta(const MetaTensor& x, auto boxes_dim = x.dims(); PADDLE_ENFORCE_EQ(boxes_dim.size(), 2, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The Input Boxes must be 2-dimention " "whose shape must be [N, 4] " "N is the number of boxes " "in last dimension in format [x1, x2, y1, y2]. ")); auto num_boxes = boxes_dim[0]; - out->set_dims(framework::make_ddim(num_boxes)); + out->set_dims(phi::make_ddim({num_boxes})); } void NormInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc index 482dd66914baa..81fd476962971 100755 --- a/paddle/phi/kernels/cpu/nms_kernel.cc +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - -#include "paddle/phi/kernels/trace_kernel.h" - +#include "paddle/phi/kernels/nms_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/diagonal.h" @@ -25,7 +22,7 @@ namespace phi { template static void NMS(const T* boxes_data, - int64_t* output_data, + T* output_data, float threshold, int64_t num_boxes) { auto num_masks = CeilDivide(num_boxes, 64); @@ -66,8 +63,8 @@ void NMSKernel(const Context& dev_ctx, const DenseTensor& boxes, float threshold, DenseTensor* output){ - int64_t* output_data = dev_ctx.template Alloc(output); - NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); + auto output_data = dev_ctx.template Alloc(output); + NMS(boxes.data(), output_data, threshold, boxes.dims()[0]); } } // namespace phi @@ -77,4 +74,4 @@ PD_REGISTER_KERNEL(nms, ALL_LAYOUT, phi::NMSKernel, float, - double) {} \ No newline at end of file + double) {} From e4cbf9d9db78019c2d479612ece343e2f3c098cf Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 15 Jul 2022 03:08:48 +0000 Subject: [PATCH 11/21] update unstack_op --- paddle/phi/infermeta/backward.cc | 9 +++++---- paddle/phi/infermeta/backward.h | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index e40e4d315992b..48d9af586f9b7 100755 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -661,12 +661,12 @@ void StackGradInferMeta(const MetaTensor& out_grad, } } -void UnStackGradInferMeta(const std::vector& x, +void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad) { - std::vector input_dims(x.size()); - for(size_t i = 0; i < x.size(); ++i){ - input_dims[i] = x[i]->dims(); + std::vector input_dims(out_grad.size()); + for(size_t i = 0; i < out_grad.size(); ++i){ + input_dims[i] = out_grad[i]->dims(); } for (size_t i = 1; i < input_dims.size(); ++i) { PADDLE_ENFORCE_EQ( @@ -700,6 +700,7 @@ void UnStackGradInferMeta(const std::vector& x, auto vec = phi::vectorize(input_dims[0]); vec.insert(vec.begin() + axis, input_dims.size()); x_grad->set_dims(phi::make_ddim(vec)); + x_grad->set_dtype(out_grad[0]->dtype()); } } // namespace phi diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c7f025e1105ef..365be4adc1ff4 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -261,7 +261,7 @@ void StackGradInferMeta(const MetaTensor& out_grad, int axis, std::vector x_grad); -void UnStackGradInferMeta(const std::vector& x, +void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad); From 9135bf8fa6c0de5665ae3fc4af3f9a31e97803f5 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Mon, 18 Jul 2022 04:13:13 +0000 Subject: [PATCH 12/21] temp save change --- paddle/phi/kernels/gpu/nms_kernel.cu | 31 ++++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu index 5dc8a51b05478..419d1c85db0ec 100755 --- a/paddle/phi/kernels/gpu/nms_kernel.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -13,11 +13,14 @@ // limitations under the License. #include "paddle/phi/kernels/nms_kernel.h" - +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/diagonal.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" + +static const int64_t threadsPerBlock = sizeof(int64_t) * 8; namespace phi { @@ -58,30 +61,23 @@ void NMSKernel(const Context& dev_ctx, auto* output_data = dev_ctx.template Alloc(output); const int64_t num_boxes = boxes->dims()[0]; const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); - dim3 block(threadsPerBlock); dim3 grid(blocks_per_line, blocks_per_line); - auto mask_data = - memory::Alloc(dev_ctx.cuda_device_context(), - num_boxes * blocks_per_line * sizeof(uint64_t)); + paddle::memory::Alloc(dev_ctx, num_boxes * blocks_per_line * sizeof(uint64_t)); uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); - NMS<<>>( + NMS<<>>( boxes->data(), threshold, num_boxes, mask_dev); - std::vector mask_host(num_boxes * blocks_per_line); - memory::Copy(phi::CPUPlace(), + paddle::memory::Copy(phi::CPUPlace(), mask_host.data(), dev_ctx.GetPlace(), mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t), - dev_ctx.cuda_device_context().stream()); - + dev_ctx.stream()); std::vector remv(blocks_per_line); - std::vector keep_boxes_idxs(num_boxes); int64_t* output_host = keep_boxes_idxs.data(); - int64_t last_box_num = 0; for (int64_t i = 0; i < num_boxes; ++i) { auto remv_element_id = i / threadsPerBlock; @@ -94,13 +90,12 @@ void NMSKernel(const Context& dev_ctx, } } } - memory::Copy(dev_ctx.GetPlace(), + paddle::memory::Copy(dev_ctx.GetPlace(), output_data, phi::CPUPlace(), output_host, sizeof(int64_t) * num_boxes, - dev_ctx.cuda_device_context().stream()); + dev_ctx.stream()); } - } -PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} \ No newline at end of file +PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} From a692a334cdf28a1165ed2fb5e844f4e8935f556b Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Mon, 18 Jul 2022 12:24:44 +0000 Subject: [PATCH 13/21] finish fix nms_op --- paddle/fluid/operators/detection/CMakeLists.txt | 2 +- paddle/fluid/operators/detection/nms_op.cc | 1 - paddle/phi/kernels/gpu/nms_kernel.cu | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index c05c39e88d74a..0f40e6596c93f 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -82,7 +82,7 @@ detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) -detection_library(nms_op SRCS nms_op.cc nms_op.cu) +detection_library(nms_op SRCS nms_op.cc) if(WITH_GPU OR WITH_ROCM) set(TMPDEPS memory) diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc index b1b687c11556e..3373fa5c2fd54 100755 --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -79,7 +79,6 @@ class NMSOp : public framework::OperatorWithKernel { template class NMSKernel : public framework::OpKernel { - public: }; } // namespace operators diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu index 419d1c85db0ec..34da385fdf38b 100755 --- a/paddle/phi/kernels/gpu/nms_kernel.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -59,7 +59,7 @@ void NMSKernel(const Context& dev_ctx, float threshold, DenseTensor* output){ auto* output_data = dev_ctx.template Alloc(output); - const int64_t num_boxes = boxes->dims()[0]; + const int64_t num_boxes = boxes.dims()[0]; const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); dim3 block(threadsPerBlock); dim3 grid(blocks_per_line, blocks_per_line); @@ -67,7 +67,7 @@ void NMSKernel(const Context& dev_ctx, paddle::memory::Alloc(dev_ctx, num_boxes * blocks_per_line * sizeof(uint64_t)); uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); NMS<<>>( - boxes->data(), threshold, num_boxes, mask_dev); + boxes.data(), threshold, num_boxes, mask_dev); std::vector mask_host(num_boxes * blocks_per_line); paddle::memory::Copy(phi::CPUPlace(), mask_host.data(), From de372d6cc1f310bf651c1b3a493bca8e03537fcb Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Wed, 20 Jul 2022 02:11:00 +0000 Subject: [PATCH 14/21] pass nms test --- paddle/phi/kernels/gpu/nms_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu index 34da385fdf38b..fe70034fe3b54 100755 --- a/paddle/phi/kernels/gpu/nms_kernel.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -58,7 +58,7 @@ void NMSKernel(const Context& dev_ctx, const DenseTensor& boxes, float threshold, DenseTensor* output){ - auto* output_data = dev_ctx.template Alloc(output); + auto* output_data = dev_ctx.template Alloc(output); const int64_t num_boxes = boxes.dims()[0]; const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); dim3 block(threadsPerBlock); From 3717a30f05d8374927763d88bcc5ac763a287387 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 21 Jul 2022 06:42:12 +0000 Subject: [PATCH 15/21] fix CI --- python/paddle/fluid/tests/unittests/test_nms_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nms_op.py b/python/paddle/fluid/tests/unittests/test_nms_op.py index 7196d0dea220a..a81a46e1140e8 100755 --- a/python/paddle/fluid/tests/unittests/test_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_nms_op.py @@ -79,10 +79,12 @@ def setUp(self): boxes[:, 2] = boxes[:, 0] + boxes[:, 2] boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + paddle.disable_static() self.inputs = {'Boxes': boxes} self.attrs = {'iou_threshold': 0.5} out_py = nms(boxes, self.attrs['iou_threshold']) self.outputs = {'KeepBoxesIdxs': out_py} + paddle.enable_static() def init_dtype_type(self): pass From 58cafb9beca129f5542f09c7202213033cf069c3 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 22 Jul 2022 02:18:42 +0000 Subject: [PATCH 16/21] fix ops test --- paddle/phi/kernels/cpu/nms_kernel.cc | 4 ++-- python/paddle/fluid/tests/unittests/test_ops_nms.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc index 81fd476962971..c26aae5fb31a8 100755 --- a/paddle/phi/kernels/cpu/nms_kernel.cc +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template static void NMS(const T* boxes_data, - T* output_data, + int64_t* output_data, float threshold, int64_t num_boxes) { auto num_masks = CeilDivide(num_boxes, 64); @@ -63,7 +63,7 @@ void NMSKernel(const Context& dev_ctx, const DenseTensor& boxes, float threshold, DenseTensor* output){ - auto output_data = dev_ctx.template Alloc(output); + auto output_data = dev_ctx.template Alloc(output); NMS(boxes.data(), output_data, threshold, boxes.dims()[0]); } diff --git a/python/paddle/fluid/tests/unittests/test_ops_nms.py b/python/paddle/fluid/tests/unittests/test_ops_nms.py index c775a47bd2472..32ee7e9358f47 100644 --- a/python/paddle/fluid/tests/unittests/test_ops_nms.py +++ b/python/paddle/fluid/tests/unittests/test_ops_nms.py @@ -88,6 +88,7 @@ def tearDown(self): self.temp_dir.cleanup() def test_nms(self): + paddle.disable_static() for device in self.devices: for dtype in self.dtypes: boxes, scores, category_idxs, categories = gen_args( @@ -103,8 +104,10 @@ def test_nms(self): self.assertTrue( np.array_equal(out.numpy(), out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) + paddle.enable_static() def test_multiclass_nms_dynamic(self): + paddle.disable_static() for device in self.devices: for dtype in self.dtypes: boxes, scores, category_idxs, categories = gen_args( @@ -121,8 +124,10 @@ def test_multiclass_nms_dynamic(self): self.assertTrue( np.array_equal(out.numpy(), out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) + paddle.enable_static() def test_multiclass_nms_static(self): + paddle.disable_static() for device in self.devices: for dtype in self.dtypes: paddle.enable_static() @@ -160,8 +165,10 @@ def test_multiclass_nms_static(self): self.assertTrue( np.array_equal(out, out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) + paddle.enable_static() def test_multiclass_nms_dynamic_to_static(self): + paddle.disable_static() for device in self.devices: for dtype in self.dtypes: paddle.set_device(device) @@ -196,7 +203,7 @@ def fun(x): np.array_equal(origin, res), "origin out: {}\n inference model out: {}\n".format( origin, res)) - + paddle.enable_static() if __name__ == '__main__': unittest.main() From 0939969300f0e27259c1e66b989c2f6dc3856c8d Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 28 Jul 2022 07:52:48 +0000 Subject: [PATCH 17/21] save change --- python/paddle/fluid/layers/nn.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8cab4aa2aef78..fb39061d4a45d 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10685,12 +10685,6 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ - if in_dygraph_mode(): - if num == None: - num = x.shape[axis] - if num == 0: - return [] - return _C_ops.final_state_unstack(x, axis, num) if _non_static_mode(): if num == None: From 4daaf39aa9be875ab42edb4470dcb8f9f1c4ca21 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 28 Jul 2022 11:46:38 +0000 Subject: [PATCH 18/21] fix code style --- paddle/fluid/operators/detection/nms_op.cc | 14 +-- paddle/fluid/operators/unstack_op.cc | 6 +- paddle/phi/infermeta/backward.cc | 3 +- paddle/phi/infermeta/backward.h | 1 - paddle/phi/infermeta/unary.cc | 24 +++--- paddle/phi/infermeta/unary.h | 4 +- paddle/phi/kernels/cpu/nms_kernel.cc | 18 ++-- paddle/phi/kernels/gpu/nms_kernel.cu | 85 ++++++++++--------- paddle/phi/kernels/nms_kernel.h | 8 +- python/paddle/fluid/layers/nn.py | 1 + .../fluid/tests/unittests/test_ops_nms.py | 1 + python/paddle/tensor/manipulation.py | 1 + 12 files changed, 81 insertions(+), 85 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/detection/nms_op.cc mode change 100755 => 100644 paddle/fluid/operators/unstack_op.cc mode change 100755 => 100644 paddle/phi/infermeta/backward.cc mode change 100755 => 100644 paddle/phi/infermeta/unary.cc mode change 100755 => 100644 paddle/phi/infermeta/unary.h mode change 100755 => 100644 paddle/phi/kernels/cpu/nms_kernel.cc mode change 100755 => 100644 paddle/phi/kernels/gpu/nms_kernel.cu mode change 100755 => 100644 paddle/phi/kernels/nms_kernel.h diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc old mode 100755 new mode 100644 index 3373fa5c2fd54..03680538f778e --- a/paddle/fluid/operators/detection/nms_op.cc +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" +#include + +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" -#include "paddle/fluid/framework/infershape_utils.h" - -#include +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -69,6 +69,7 @@ class NMSOpMaker : public framework::OpProtoAndCheckerMaker { class NMSOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -78,8 +79,7 @@ class NMSOp : public framework::OperatorWithKernel { }; template -class NMSKernel : public framework::OpKernel { -}; +class NMSKernel : public framework::OpKernel {}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc old mode 100755 new mode 100644 index f1eb99a1c92e9..d1cfbd2b90260 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -20,8 +20,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -84,4 +84,6 @@ REGISTER_OPERATOR(unstack, ops::UnStackGradOpMaker, ops::UnStackGradOpMaker, UnStackInferMetaFunctor); -REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, UnStackGradInferMetaFunctor); +REGISTER_OPERATOR(unstack_grad, + ops::UnStackGradOp, + UnStackGradInferMetaFunctor); diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc old mode 100755 new mode 100644 index cee9bf9f03e6c..90483f1be208b --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/infermeta/backward.h" - #include "paddle/phi/common/type_traits.h" #include "paddle/phi/kernels/funcs/axis_utils.h" @@ -747,7 +746,7 @@ void UnStackGradInferMeta(const std::vector& out_grad, int axis, MetaTensor* x_grad) { std::vector input_dims(out_grad.size()); - for(size_t i = 0; i < out_grad.size(); ++i){ + for (size_t i = 0; i < out_grad.size(); ++i) { input_dims[i] = out_grad[i]->dims(); } for (size_t i = 1; i < input_dims.size(); ++i) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 5685e4713ecdc..7e95801dbd78e 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include - #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc old mode 100755 new mode 100644 index 097eebf9fbd3c..3721dcbf09d52 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1663,19 +1663,17 @@ void NanmedianInferMeta(const MetaTensor& x, out->set_dims(make_ddim(out_dim)); } -void NMSInferMeta(const MetaTensor& x, - float threshold, - MetaTensor* out){ - auto boxes_dim = x.dims(); - PADDLE_ENFORCE_EQ(boxes_dim.size(), - 2, - phi::errors::InvalidArgument( - "The Input Boxes must be 2-dimention " - "whose shape must be [N, 4] " - "N is the number of boxes " - "in last dimension in format [x1, x2, y1, y2]. ")); - auto num_boxes = boxes_dim[0]; - out->set_dims(phi::make_ddim({num_boxes})); +void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) { + auto boxes_dim = x.dims(); + PADDLE_ENFORCE_EQ(boxes_dim.size(), + 2, + phi::errors::InvalidArgument( + "The Input Boxes must be 2-dimention " + "whose shape must be [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2]. ")); + auto num_boxes = boxes_dim[0]; + out->set_dims(phi::make_ddim({num_boxes})); } void NormInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h old mode 100755 new mode 100644 index 0ff80ccc9fc89..e790b1c85bb44 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -222,9 +222,7 @@ void NanmedianInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* median_index); -void NMSInferMeta(const MetaTensor& x, - float threshold, - MetaTensor* out); +void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out); void NormInferMeta(const MetaTensor& x, int axis, diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc old mode 100755 new mode 100644 index c26aae5fb31a8..7e656b14f1fc5 --- a/paddle/phi/kernels/cpu/nms_kernel.cc +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/nms_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" + #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/diagonal.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -60,18 +61,13 @@ static void NMS(const T* boxes_data, template void NMSKernel(const Context& dev_ctx, - const DenseTensor& boxes, - float threshold, - DenseTensor* output){ - auto output_data = dev_ctx.template Alloc(output); - NMS(boxes.data(), output_data, threshold, boxes.dims()[0]); + const DenseTensor& boxes, + float threshold, + DenseTensor* output) { + auto output_data = dev_ctx.template Alloc(output); + NMS(boxes.data(), output_data, threshold, boxes.dims()[0]); } } // namespace phi -PD_REGISTER_KERNEL(nms, - CPU, - ALL_LAYOUT, - phi::NMSKernel, - float, - double) {} +PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu old mode 100755 new mode 100644 index fe70034fe3b54..5a52cb33662fc --- a/paddle/phi/kernels/gpu/nms_kernel.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -13,12 +13,13 @@ // limitations under the License. #include "paddle/phi/kernels/nms_kernel.h" + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/memory/memcpy.h" static const int64_t threadsPerBlock = sizeof(int64_t) * 8; @@ -55,47 +56,47 @@ static __global__ void NMS(const T* boxes_data, template void NMSKernel(const Context& dev_ctx, - const DenseTensor& boxes, - float threshold, - DenseTensor* output){ - auto* output_data = dev_ctx.template Alloc(output); - const int64_t num_boxes = boxes.dims()[0]; - const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); - dim3 block(threadsPerBlock); - dim3 grid(blocks_per_line, blocks_per_line); - auto mask_data = - paddle::memory::Alloc(dev_ctx, num_boxes * blocks_per_line * sizeof(uint64_t)); - uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); - NMS<<>>( - boxes.data(), threshold, num_boxes, mask_dev); - std::vector mask_host(num_boxes * blocks_per_line); - paddle::memory::Copy(phi::CPUPlace(), - mask_host.data(), - dev_ctx.GetPlace(), - mask_dev, - num_boxes * blocks_per_line * sizeof(uint64_t), - dev_ctx.stream()); - std::vector remv(blocks_per_line); - std::vector keep_boxes_idxs(num_boxes); - int64_t* output_host = keep_boxes_idxs.data(); - int64_t last_box_num = 0; - for (int64_t i = 0; i < num_boxes; ++i) { - auto remv_element_id = i / threadsPerBlock; - auto remv_bit_id = i % threadsPerBlock; - if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { - output_host[last_box_num++] = i; - uint64_t* current_mask = mask_host.data() + i * blocks_per_line; - for (auto j = remv_element_id; j < blocks_per_line; ++j) { - remv[j] |= current_mask[j]; - } + const DenseTensor& boxes, + float threshold, + DenseTensor* output) { + auto* output_data = dev_ctx.template Alloc(output); + const int64_t num_boxes = boxes.dims()[0]; + const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + dim3 block(threadsPerBlock); + dim3 grid(blocks_per_line, blocks_per_line); + auto mask_data = paddle::memory::Alloc( + dev_ctx, num_boxes * blocks_per_line * sizeof(uint64_t)); + uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); + NMS<<>>( + boxes.data(), threshold, num_boxes, mask_dev); + std::vector mask_host(num_boxes * blocks_per_line); + paddle::memory::Copy(phi::CPUPlace(), + mask_host.data(), + dev_ctx.GetPlace(), + mask_dev, + num_boxes * blocks_per_line * sizeof(uint64_t), + dev_ctx.stream()); + std::vector remv(blocks_per_line); + std::vector keep_boxes_idxs(num_boxes); + int64_t* output_host = keep_boxes_idxs.data(); + int64_t last_box_num = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + auto remv_element_id = i / threadsPerBlock; + auto remv_bit_id = i % threadsPerBlock; + if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { + output_host[last_box_num++] = i; + uint64_t* current_mask = mask_host.data() + i * blocks_per_line; + for (auto j = remv_element_id; j < blocks_per_line; ++j) { + remv[j] |= current_mask[j]; } } - paddle::memory::Copy(dev_ctx.GetPlace(), - output_data, - phi::CPUPlace(), - output_host, - sizeof(int64_t) * num_boxes, - dev_ctx.stream()); -} + } + paddle::memory::Copy(dev_ctx.GetPlace(), + output_data, + phi::CPUPlace(), + output_host, + sizeof(int64_t) * num_boxes, + dev_ctx.stream()); } +} // namespace phi PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} diff --git a/paddle/phi/kernels/nms_kernel.h b/paddle/phi/kernels/nms_kernel.h old mode 100755 new mode 100644 index 24307138db1fc..e8511f4c4a49f --- a/paddle/phi/kernels/nms_kernel.h +++ b/paddle/phi/kernels/nms_kernel.h @@ -49,8 +49,8 @@ HOSTDEVICE inline bool CalculateIoU(const T* const box_1, template void NMSKernel(const Context& dev_ctx, - const DenseTensor& boxes, - float threshold, - DenseTensor* output); + const DenseTensor& boxes, + float threshold, + DenseTensor* output); -} // namespace phi \ No newline at end of file +} // namespace phi diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 03d21035bd510..a1bfbaab2b169 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14,6 +14,7 @@ """ All layers just related to the neural network. """ + from __future__ import print_function import os diff --git a/python/paddle/fluid/tests/unittests/test_ops_nms.py b/python/paddle/fluid/tests/unittests/test_ops_nms.py index fdf9899d5a1fb..18b022b98d1ce 100644 --- a/python/paddle/fluid/tests/unittests/test_ops_nms.py +++ b/python/paddle/fluid/tests/unittests/test_ops_nms.py @@ -221,5 +221,6 @@ def test_matrix_nms_dynamic(self): keep_top_k=100, ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4053cc43b30ee..4828b20ba769c 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -23,6 +23,7 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.layers import utils import numpy as np + # TODO: define functions to manipulate a tensor from ..fluid.layers.nn import _elementwise_op_in_dygraph from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only From 586e2f43696c345f99ac272bd12c719af08dc22d Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Thu, 28 Jul 2022 11:49:46 +0000 Subject: [PATCH 19/21] fix code style --- python/paddle/fluid/layers/nn.py | 1 - python/paddle/tensor/manipulation.py | 1 - 2 files changed, 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a1bfbaab2b169..03d21035bd510 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14,7 +14,6 @@ """ All layers just related to the neural network. """ - from __future__ import print_function import os diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 4828b20ba769c..4053cc43b30ee 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -23,7 +23,6 @@ from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.layers import utils import numpy as np - # TODO: define functions to manipulate a tensor from ..fluid.layers.nn import _elementwise_op_in_dygraph from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only From b429e429289cbfe90f4bc5d42e0a8fae2a47e3dd Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 29 Jul 2022 08:47:45 +0000 Subject: [PATCH 20/21] fix ci and codestyle --- paddle/phi/api/yaml/legacy_api.yaml | 26 +++++++++---------- paddle/phi/api/yaml/legacy_backward.yaml | 24 ++++++++--------- .../fluid/tests/unittests/test_ops_nms.py | 8 ------ .../fluid/tests/unittests/test_unstack_op.py | 7 +++++ 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index f9314cc02e4b8..836344f39c123 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -857,7 +857,7 @@ func : FrameInferMeta kernel : func : frame - backward : frame_grad + backward : frame_grad - api : frobenius_norm args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) @@ -2180,7 +2180,7 @@ kernel : func : spectralnorm data_type : weight - backward : spectral_norm_grad + backward : spectral_norm_grad - api : split args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) @@ -2468,16 +2468,6 @@ func : unique data_type : x -# unstack -- api : unstack - args : (Tensor x, int axis, int num) - output : Tensor[]{num} - infer_meta : - func : UnStackInferMeta - kernel : - func : unstack - backward : unstack_grad - - api : unique_consecutive args : (Tensor x, bool return_inverse, bool return_counts, int[] axis, int dtype) output : Tensor(out), Tensor(index), Tensor(counts) @@ -2498,6 +2488,16 @@ intermediate : xshape backward : unsqueeze_grad +# unstack +- api : unstack + args : (Tensor x, int axis, int num) + output : Tensor[]{num} + infer_meta : + func : UnStackInferMeta + kernel : + func : unstack + backward : unstack_grad + # viterbi_decode - api : viterbi_decode args : (Tensor input, Tensor transition, Tensor length, bool include_bos_eos_tag) @@ -2561,7 +2561,7 @@ kernel: func: broadcast_tensors backward: broadcast_tensors_grad - + # dirichlet - api: dirichlet args: (Tensor alpha) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 99fd15b00710a..e629d4d66de5f 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -894,10 +894,10 @@ forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out) args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners) output : Tensor(x_grad), Tensor(grid_grad) - infer_meta : + infer_meta : func : GeneralBinaryGradInferMeta param : [x, grid] - kernel : + kernel : func : grid_sample_grad data_type : x @@ -2420,16 +2420,6 @@ func : unfold_grad no_need_buffer : x -- backward_api : unstack_grad - forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) - args : (Tensor[] out_grad, int axis) - output : Tensor(x_grad) - infer_meta : - func : UnStackGradInferMeta - param : [out_grad, axis] - kernel : - func : unstack_grad - - backward_api : unsqueeze_double_grad forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x) args : (Tensor grad_x_grad, IntArray axes) @@ -2449,6 +2439,16 @@ inplace : (out_grad -> x_grad) backward : unsqueeze_double_grad +- backward_api : unstack_grad + forward : unstack (Tensor x, int axis, int num) -> Tensor[](out) + args : (Tensor[] out_grad, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnStackGradInferMeta + param : [out_grad, axis] + kernel : + func : unstack_grad + - backward_api : warpctc_grad forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad) args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times) diff --git a/python/paddle/fluid/tests/unittests/test_ops_nms.py b/python/paddle/fluid/tests/unittests/test_ops_nms.py index 18b022b98d1ce..3d6f2b717f261 100644 --- a/python/paddle/fluid/tests/unittests/test_ops_nms.py +++ b/python/paddle/fluid/tests/unittests/test_ops_nms.py @@ -88,7 +88,6 @@ def tearDown(self): self.temp_dir.cleanup() def test_nms(self): - paddle.disable_static() for device in self.devices: for dtype in self.dtypes: boxes, scores, category_idxs, categories = gen_args( @@ -104,10 +103,8 @@ def test_nms(self): self.assertTrue( np.array_equal(out.numpy(), out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) - paddle.enable_static() def test_multiclass_nms_dynamic(self): - paddle.disable_static() for device in self.devices: for dtype in self.dtypes: boxes, scores, category_idxs, categories = gen_args( @@ -124,10 +121,8 @@ def test_multiclass_nms_dynamic(self): self.assertTrue( np.array_equal(out.numpy(), out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) - paddle.enable_static() def test_multiclass_nms_static(self): - paddle.disable_static() for device in self.devices: for dtype in self.dtypes: paddle.enable_static() @@ -165,10 +160,8 @@ def test_multiclass_nms_static(self): self.assertTrue( np.array_equal(out, out_py), "paddle out: {}\n py out: {}\n".format(out, out_py)) - paddle.enable_static() def test_multiclass_nms_dynamic_to_static(self): - paddle.disable_static() for device in self.devices: for dtype in self.dtypes: paddle.set_device(device) @@ -203,7 +196,6 @@ def fun(x): np.array_equal(origin, res), "origin out: {}\n inference model out: {}\n".format( origin, res)) - paddle.enable_static() def test_matrix_nms_dynamic(self): for device in self.devices: diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index bb28bdeba79d3..d8342e1020bd1 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -87,5 +87,12 @@ def initParameters(self): self.axis = 2 +class TestStackOp7(TestUnStackOpBase): + + def initParameters(self): + self.input_dim = (0, 0, 0) + self.axis = 1 + + if __name__ == '__main__': unittest.main() From 8c578bc6b8829bcba5148d2eb23d2c99f974a850 Mon Sep 17 00:00:00 2001 From: HexToString <506181616@qq.com> Date: Fri, 29 Jul 2022 11:02:11 +0000 Subject: [PATCH 21/21] fix ci --- python/paddle/fluid/tests/unittests/test_unstack_op.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index d8342e1020bd1..bb28bdeba79d3 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -87,12 +87,5 @@ def initParameters(self): self.axis = 2 -class TestStackOp7(TestUnStackOpBase): - - def initParameters(self): - self.input_dim = (0, 0, 0) - self.axis = 1 - - if __name__ == '__main__': unittest.main()