Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No.54:为 Paddle allclose、isclose 算子实现 float16 数据类型支持 #51168

Merged
merged 33 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
09c2f1f
fp16 support
mrcangye Mar 3, 2023
6b8b365
Update test_allclose_op.py
mrcangye Mar 3, 2023
3a7e42d
fp16 support
mrcangye Mar 6, 2023
b282bb7
Merge branch 'fp16' of github.com:mrcangye/Paddle into fp16
mrcangye Mar 6, 2023
33de245
fp16 support 03062101
mrcangye Mar 6, 2023
140f125
fp16 support 03062127
mrcangye Mar 6, 2023
28e8285
Merge branch 'PaddlePaddle:develop' into fp16
mrcangye Mar 6, 2023
48624a7
Update isclose_kernel.cu
mrcangye Mar 6, 2023
72dabe3
Update allclose_kernel.cu
mrcangye Mar 6, 2023
4ba5d0d
Update allclose_kernel.cu
mrcangye Mar 6, 2023
cdff335
Update isclose_kernel_impl.h
mrcangye Mar 6, 2023
4a08cd1
fp16 support 03070833
mrcangye Mar 7, 2023
a5219a4
fp16 support 03070842
mrcangye Mar 7, 2023
ee7e029
fp16 support 03070850
mrcangye Mar 7, 2023
4512f4c
fp16 support 03071325
mrcangye Mar 7, 2023
7c49976
fp16 support 03071345
mrcangye Mar 7, 2023
b56c66d
fp16 support 03071351
mrcangye Mar 7, 2023
6766642
fp16 support 03071351
mrcangye Mar 7, 2023
9e99cb4
fp16 support 03071548
mrcangye Mar 7, 2023
3ec9289
fp16 support 03071720
mrcangye Mar 7, 2023
818aef2
fp16 support 03071746
mrcangye Mar 7, 2023
2e3e240
fp16 support 03071921
mrcangye Mar 7, 2023
e56bd7c
fp16 support 03081357
mrcangye Mar 8, 2023
7d0aa4f
fp16 support 03081406
mrcangye Mar 8, 2023
50c62f7
fp16 support 03081409
mrcangye Mar 8, 2023
66489c4
fp16 support 03081529
mrcangye Mar 8, 2023
380b15a
fp16 support 03081528
mrcangye Mar 8, 2023
f32f140
fp16 support 03082041
mrcangye Mar 8, 2023
0079024
fp16 support 03082043
mrcangye Mar 8, 2023
7dbe9db
update allclose_kernel.cu, isclose_kernel.cu
mrcangye Mar 9, 2023
8e38fe6
update allclose_kernel.cu, isclose_kernel.cu
mrcangye Mar 9, 2023
db57854
Update allclose_kernel.cu
mrcangye Mar 9, 2023
c41e60f
Update allclose_kernel.cu
mrcangye Mar 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions paddle/phi/kernels/gpu/allclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "glog/logging.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -31,14 +33,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
const MPType a = static_cast<MPType>(in_data[i]);
const MPType b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
MPType left = (a > b ? a - b : b - a);
MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
if (!val) *out_data = false;
Expand Down Expand Up @@ -92,7 +96,12 @@ void AllCloseKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(
allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) {
PD_REGISTER_KERNEL(allclose,
GPU,
ALL_LAYOUT,
phi::AllCloseKernel,
float,
double,
phi::dtype::float16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要增加BF16的注册

kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/isclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
#include "paddle/phi/kernels/isclose_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"

PD_REGISTER_KERNEL(
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
PD_REGISTER_KERNEL(isclose,
GPU,
ALL_LAYOUT,
phi::IscloseKernel,
float,
double,
phi::dtype::float16) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要BF16的注册

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这个题目中暂时不添加bf16的支持。我们会有另外的题目

11 changes: 7 additions & 4 deletions paddle/phi/kernels/impl/isclose_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
const MPType a = static_cast<MPType>(in_data[i]);
const MPType b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
MPType left = (a > b ? a - b : b - a);
MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
Expand Down
33 changes: 32 additions & 1 deletion python/paddle/fluid/tests/unittests/test_allclose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
import paddle.fluid.core as core


class TestAllcloseOp(OpTest):
Expand Down Expand Up @@ -134,7 +135,7 @@ def test_x_dtype():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.allclose(x, y)

Expand Down Expand Up @@ -170,6 +171,36 @@ def test_equal_nan():
self.assertRaises(TypeError, test_equal_nan)


class TestAllcloseOpFp16(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请继承OpTest框架进行测试

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加了继承OpTest的单测,这个单测是测的静态图分支

def test_fp16(self):
x_data = np.random.rand(10, 10).astype('float16')
y_data = np.random.rand(10, 10).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])


class TestAllcloseOpFloat16(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float16")
self.other = np.array([10]).astype("float16")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, check_eager=True)


class TestAllcloseOpFloat32(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float32")
Expand Down
33 changes: 32 additions & 1 deletion python/paddle/fluid/tests/unittests/test_isclose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from op_test import OpTest

import paddle
import paddle.fluid.core as core


class TestIscloseOp(OpTest):
Expand Down Expand Up @@ -166,7 +167,7 @@ def test_x_dtype():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.isclose(x, y)

Expand Down Expand Up @@ -203,6 +204,36 @@ def test_equal_nan():
self.assertRaises(TypeError, test_equal_nan)


class TestIscloseOpFp16(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请利用OpTest框架写测试

def test_fp16(self):
x_data = np.random.rand(10, 10).astype('float16')
y_data = np.random.rand(10, 10).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])


class TestIscloseOpFloat16(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float16")
self.other = np.array([10]).astype("float16")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, check_eager=True)


class TestIscloseOpFloat32(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float32")
Expand Down
24 changes: 16 additions & 8 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance.

Args:
x(Tensor): The input tensor, it's data type should be float32, float64..
y(Tensor): The input tensor, it's data type should be float32, float64..
x(Tensor): The input tensor, it's data type should be float16, float32, float64..
y(Tensor): The input tensor, it's data type should be float16, float32, float64..
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}.
Expand Down Expand Up @@ -402,8 +402,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if in_dygraph_mode():
return _C_ops.allclose(x, y, rtol, atol, equal_nan)
else:
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose')
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose')
check_variable_and_dtype(
x, "input", ['float16', 'float32', 'float64'], 'allclose'
)
check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'allclose'
)
check_type(rtol, 'rtol', float, 'allclose')
check_type(atol, 'atol', float, 'allclose')
check_type(equal_nan, 'equal_nan', bool, 'allclose')
Expand Down Expand Up @@ -990,8 +994,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance.

Args:
x(Tensor): The input tensor, it's data type should be float32, float64.
y(Tensor): The input tensor, it's data type should be float32, float64.
x(Tensor): The input tensor, it's data type should be float16, float32, float64.
y(Tensor): The input tensor, it's data type should be float16, float32, float64.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
Expand Down Expand Up @@ -1028,8 +1032,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if in_dygraph_mode():
return _C_ops.isclose(x, y, rtol, atol, equal_nan)
else:
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'isclose')
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'isclose')
check_variable_and_dtype(
x, "input", ['float16', 'float32', 'float64'], 'isclose'
)
check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'isclose'
)
check_type(rtol, 'rtol', float, 'isclose')
check_type(atol, 'atol', float, 'isclose')
check_type(equal_nan, 'equal_nan', bool, 'isclose')
Expand Down