Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bf16] add bf16 kernel: scale gather sum #39683

Merged
merged 17 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>);
ops::GatherOpKernel<int64_t>,
ops::GatherOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>);
ops::GatherGradientOpKernel<int64_t>,
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/gather_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<plat::float16>);
ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>,
ops::GatherGradOpCUDAKernel<plat::float16>);
ops::GatherGradOpCUDAKernel<plat::float16>,
ops::GatherGradOpCUDAKernel<plat::bfloat16>);
2 changes: 2 additions & 0 deletions paddle/fluid/operators/math/selected_rows_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/math_function.h"
Expand Down Expand Up @@ -436,6 +437,7 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::bfloat16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
template struct MergeAdd<platform::CUDADeviceContext,
platform::complex<double>>;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/sum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::SumKernel<paddle::platform::CUDADeviceContext, double>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);
67 changes: 67 additions & 0 deletions paddle/fluid/platform/device/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include <hip/hip_runtime.h>
#endif
#include <stdio.h>
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

Expand Down Expand Up @@ -149,6 +150,72 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
#endif
#endif

// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
bfloat16 low_half;
// the bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
bfloat16 high_half;
// the bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) {
return *reinterpret_cast<bfloat16 *>(&x);
}

static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) {
return *reinterpret_cast<__nv_bfloat16 *>(&x);
}

CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
PDBF16ToCUDABF16(val)));
}
#else
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
// concrete packed bfloat16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t sum;
uint32_t newval;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// the bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed,
bf16_add_to_low_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// the bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed,
bf16_add_to_high_half(assumed, val_f));
} while (old != assumed);
bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif

CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
float *real = reinterpret_cast<float *>(address);
float *imag = real + 1;
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ struct CBlas<phi::dtype::bfloat16> {
"Blas VCOPY do not supported on CPU with bfloat16,"
" please check your code"));
}

template <typename... ARGS>
static void VADD(int n,
const phi::dtype::bfloat16 *x,
const phi::dtype::bfloat16 *y,
phi::dtype::bfloat16 *z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
}
}
};

#ifdef PADDLE_WITH_MKLML
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/scale_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PT_REGISTER_KERNEL(scale,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
uint8_t,
int8_t,
int16_t,
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,12 @@ def _append_ops(self, block):

op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
inputs = append_input_output(block, op_proto, self.inputs, True,
self.dtype)
outputs = append_input_output(block, op_proto, self.outputs, False,
Expand Down
29 changes: 28 additions & 1 deletion python/paddle/fluid/tests/unittests/test_gather_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
from paddle.framework import core
Expand Down Expand Up @@ -117,6 +117,33 @@ def config(self):
self.index_type = "int32"


class TestGatherBF16Op(OpTest):
def setUp(self):
self.op_type = "gather"
self.dtype = np.uint16
self.config()
xnp = np.random.random(self.x_shape).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(xnp),
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.5)

def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.index = [1, 3, 5]
self.index_type = "int32"


class TestGatherOp1(OpTest):
def setUp(self):
self.op_type = "gather"
Expand Down
19 changes: 18 additions & 1 deletion python/paddle/fluid/tests/unittests/test_scale_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
Expand Down Expand Up @@ -153,6 +153,23 @@ def test_check_grad(self):
place, ["X"], "Out", max_relative_error=0.05)


class TestScaleBF16Op(OpTest):
def setUp(self):
self.op_type = "scale"
self.dtype = np.uint16
self.attrs = {'scale': -2.3}
x = np.random.random((10, 10)).astype(np.float32)
out = x * np.float32(self.attrs['scale'])
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.8)


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):
Expand Down
26 changes: 26 additions & 0 deletions python/paddle/fluid/tests/unittests/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,32 @@ def test_w_is_selected_rows(self):
globals()[cls_name] = TestSumFp16Case


#----------- test bf16 -----------
class TestSumBF16Op(OpTest):
def setUp(self):
self.op_type = "sum"
self.init_kernel_type()
x0 = np.random.random((3, 40)).astype(np.float32)
x1 = np.random.random((3, 40)).astype(np.float32)
x2 = np.random.random((3, 40)).astype(np.float32)
y = x0 + x1 + x2
self.inputs = {
"X": [("x0", convert_float_to_uint16(x0)),
("x1", convert_float_to_uint16(x1)),
("x2", convert_float_to_uint16(x2))]
}
self.outputs = {'Out': convert_float_to_uint16(y)}

def init_kernel_type(self):
self.dtype = np.uint16

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['x0'], 'Out', numeric_grad_delta=0.5)


class API_Test_Add_n(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
Expand Down