From 17b7a202dc9a3078a0a15012a0ec440f2a0e653d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 1 Mar 2022 09:37:12 +0000 Subject: [PATCH 1/6] add gaussian random --- paddle/fluid/operators/amp/fp16_type_traits.h | 7 +++ paddle/fluid/operators/distribution_helper.h | 11 +++-- paddle/fluid/operators/gaussian_random_op.cu | 9 +++- .../kernels/primitive/compute_primitives.h | 6 +++ .../unittests/test_gaussian_random_op.py | 44 ++++++++++++++++++- 5 files changed, 70 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/amp/fp16_type_traits.h b/paddle/fluid/operators/amp/fp16_type_traits.h index f7aa0de97598d..56aebe90788fb 100644 --- a/paddle/fluid/operators/amp/fp16_type_traits.h +++ b/paddle/fluid/operators/amp/fp16_type_traits.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -32,6 +33,12 @@ class MPTypeTrait { using Type = float; }; +template <> +class MPTypeTrait { + public: + using Type = float; +}; + } // namespace details } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distribution_helper.h b/paddle/fluid/operators/distribution_helper.h index c13bf687af234..f5f5e59626343 100644 --- a/paddle/fluid/operators/distribution_helper.h +++ b/paddle/fluid/operators/distribution_helper.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" @@ -194,12 +195,14 @@ __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, using SType = hiprandStatePhilox4_32_10_t; #endif size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; - T args[kCount]; + using MT = typename paddle::operators::details::MPTypeTrait::Type; + MT args[kCount]; T result[kCount]; for (size_t i = idx; i < size; i += total_thread * kCount) { - kps::ElementwiseRandom(&args[0], dist, &state); - kps::ElementwiseUnary(&result[0], &args[0], - trans); + kps::ElementwiseRandom(&args[0], dist, + &state); + kps::ElementwiseUnary(&result[0], + &args[0], trans); kps::WriteData(out_data + i, &result[0], size - i, 1, stride, 1); __syncthreads(); diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 21d827c79200c..298a9ceb3c36d 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -44,7 +44,8 @@ struct GaussianGenerator { thrust::minstd_rand rng; rng.seed(seed_); using MT = typename details::MPTypeTrait::Type; - thrust::normal_distribution dist(mean_, std_); + thrust::normal_distribution dist(static_cast(mean_), + static_cast(std_)); unsigned int new_n = n + offset_; rng.discard(new_n); MT out = dist(rng); @@ -82,7 +83,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel { if (FLAGS_use_curand) { using MT = typename details::MPTypeTrait::Type; distribution::normal_distribution dist; - distribution::normal_transform trans(mean, std); + distribution::normal_transform trans(static_cast(mean), + static_cast(std)); distribution::distribution_and_transform(dev_cxt, tensor, dist, trans); } else { @@ -139,11 +141,14 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel { REGISTER_OP_CUDA_KERNEL( gaussian_random, paddle::operators::GPUGaussianRandomKernel, + paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel, paddle::operators::GPUGaussianRandomKernel); REGISTER_OP_CUDA_KERNEL( gaussian_random_batch_size_like, paddle::operators::GPUGaussianRandomBatchSizeLikeKernel< paddle::platform::float16>, + paddle::operators::GPUGaussianRandomBatchSizeLikeKernel< + paddle::platform::bfloat16>, paddle::operators::GPUGaussianRandomBatchSizeLikeKernel, paddle::operators::GPUGaussianRandomBatchSizeLikeKernel); diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 4f3c069f3b249..2a6d4244d9a5b 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -22,6 +22,7 @@ #endif #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" namespace phi { @@ -52,6 +53,11 @@ class MPTypeTrait { using Type = float; }; +template <> +class MPTypeTrait { + public: + using Type = float; +}; /** * @brief Will be used in BlockYReduce, get the index of reduce_num in shared * memory. diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 31caf4bd6be98..52d1262ed501b 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -22,7 +22,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator from paddle.fluid.executor import Executor -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float import paddle @@ -65,6 +65,48 @@ def verify_output(self, outs): "hist: " + str(hist) + " hist2: " + str(hist2)) +class TestGaussianRandomBF16Op(OpTest): + def setUp(self): + self.op_type = "gaussian_random" + self.set_attrs() + self.inputs = {} + self.use_mkldnn = False + self.attrs = { + "shape": [123, 92], + "mean": self.mean, + "std": self.std, + "seed": 10, + "dtype": paddle.fluid.core.VarDesc.VarType.BF16, + "use_mkldnn": self.use_mkldnn + } + paddle.seed(10) + + self.outputs = {'Out': np.zeros((123, 92), dtype='float32')} + + def set_attrs(self): + self.mean = 1.0 + self.std = 2. + + def test_check_output(self): + self.check_output_with_place_customized( + self.verify_output, place=core.CUDAPlace(0)) + + def verify_output(self, outs): + outs = convert_uint16_to_float(outs) + self.assertEqual(outs[0].shape, (123, 92)) + hist, _ = np.histogram(outs[0], range=(-3, 5)) + hist = hist.astype("float32") + hist /= float(outs[0].size) + data = np.random.normal(size=(123, 92), loc=1, scale=2) + hist2, _ = np.histogram(data, range=(-3, 5)) + hist2 = hist2.astype("float32") + hist2 /= float(outs[0].size) + self.assertTrue( + np.allclose( + hist, hist2, rtol=0, atol=0.05), + "hist: " + str(hist) + " hist2: " + str(hist2)) + + class TestMeanStdAreInt(TestGaussianRandomOp): def set_attrs(self): self.mean = 1 From 5f8d2d9062dece52955a8bc5f14ba1946c6c6290 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 1 Mar 2022 10:13:13 +0000 Subject: [PATCH 2/6] add full --- paddle/phi/kernels/gpu/full_kernel.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 1f756bfdbed30..7b3a49a8bdcb9 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -63,9 +63,11 @@ void FullLikeKernel(const Context& dev_ctx, auto value = val.to(); using CommonType = typename std::common_type< float, - typename std::conditional::value, - float, - T>::type>::type; + typename std::conditional< + std::is_same::value || + std::is_same::value, + float, + T>::type>::type; auto common_type_value = static_cast(value); @@ -110,6 +112,7 @@ PD_REGISTER_KERNEL(full, int64_t, bool, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -123,6 +126,7 @@ PD_REGISTER_KERNEL(full_like, int, int64_t, bool, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } From 45ffe2e2c8345ff1a0cd0bf243bb04e87cb80319 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 2 Mar 2022 01:43:16 +0000 Subject: [PATCH 3/6] refine reduce --- paddle/phi/kernels/gpu/reduce.h | 44 +++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 9223a94c12aeb..289c2da94a163 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -1004,14 +1004,16 @@ template class ReduceOp, typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const paddle::platform::Place& place, - gpuStream_t stream) { +static + typename std::enable_if::value && + !std::is_same::value, + void>::type + CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { auto reducer = ReduceOp(); cub::TransformInputIterator trans_x(x_data, transform); @@ -1047,16 +1049,19 @@ template class ReduceOp, typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const paddle::platform::Place& place, - gpuStream_t stream) { - PADDLE_THROW(phi::errors::InvalidArgument( - "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); +static + typename std::enable_if::value || + std::is_same::value, + void>::type + CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { + PADDLE_THROW( + phi::errors::InvalidArgument("Tx should not be float16 or bfloat16 when " + "using cub::DeviceReduce::Reduce().")); } template ::value; - bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; + constexpr bool kIsTxBF16 = std::is_same::value; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16; if (use_cub_reduce) { CubTensorReduceImpl( x_data, y_data, transform, config.reduce_num, x.place(), stream); From 5a8cb1c4d78649d421f8cea8d1f6b2b737892df0 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 2 Mar 2022 04:10:10 +0000 Subject: [PATCH 4/6] refine code --- paddle/phi/kernels/gpu/full_kernel.cu | 4 +- paddle/phi/kernels/gpu/reduce.h | 44 ++++++++----------- .../kernels/primitive/compute_primitives.h | 7 +-- 3 files changed, 22 insertions(+), 33 deletions(-) diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 7b3a49a8bdcb9..a905979f08b5f 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -126,7 +126,7 @@ PD_REGISTER_KERNEL(full_like, int, int64_t, bool, - phi::dtype::float16, - phi::dtype::bfloat16) { + phi::dtype::bfloat16, + phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index 289c2da94a163..9223a94c12aeb 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -1004,16 +1004,14 @@ template class ReduceOp, typename TransformOp> -static - typename std::enable_if::value && - !std::is_same::value, - void>::type - CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const paddle::platform::Place& place, - gpuStream_t stream) { +static typename std::enable_if::value, + void>::type +CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { auto reducer = ReduceOp(); cub::TransformInputIterator trans_x(x_data, transform); @@ -1049,19 +1047,16 @@ template class ReduceOp, typename TransformOp> -static - typename std::enable_if::value || - std::is_same::value, - void>::type - CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const paddle::platform::Place& place, - gpuStream_t stream) { - PADDLE_THROW( - phi::errors::InvalidArgument("Tx should not be float16 or bfloat16 when " - "using cub::DeviceReduce::Reduce().")); +static typename std::enable_if::value, + void>::type +CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const paddle::platform::Place& place, + gpuStream_t stream) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); } template ::value; - constexpr bool kIsTxBF16 = std::is_same::value; - bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; if (use_cub_reduce) { CubTensorReduceImpl( x_data, y_data, transform, config.reduce_num, x.place(), stream); diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index efc68c2bd94c1..632ad00f6d06e 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -22,7 +22,7 @@ #endif #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/phi/common/bfloat16.h" +// #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" namespace phi { @@ -53,11 +53,6 @@ class MPTypeTrait { using Type = float; }; -template <> -class MPTypeTrait { - public: - using Type = float; -}; /** * @brief Will be used in BlockYReduce, get the index of reduce_num in shared * memory. From 4642361f4d4907f8d3a070790a8524d3b2f5069b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 2 Mar 2022 06:18:51 +0000 Subject: [PATCH 5/6] refine gaussian_random unittest --- python/paddle/fluid/tests/unittests/test_gaussian_random_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 52d1262ed501b..738441a46d377 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -65,6 +65,8 @@ def verify_output(self, outs): "hist: " + str(hist) + " hist2: " + str(hist2)) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") class TestGaussianRandomBF16Op(OpTest): def setUp(self): self.op_type = "gaussian_random" From 4aa5b0622c2af384e372e132731db65f02e2db85 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 2 Mar 2022 06:54:39 +0000 Subject: [PATCH 6/6] add unittest for fill_any_like fill_constant --- .../tests/unittests/test_fill_any_like_op.py | 21 ++++++++++++++++++- .../tests/unittests/test_fill_constant_op.py | 21 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index 5bc2d1cda180b..9be2e57ff0cba 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -21,7 +21,7 @@ import paddle.compat as cpt import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 class TestFillAnyLikeOp(OpTest): @@ -47,6 +47,25 @@ def init(self): self.value = 0.0 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFillAnyLikeOpBfloat16(OpTest): + def setUp(self): + self.op_type = "fill_any_like" + self.dtype = np.uint16 + self.value = 0.0 + self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)} + self.attrs = {'value': self.value, 'dtype': core.VarDesc.VarType.BF16} + self.outputs = { + 'Out': + convert_float_to_uint16(self.value * np.ones_like(self.inputs["X"])) + } + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + class TestFillAnyLikeOpValue1(TestFillAnyLikeOp): def init(self): self.value = 1.0 diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 822c952893e11..15071b2b6aa69 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -83,6 +83,27 @@ def test_check_output(self): self.check_output() +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFillConstantBF16Op(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + self.dtype = np.uint16 + self.inputs = {} + self.attrs = { + 'shape': [123, 92], + 'value': 3.8, + 'dtype': core.VarDesc.VarType.BF16 + } + self.outputs = {'Out': convert_float_to_uint16(np.full((123, 92), 3.8))} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + class TestFillConstantOpWithSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope()