Skip to content

Commit

Permalink
[bf16] add bf16 kernel: scale gather sum (#39683)
Browse files Browse the repository at this point in the history
* add scale gather sum

* refine CUDA_ATOMIC_WRAPPER ADD for bf16

* add gather unittest

* solve conflict

* add scale uinttest

* add sum unittest

* solve conflict

* refine gather unittest

* refine unittest
  • Loading branch information
zhangbo9674 authored Mar 1, 2022
1 parent 9de7989 commit 6d26b33
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 8 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>);
ops::GatherOpKernel<int64_t>,
ops::GatherOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>);
ops::GatherGradientOpKernel<int64_t>,
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/gather_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<plat::float16>);
ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>,
ops::GatherGradOpCUDAKernel<plat::float16>);
ops::GatherGradOpCUDAKernel<plat::float16>,
ops::GatherGradOpCUDAKernel<plat::bfloat16>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/math/selected_rows_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand Down Expand Up @@ -445,6 +446,7 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::bfloat16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/sum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::SumKernel<paddle::platform::CUDADeviceContext, double>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);
67 changes: 67 additions & 0 deletions paddle/fluid/platform/device/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include <hip/hip_runtime.h>
#endif
#include <stdio.h>
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

Expand Down Expand Up @@ -244,6 +245,72 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
#endif
#endif

// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
bfloat16 low_half;
// the bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
bfloat16 high_half;
// the bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) {
return *reinterpret_cast<bfloat16 *>(&x);
}

static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) {
return *reinterpret_cast<__nv_bfloat16 *>(&x);
}

CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
PDBF16ToCUDABF16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
// concrete packed bfloat16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed,
bf16_add_to_low_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// the bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed,
bf16_add_to_high_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif

CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address);
float *imag = real + 1;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(scale,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,12 @@ def _append_ops(self, block):

op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
inputs = append_input_output(block, op_proto, self.inputs, True,
self.dtype)
outputs = append_input_output(block, op_proto, self.outputs, False,
Expand Down
35 changes: 34 additions & 1 deletion python/paddle/fluid/tests/unittests/test_gather_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
from paddle.framework import core
Expand Down Expand Up @@ -117,6 +117,39 @@ def config(self):
self.index_type = "int32"


class TestGatherBF16Op(OpTest):
def setUp(self):
self.op_type = "gather"
self.dtype = np.uint16
self.config()
xnp = np.random.random(self.x_shape).astype(np.float32)
axis_np = np.array(self.axis).astype(self.axis_type)
index_np = np.array(self.index).astype(self.index_type)
self.inputs = {
'X': convert_float_to_uint16(xnp),
'Index': index_np,
'Axis': axis_np
}
out = gather_numpy(self.inputs['X'], index_np, axis_np[0])
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.5)

def config(self):
"""
For multi-dimension input
"""
self.x_shape = (3, 88, 3)
self.index = [1, 3, 5]
self.index_type = "int32"
self.axis = [1]
self.axis_type = "int32"


class TestGatherOp1(OpTest):
def setUp(self):
self.op_type = "gather"
Expand Down
19 changes: 18 additions & 1 deletion python/paddle/fluid/tests/unittests/test_scale_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
Expand Down Expand Up @@ -153,6 +153,23 @@ def test_check_grad(self):
place, ["X"], "Out", max_relative_error=0.05)


class TestScaleBF16Op(OpTest):
def setUp(self):
self.op_type = "scale"
self.dtype = np.uint16
self.attrs = {'scale': -2.3}
x = np.random.random((10, 10)).astype(np.float32)
out = x * np.float32(self.attrs['scale'])
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.8)


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):
Expand Down
26 changes: 26 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,32 @@ def test_w_is_selected_rows(self):
globals()[cls_name] = TestSumFp16Case


#----------- test bf16 -----------
class TestSumBF16Op(OpTest):
def setUp(self):
self.op_type = "sum"
self.init_kernel_type()
x0 = np.random.random((3, 40)).astype(np.float32)
x1 = np.random.random((3, 40)).astype(np.float32)
x2 = np.random.random((3, 40)).astype(np.float32)
y = x0 + x1 + x2
self.inputs = {
"X": [("x0", convert_float_to_uint16(x0)),
("x1", convert_float_to_uint16(x1)),
("x2", convert_float_to_uint16(x2))]
}
self.outputs = {'Out': convert_float_to_uint16(y)}

def init_kernel_type(self):
self.dtype = np.uint16

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['x0'], 'Out', numeric_grad_delta=0.5)


class API_Test_Add_n(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
Expand Down

0 comments on commit 6d26b33

Please sign in to comment.