diff --git a/paddle/fluid/operators/gather_op_xpu.cc b/paddle/fluid/operators/gather_op_xpu.cc new file mode 100644 index 0000000000000..ae3d0f2633bb1 --- /dev/null +++ b/paddle/fluid/operators/gather_op_xpu.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/gather_op.h" +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_version_registry.h" +namespace paddle { +namespace operators { + +template +class GatherOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on XPU.")); + + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + if (ctx.HasInput("Axis")) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Now, it doesn't support XPU with Axis.")); + } + + output->mutable_data(ctx.GetPlace()); + if (x->numel() == 0) return; + // check index type is INT32 + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "XPU only support INT32, it holds %s, but desires to be %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32))); + + const auto index_dims = index->dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), 1, + platform::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + int slice_size = x->numel() / x->dims()[0]; + auto &dev_ctx = ctx.template device_context(); + int r = + xpu::gather(dev_ctx.x_context(), x->data(), index->data(), + index->dims()[0], slice_size, output->data()); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU kernel error! error code=%d", r)); + } +}; + +template +class GatherGradOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("This kernel only runs on XPU.")); + + auto *index = ctx.Input("Index"); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto &dev_ctx = ctx.template device_context(); + + if (ctx.HasInput("Axis")) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Now, it doesn't support XPU with Axis.")); + } + + dx->mutable_data(ctx.GetPlace()); + const int zero = 0; + int r_dx = xpu::memset(dev_ctx.x_context(), dx->data(), zero, + dx->numel() * sizeof(T)); + PADDLE_ENFORCE_EQ( + r_dx, xpu::Error_t::SUCCESS, + platform::errors::External("XPU kernel error! error code=%d", r_dx)); + + if (dout->numel() == 0) { + return; + } + bool overwrite = ctx.Attr("overwrite"); + // check index type is INT32 + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "XPU only support INT32, it holds %s, but desires to be %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32))); + + const auto index_dims = index->dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), 1, + platform::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + + int index_size = index_dims[0]; + int slice_size = dout->numel() / dout->dims()[0]; + + int r = xpu::scatter(dev_ctx.x_context(), dout->data(), + index->data(), index_size, slice_size, + dx->data(), overwrite); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU kernel error! error code=%d", r)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL(gather, ops::GatherOpXPUKernel); +REGISTER_OP_XPU_KERNEL(gather_grad, ops::GatherGradOpXPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_gather_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_gather_op_xpu.py new file mode 100644 index 0000000000000..9bea33e484e19 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_gather_op_xpu.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +def gather_numpy(x, index, axis): + x_transpose = np.swapaxes(x, 0, axis) + tmp_gather = x_transpose[index, ...] + gather = np.swapaxes(tmp_gather, 0, axis) + return gather + + +class TestGatherOp(OpTest): + def setUp(self): + self.op_type = "gather" + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = { + 'X': xnp, + 'Index': np.array(self.index).astype(self.index_type) + } + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.x_type = "float64" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestXPUGatherOp(OpTest): + def setUp(self): + self.op_type = "gather" + self.dtype = np.float32 + self.attrs = {'use_xpu': True} + + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = { + 'X': xnp, + 'Index': np.array(self.index).astype(self.index_type) + } + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.x_type = self.dtype + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase1(TestXPUGatherOp): + def config(self): + """ + For one dimension input + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase2(TestXPUGatherOp): + def config(self): + """ + For int64_t index type + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase3(TestXPUGatherOp): + def config(self): + """ + For other input type + """ + self.x_shape = (10, 20) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase4(TestXPUGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'use_xpu': True, 'overwrite': False} + self.x_type = "float32" + self.index = [1, 1] + self.index_type = "int32" + + +class TestCase5(TestXPUGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'use_xpu': True, 'overwrite': False} + self.x_type = "float32" + self.index = [1, 1, 3] + self.index_type = "int32" + + +class TestCase6(TestXPUGatherOp): + def config(self): + self.x_shape = (10, 20) + self.attrs = {'use_xpu': True, 'overwrite': True} + self.x_type = "float32" + self.index = [1, 3] + self.index_type = "int32" + + +if __name__ == "__main__": + unittest.main()