Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bf16] add bf16 kernel: gaussian_random fill_constant fill_any_like #40027

Merged
merged 11 commits into from
Mar 7, 2022
3 changes: 2 additions & 1 deletion paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ struct GaussianGenerator {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_));
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/kernels/funcs/distribution_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/hostdevice.h"
Expand Down Expand Up @@ -255,11 +256,13 @@ __global__ void DistributionKernel(size_t size,
using SType = hiprandStatePhilox4_32_10_t;
#endif
size_t total_thread = GRID_NUM_X * BLOCK_NUM_X;
T args[kCount];
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT args[kCount];
T result[kCount];
for (size_t i = idx; i < size; i += total_thread * kCount) {
kps::ElementwiseRandom<SType, T, kCount, 1, DistOp>(&args[0], dist, &state);
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(
kps::ElementwiseRandom<SType, MT, kCount, 1, DistOp>(
&args[0], dist, &state);
kps::ElementwiseUnary<MT, T, kCount, 1, 1, TransformOp>(
&result[0], &args[0], trans);
kps::WriteData<T, T, kCount, 1, 1, true>(
out_data + i, &result[0], size - i, 1, stride, 1);
Expand Down
10 changes: 7 additions & 3 deletions paddle/phi/kernels/gpu/full_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ void FullLikeKernel(const Context& dev_ctx,
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
float,
T>::type>::type;
typename std::conditional<
std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value,
float,
T>::type>::type;

auto common_type_value = static_cast<CommonType>(value);

Expand Down Expand Up @@ -110,6 +112,7 @@ PD_REGISTER_KERNEL(full,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

Expand All @@ -123,6 +126,7 @@ PD_REGISTER_KERNEL(full_like,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
13 changes: 8 additions & 5 deletions paddle/phi/kernels/gpu/gaussian_random_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand All @@ -46,8 +46,9 @@ struct GaussianGenerator {
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_));
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
Expand Down Expand Up @@ -83,9 +84,10 @@ void GaussianRandomKernel(const Context& dev_ctx,

if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::normal_transform<MT> trans(static_cast<MT>(mean),
static_cast<MT>(std));
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
Expand All @@ -110,5 +112,6 @@ PD_REGISTER_KERNEL(gaussian_random,
ALL_LAYOUT,
phi::GaussianRandomKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
1 change: 1 addition & 0 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#endif

#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
// #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"

namespace phi {
Expand Down
21 changes: 20 additions & 1 deletion python/paddle/fluid/tests/unittests/test_fill_any_like_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle.compat as cpt
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16


class TestFillAnyLikeOp(OpTest):
Expand All @@ -47,6 +47,25 @@ def init(self):
self.value = 0.0


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFillAnyLikeOpBfloat16(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.dtype = np.uint16
self.value = 0.0
self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)}
self.attrs = {'value': self.value, 'dtype': core.VarDesc.VarType.BF16}
self.outputs = {
'Out':
convert_float_to_uint16(self.value * np.ones_like(self.inputs["X"]))
}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)


class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self):
self.value = 1.0
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fill_constant_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ def test_check_output(self):
self.check_output()


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFillConstantBF16Op(OpTest):
def setUp(self):
'''Test fill_constant op with specified value
'''
self.op_type = "fill_constant"
self.dtype = np.uint16
self.inputs = {}
self.attrs = {
'shape': [123, 92],
'value': 3.8,
'dtype': core.VarDesc.VarType.BF16
}
self.outputs = {'Out': convert_float_to_uint16(np.full((123, 92), 3.8))}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)


class TestFillConstantOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
Expand Down
46 changes: 45 additions & 1 deletion python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float
import paddle


Expand Down Expand Up @@ -65,6 +65,50 @@ def verify_output(self, outs):
"hist: " + str(hist) + " hist2: " + str(hist2))


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestGaussianRandomBF16Op(OpTest):
def setUp(self):
self.op_type = "gaussian_random"
self.set_attrs()
self.inputs = {}
self.use_mkldnn = False
self.attrs = {
"shape": [123, 92],
"mean": self.mean,
"std": self.std,
"seed": 10,
"dtype": paddle.fluid.core.VarDesc.VarType.BF16,
"use_mkldnn": self.use_mkldnn
}
paddle.seed(10)

self.outputs = {'Out': np.zeros((123, 92), dtype='float32')}

def set_attrs(self):
self.mean = 1.0
self.std = 2.

def test_check_output(self):
self.check_output_with_place_customized(
self.verify_output, place=core.CUDAPlace(0))

def verify_output(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, (123, 92))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float32")
hist /= float(outs[0].size)
data = np.random.normal(size=(123, 92), loc=1, scale=2)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float32")
hist2 /= float(outs[0].size)
self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.05),
"hist: " + str(hist) + " hist2: " + str(hist2))


class TestMeanStdAreInt(TestGaussianRandomOp):
def set_attrs(self):
self.mean = 1
Expand Down