Skip to content

Commit

Permalink
add cpu kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Feb 9, 2022
1 parent 160725a commit f907a86
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 31 deletions.
16 changes: 12 additions & 4 deletions paddle/fluid/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
Expand All @@ -406,7 +408,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
Expand All @@ -419,7 +423,9 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
squeeze2_grad,
Expand All @@ -433,4 +439,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
23 changes: 13 additions & 10 deletions paddle/fluid/operators/stack_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,16 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);

REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>);

REGISTER_OP_CPU_KERNEL(stack_grad,
ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>,
ops::StackKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);

REGISTER_OP_CPU_KERNEL(
stack_grad, ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>,
ops::StackGradKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);
16 changes: 12 additions & 4 deletions paddle/fluid/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
Expand All @@ -379,7 +381,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
Expand All @@ -391,7 +395,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
Expand All @@ -404,4 +410,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
paddle::platform::complex<double>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
6 changes: 2 additions & 4 deletions python/paddle/fluid/tests/unittests/test_squeeze_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def setUp(self):
self.outputs = {"Out": convert_float_to_uint16(out)}

def test_check_output(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
self.check_output()

def test_check_grad(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
self.check_grad(["X"], "Out")

def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
Expand Down
7 changes: 2 additions & 5 deletions python/paddle/fluid/tests/unittests/test_stack_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,10 @@ def setUp(self):
self.attrs = {'axis': self.axis}

def test_check_output(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
self.check_output()

def test_check_grad(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0), self.get_x_names(), 'Y')
self.check_grad(self.get_x_names(), 'Y')


class TestStackAPIWithLoDTensorArray(unittest.TestCase):
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@ def setUp(self):
self.outputs = {"Out": convert_float_to_uint16(out)}

def test_check_output(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
self.check_output()

def test_check_grad(self):
if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out")
self.check_grad(["X"], "Out")

def init_test_case(self):
self.ori_shape = (3, 40)
Expand Down

0 comments on commit f907a86

Please sign in to comment.