Skip to content

Commit

Permalink
Code format.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Dec 15, 2021
1 parent 42b2533 commit 8f7675d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
13 changes: 7 additions & 6 deletions paddle/fluid/operators/scatter_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>,
ops::ScatterOpCUDAKernel<double>,
ops::ScatterOpCUDAKernel<int>,
ops::ScatterOpCUDAKernel<int64_t>,
ops::ScatterOpCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel<float>,
ops::ScatterGradOpCUDAKernel<double>,
ops::ScatterOpCUDAKernel<int>,
ops::ScatterOpCUDAKernel<int64_t>,
ops::ScatterGradOpCUDAKernel<paddle::platform::float16>);
ops::ScatterOpCUDAKernel<paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
scatter_grad, ops::ScatterGradOpCUDAKernel<float>,
ops::ScatterGradOpCUDAKernel<double>, ops::ScatterOpCUDAKernel<int>,
ops::ScatterOpCUDAKernel<int64_t>,
ops::ScatterGradOpCUDAKernel<paddle::platform::float16>);
14 changes: 8 additions & 6 deletions python/paddle/fluid/tests/unittests/test_scatter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,22 +276,22 @@ def setUp(self):
self.__class__.op_type = "scatter"
# compute grad in the following code handly.
self.__class__.no_need_check_grad = True
self.x_type='float16'
self.x_type = 'float16'
self.x_np = np.ones((3, 3)).astype(self.x_type)
self.index_np = np.array([1, 2]).astype("int32")
self.updates_np = np.random.random((2, 3)).astype(self.x_type)
self.output_np = np.copy(self.x_np)
self.output_np[self.index_np] = self.updates_np
self.dout_np = np.random.random((3, 3)).astype(self.x_type)

# compute ref_dx
self.ref_dx = np.copy(self.dout_np)
zero_np = np.zeros((2, 3)).astype(self.x_type)
self.ref_dx[self.index_np] = zero_np

def compute_ref_grad_updates(self):
ref_grad_updates = paddle.gather(paddle.to_tensor(self.dout_np),
paddle.to_tensor(self.index_np))
ref_grad_updates = paddle.gather(
paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np))
return ref_grad_updates

def test_scatter_fp16(self):
Expand All @@ -304,8 +304,10 @@ def test_scatter_fp16(self):
[out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True)
ref_grad_updates = self.compute_ref_grad_updates()
np.testing.assert_allclose(
ref_grad_updates.numpy(), updates_tensor.grad.numpy(),
rtol=1e-5, atol=1e-5)
ref_grad_updates.numpy(),
updates_tensor.grad.numpy(),
rtol=1e-5,
atol=1e-5)
np.testing.assert_allclose(
self.ref_dx, x_tensor.grad.numpy(), rtol=1e-5, atol=1e-5)

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,9 @@ def scatter(x, index, updates, overwrite=True, name=None):
if in_dygraph_mode():
return _C_ops.scatter(x, index, updates, 'overwrite', overwrite)

check_variable_and_dtype(x, 'dtype', ['float16', 'float32', 'float64', 'int32', 'int64'], 'scatter')
check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'float16', 'int32', 'int64'],
'scatter')
check_type(overwrite, 'overwrite', bool, 'scatter')
helper = LayerHelper('scatter', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down

0 comments on commit 8f7675d

Please sign in to comment.