-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No.54:为 Paddle allclose、isclose 算子实现 float16 数据类型支持 #51168
Changes from 29 commits
09c2f1f
6b8b365
3a7e42d
b282bb7
33de245
140f125
28e8285
48624a7
72dabe3
4ba5d0d
cdff335
4a08cd1
a5219a4
ee7e029
4512f4c
7c49976
b56c66d
6766642
9e99cb4
3ec9289
818aef2
2e3e240
e56bd7c
7d0aa4f
50c62f7
66489c4
380b15a
f32f140
0079024
7dbe9db
8e38fe6
db57854
c41e60f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,14 @@ | |
#include "paddle/phi/kernels/isclose_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/data_type.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h" | ||
|
||
PD_REGISTER_KERNEL( | ||
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} | ||
PD_REGISTER_KERNEL(isclose, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::IscloseKernel, | ||
float, | ||
double, | ||
phi::dtype::float16) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要BF16的注册 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在这个题目中暂时不添加bf16的支持。我们会有另外的题目 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
from op_test import OpTest | ||
|
||
import paddle | ||
import paddle.fluid.core as core | ||
|
||
|
||
class TestAllcloseOp(OpTest): | ||
|
@@ -134,7 +135,7 @@ def test_x_dtype(): | |
with paddle.static.program_guard( | ||
paddle.static.Program(), paddle.static.Program() | ||
): | ||
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') | ||
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') | ||
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') | ||
result = paddle.allclose(x, y) | ||
|
||
|
@@ -170,6 +171,36 @@ def test_equal_nan(): | |
self.assertRaises(TypeError, test_equal_nan) | ||
|
||
|
||
class TestAllcloseOpFp16(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请继承OpTest框架进行测试 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经添加了继承OpTest的单测,这个单测是测的静态图分支 |
||
def test_fp16(self): | ||
x_data = np.random.rand(10, 10).astype('float16') | ||
y_data = np.random.rand(10, 10).astype('float16') | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') | ||
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') | ||
out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08) | ||
if core.is_compiled_with_cuda(): | ||
place = paddle.CUDAPlace(0) | ||
exe = paddle.static.Executor(place) | ||
exe.run(paddle.static.default_startup_program()) | ||
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) | ||
|
||
|
||
class TestAllcloseOpFloat16(TestAllcloseOp): | ||
def set_args(self): | ||
self.input = np.array([10.1]).astype("float16") | ||
self.other = np.array([10]).astype("float16") | ||
self.rtol = np.array([0.01]).astype("float64") | ||
self.atol = np.array([0]).astype("float64") | ||
self.equal_nan = False | ||
|
||
def test_check_output(self): | ||
if core.is_compiled_with_cuda(): | ||
place = core.CUDAPlace(0) | ||
if core.is_float16_supported(place): | ||
self.check_output_with_place(place, check_eager=True) | ||
|
||
|
||
class TestAllcloseOpFloat32(TestAllcloseOp): | ||
def set_args(self): | ||
self.input = np.array([10.1]).astype("float32") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
from op_test import OpTest | ||
|
||
import paddle | ||
import paddle.fluid.core as core | ||
|
||
|
||
class TestIscloseOp(OpTest): | ||
|
@@ -166,7 +167,7 @@ def test_x_dtype(): | |
with paddle.static.program_guard( | ||
paddle.static.Program(), paddle.static.Program() | ||
): | ||
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') | ||
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32') | ||
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') | ||
result = paddle.isclose(x, y) | ||
|
||
|
@@ -203,6 +204,36 @@ def test_equal_nan(): | |
self.assertRaises(TypeError, test_equal_nan) | ||
|
||
|
||
class TestIscloseOpFp16(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请利用OpTest框架写测试 |
||
def test_fp16(self): | ||
x_data = np.random.rand(10, 10).astype('float16') | ||
y_data = np.random.rand(10, 10).astype('float16') | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16') | ||
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16') | ||
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) | ||
if core.is_compiled_with_cuda(): | ||
place = paddle.CUDAPlace(0) | ||
exe = paddle.static.Executor(place) | ||
exe.run(paddle.static.default_startup_program()) | ||
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out]) | ||
|
||
|
||
class TestIscloseOpFloat16(TestIscloseOp): | ||
def set_args(self): | ||
self.input = np.array([10.1]).astype("float16") | ||
self.other = np.array([10]).astype("float16") | ||
self.rtol = np.array([0.01]).astype("float64") | ||
self.atol = np.array([0]).astype("float64") | ||
self.equal_nan = False | ||
|
||
def test_check_output(self): | ||
if core.is_compiled_with_cuda(): | ||
place = core.CUDAPlace(0) | ||
if core.is_float16_supported(place): | ||
self.check_output_with_place(place, check_eager=True) | ||
|
||
|
||
class TestIscloseOpFloat32(TestIscloseOp): | ||
def set_args(self): | ||
self.input = np.array([10.1]).astype("float32") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要增加BF16的注册