diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index be3ca5f68fbf3..c8dae35822544 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -111,9 +111,10 @@ REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel, ops::ScatterOpCUDAKernel, ops::ScatterOpCUDAKernel, ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel, - ops::ScatterGradOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel, - ops::ScatterGradOpCUDAKernel); + ops::ScatterOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + scatter_grad, ops::ScatterGradOpCUDAKernel, + ops::ScatterGradOpCUDAKernel, ops::ScatterOpCUDAKernel, + ops::ScatterOpCUDAKernel, + ops::ScatterGradOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 1f8daacf01a9a..ad542da781670 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -276,22 +276,22 @@ def setUp(self): self.__class__.op_type = "scatter" # compute grad in the following code handly. self.__class__.no_need_check_grad = True - self.x_type='float16' + self.x_type = 'float16' self.x_np = np.ones((3, 3)).astype(self.x_type) self.index_np = np.array([1, 2]).astype("int32") self.updates_np = np.random.random((2, 3)).astype(self.x_type) self.output_np = np.copy(self.x_np) self.output_np[self.index_np] = self.updates_np self.dout_np = np.random.random((3, 3)).astype(self.x_type) - + # compute ref_dx self.ref_dx = np.copy(self.dout_np) zero_np = np.zeros((2, 3)).astype(self.x_type) self.ref_dx[self.index_np] = zero_np def compute_ref_grad_updates(self): - ref_grad_updates = paddle.gather(paddle.to_tensor(self.dout_np), - paddle.to_tensor(self.index_np)) + ref_grad_updates = paddle.gather( + paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np)) return ref_grad_updates def test_scatter_fp16(self): @@ -304,8 +304,10 @@ def test_scatter_fp16(self): [out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True) ref_grad_updates = self.compute_ref_grad_updates() np.testing.assert_allclose( - ref_grad_updates.numpy(), updates_tensor.grad.numpy(), - rtol=1e-5, atol=1e-5) + ref_grad_updates.numpy(), + updates_tensor.grad.numpy(), + rtol=1e-5, + atol=1e-5) np.testing.assert_allclose( self.ref_dx, x_tensor.grad.numpy(), rtol=1e-5, atol=1e-5) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5288b2bd8cf87..75fb262ebdb6d 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1565,7 +1565,9 @@ def scatter(x, index, updates, overwrite=True, name=None): if in_dygraph_mode(): return _C_ops.scatter(x, index, updates, 'overwrite', overwrite) - check_variable_and_dtype(x, 'dtype', ['float16', 'float32', 'float64', 'int32', 'int64'], 'scatter') + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'float16', 'int32', 'int64'], + 'scatter') check_type(overwrite, 'overwrite', bool, 'scatter') helper = LayerHelper('scatter', **locals()) out = helper.create_variable_for_type_inference(x.dtype)