Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 4】No.63 : add lerp bf16 support (#53078)
Browse files Browse the repository at this point in the history
* add lerp bf16 support

* fix bug

* Update test_lerp_op.py

modify the input dtype

* modify the test_lerp_op.py

* Update test_lerp_op.py

* fix bug of import

* add user_defined_grads

* Update test_lerp_op.py

* fix bug of grad

* fix bug of grad

* fix bug of grad

* add the check for bfloat16 dtype
  • Loading branch information
longranger2 authored Jul 3, 2023
1 parent c4d5ec6 commit ce31a72
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/lerp_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,5 +279,6 @@ PD_REGISTER_KERNEL(lerp_grad,
ALL_LAYOUT,
phi::LerpGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/lerp_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,6 @@ PD_REGISTER_KERNEL(lerp,
ALL_LAYOUT,
phi::LerpKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
15 changes: 9 additions & 6 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4512,9 +4512,9 @@ def lerp(x, y, weight, name=None):
lerp(x, y, weight) = x + weight * (y - x).
Args:
x (Tensor): An N-D Tensor with starting points, the data type is float16, float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is float16, float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float16, float32, float64.
x (Tensor): An N-D Tensor with starting points, the data type is bfloat16, float16, float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is bfloat16, float16, float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is bfloat16, float16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand All @@ -4539,13 +4539,16 @@ def lerp(x, y, weight, name=None):
return _C_ops.lerp(x, y, weight)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'lerp'
x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64'], 'lerp'
y, 'y', ['uint16', 'float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'lerp'
weight,
'weight',
['uint16', 'float16', 'float32', 'float64'],
'lerp',
)

helper = LayerHelper('lerp', **locals())
Expand Down
62 changes: 61 additions & 1 deletion test/legacy_test/test_lerp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core
Expand Down Expand Up @@ -220,5 +220,65 @@ def test_x_y_broadcast_w(self):
paddle.enable_static()


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestLerpBF16(TestLerp):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.dtype = np.uint16
self.init_shape()
self.init_xyshape()
self.init_wshape()
x = np.arange(1.0, 101.0).astype("float32").reshape(self.xshape)
y = np.full(100, 10.0).astype("float32").reshape(self.yshape)
w = np.random.random(self.wshape).astype("float32")
self.init_grad(w)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
'Weight': convert_float_to_uint16(w),
}
self.outputs = {'Out': convert_float_to_uint16(x + w * (y - x))}

def init_shape(self):
self.shape = [100]

def init_xyshape(self):
self.xshape = self.shape
self.yshape = self.shape

def init_wshape(self):
self.wshape = [1]

def init_grad(self, w):
self.x_grad = (
np.ones(self.xshape)
* (1 - w)
/ (np.prod(self.xshape) / np.prod(self.wshape))
)
self.y_grad = (
np.ones(self.yshape)
* w
/ (np.prod(self.yshape) / np.prod(self.wshape))
)

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Y'],
'Out',
user_defined_grads=[self.x_grad, self.y_grad],
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ce31a72

Please sign in to comment.