diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 17234edb116e3..a3964b28eab31 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -151,6 +151,11 @@ else() cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) endif() +# ascend gather_op_npu unittest +if (WITH_ASCEND_CL) + cc_test(gather_op_npu_test SRCS gather_op_npu_test.cc DEPS gather_op tensor op_registry scope device_context enforce executor) +endif() + cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) diff --git a/paddle/fluid/operators/gather_op_npu.cc b/paddle/fluid/operators/gather_op_npu.cc new file mode 100644 index 0000000000000..2d7b5b93ad651 --- /dev/null +++ b/paddle/fluid/operators/gather_op_npu.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_op.h" +#include +#include +#include +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/kron_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class GatherOpNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + auto runner = NpuOpRunner("Gather", {*x, *index}, {*out}, + {{"validate_indices", true}}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class GatherGradOpNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *index = ctx.Input("Index"); + auto *x = ctx.Input("X"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + + // step1: Unsqueeze index + const auto index_dims = index->dims(); + if (index_dims.size() == 1) { + framework::Tensor tmp_index = UnsqueezeTo(*index, 2); + index = &tmp_index; + } + + auto stream = + ctx.template device_context() + .stream(); + + // step2: ZerosLike x in device + Tensor *tmp_zerox = const_cast(x); + Tensor zeroslike_xout(x->type()); + zeroslike_xout.Resize(x->dims()); + zeroslike_xout.mutable_data(ctx.GetPlace()); + + auto runner_zeroslike = + NpuOpRunner("ZerosLike", {*x}, {zeroslike_xout}, {}); + runner_zeroslike.Run(stream); + tmp_zerox = &zeroslike_xout; + + // step3: scatter(x_grad) + dx->mutable_data(ctx.GetPlace()); + auto runner_scatter = NpuOpRunner("TensorScatterUpdate", + {*tmp_zerox, *index, *dout}, {*dx}, {}); + runner_scatter.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + gather, ops::GatherOpNPUKernel, + ops::GatherOpNPUKernel); + +REGISTER_OP_NPU_KERNEL( + gather_grad, + ops::GatherGradOpNPUKernel, + ops::GatherGradOpNPUKernel); diff --git a/paddle/fluid/operators/gather_op_npu_test.cc b/paddle/fluid/operators/gather_op_npu_test.cc new file mode 100644 index 0000000000000..4cd46da6f26f8 --- /dev/null +++ b/paddle/fluid/operators/gather_op_npu_test.cc @@ -0,0 +1,169 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/gather_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(gather); +USE_OP_DEVICE_KERNEL(gather, NPU); +USE_OP(gather_grad); +USE_OP_DEVICE_KERNEL(gather_grad, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + auto index = scope->Var("Index"); + auto tensor_index = index->GetMutable(); + + std::vector init_x; + for (int64_t i = 1; i < 7; ++i) { + // 1,2,3,4,5,6 + init_x.push_back(static_cast(i)); + } + + // [[1, 2],[3, 4],[5, 6]] + TensorFromVector(init_x, ctx, tensor_x); + tensor_x->Resize(paddle::framework::make_ddim({3, 2})); + + std::vector init_index = {1, 2}; + paddle::framework::TensorFromVector(init_index, ctx, tensor_index); + tensor_index->Resize(paddle::framework::make_ddim({2})); + + ctx.Wait(); + + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + // run + f::AttributeMap attrs = {{"validate_indices", true}}; + auto op = f::OpRegistry::CreateOp( + op_type, {{"X", {"X"}}, {"Index", {"Index"}}}, {{"Out", {"Out"}}}, attrs); + + auto place = ctx.GetPlace(); + op->Run(*scope, place); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + ctx.Wait(); + + // ref:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/tensor/manipulation/gather_cn.html#gather + for (int i = 0; i < static_cast(out_vec.size()); ++i) { + VLOG(3) << "out_vec[" << i << "] : " << out_vec[i]; + } + uint32_t expected_size = 4; + EXPECT_EQ((uint32_t)out_vec.size(), expected_size); + + // {3, 4, 5, 6} + std::vector expected_out_vec; + for (int64_t i = 3; i < 7; ++i) { + expected_out_vec.push_back(static_cast(i)); + } + for (uint32_t i = 0; i < out_vec.size(); i++) { + EXPECT_EQ(out_vec[i], expected_out_vec[i]); + } +} + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto index = scope->Var("Index"); + auto tensor_index = index->GetMutable(); + + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + auto dout = scope->Var("DOut"); + auto tensor_dout = dout->GetMutable(); + + std::vector init_index = {0, 1, 2, 0}; + paddle::framework::TensorFromVector(init_index, ctx, tensor_index); + tensor_index->Resize(paddle::framework::make_ddim({2, 2})); + + std::vector init_x = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + TensorFromVector(init_x, ctx, tensor_x); + tensor_x->Resize(paddle::framework::make_ddim({3, 2})); + + std::vector init_dout = {5.0, 10.0}; + TensorFromVector(init_dout, ctx, tensor_dout); + tensor_dout->Resize(paddle::framework::make_ddim({2})); + + ctx.Wait(); + + auto dx = scope->Var("DX"); + auto tensor_dx = dx->GetMutable(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp( + op_type, {{"X", {"X"}}, {"Index", {"Index"}}, {"Out@GRAD", {"DOut"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + + auto place = ctx.GetPlace(); + op->Run(*scope, place); + + std::vector dx_vec; + TensorToVector(*tensor_dx, ctx, &dx_vec); + + ctx.Wait(); + + uint32_t expected_size = 3 * 2; + EXPECT_EQ((uint32_t)dx_vec.size(), expected_size); + + std::vector expected_dx_vec = {0.0, 5.0, 0.0, 0.0, 10.0, 0.0}; + for (uint32_t i = 0; i < dx_vec.size(); i++) { + VLOG(3) << "dx_vec[i]=" << dx_vec[i]; + EXPECT_EQ(dx_vec[i], expected_dx_vec[i]); + } +} + +TEST(gather, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "gather"); +} + +TEST(gather, NPU_fp16) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "gather"); +} + +TEST(gather_grad, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx, "gather_grad"); +} diff --git a/python/paddle/fluid/tests/unittests/npu/test_gather_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_gather_op_npu.py new file mode 100644 index 0000000000000..0fcb2bee658fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_gather_op_npu.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, _set_use_system_allocator +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestGatherOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "gather" + self.place = paddle.NPUPlace(0) + self.init_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Index': OpTest.np_dtype_to_fluid_dtype(self.index) + } + self.attrs = {'validate_indices': True} + self.outputs = {'Out': self.out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_input_output(self): + self.x = np.array([[1, 2], [3, 4], [5, 6]]).astype(self.dtype) + self.index = np.array([1, 2]).astype(np.int) + self.out = np.array([[3, 4], [5, 6]]).astype(self.dtype) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestGatherAPI(unittest.TestCase): + def test_name(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 2], dtype="float32") + index = paddle.static.data(name='index', shape=[1], dtype='int32') + + out = paddle.gather(x, index, name='gather') + self.assertEqual(('gather' in out.name), True) + + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + + x_np = np.array([[1, 2], [3, 4], [5, 6]]).astype('float32') + index_np = np.array([1, 2]).astype('int32') + + x = paddle.static.data(name="x", shape=[3, 2], dtype='float32') + index = paddle.static.data(name="index", shape=[2], dtype='int32') + + z = paddle.gather(x, index) + + place = paddle.NPUPlace(0) + exe = paddle.static.Executor(place) + x_value, index_value, z_value = exe.run( + feed={"x": x_np, + "index": index_np}, fetch_list=[x, index, z]) + + z_expected = np.array([[3, 4], [5, 6]]) + self.assertEqual( + (x_value == x_np).all(), + True, + msg="x_value = {}, but expected {}".format(x_value, x_np)) + self.assertEqual( + (index_value == index_np).all(), + True, + msg="index_value = {}, but expected {}".format(index_value, + index_np)) + self.assertEqual( + (z_value == z_expected).all(), + True, + msg="z_value = {}, but expected {}".format(z_value, z_expected)) + + def test_backward(self): + # TODO(ascendrc): Test backward after add grad npu op implemented. + pass + + +if __name__ == '__main__': + unittest.main()