From 68ba0e43fd87fd46748180c260ecb9ba251c5b38 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 17 Feb 2022 12:05:42 +0000 Subject: [PATCH 1/9] add scale gather sum --- paddle/fluid/operators/gather_op.cc | 6 +- paddle/fluid/operators/gather_op.cu | 6 +- paddle/fluid/operators/math/blas_impl.h | 8 +++ .../operators/math/selected_rows_functor.cu | 2 + paddle/fluid/operators/sum_op.cu | 3 +- .../platform/device/gpu/gpu_primitives.h | 69 +++++++++++++++++++ paddle/pten/kernels/gpu/scale_kernel.cu | 1 + 7 files changed, 90 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index d35b066be85e7..e70215db2d394 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -201,12 +201,14 @@ REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel, ops::GatherOpKernel, ops::GatherOpKernel, ops::GatherOpKernel, - ops::GatherOpKernel); + ops::GatherOpKernel, + ops::GatherOpKernel); REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, ops::GatherGradientOpKernel, - ops::GatherGradientOpKernel); + ops::GatherGradientOpKernel, + ops::GatherGradientOpKernel); REGISTER_OP_VERSION(gather) .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", paddle::framework::compatible::OpVersionDesc().NewInput( diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 19568835a6e96..a502a13040949 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -130,9 +130,11 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, ops::GatherOpCUDAKernel, - ops::GatherOpCUDAKernel); + ops::GatherOpCUDAKernel, + ops::GatherOpCUDAKernel); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, ops::GatherGradOpCUDAKernel, - ops::GatherGradOpCUDAKernel); + ops::GatherGradOpCUDAKernel, + ops::GatherGradOpCUDAKernel); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 8e0075c42eb2c..3fc8f8007c669 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -77,6 +77,14 @@ struct CBlas { "Blas VCOPY do not supported on CPU with bfloat16," " please check your code")); } + + template + static void VADD(int n, const pten::dtype::bfloat16 *x, + const pten::dtype::bfloat16 *y, pten::dtype::bfloat16 *z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } + } }; #ifdef PADDLE_WITH_MKLML diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index d2caf82c93a52..7c9bfb28cfb40 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" #include "paddle/pten/kernels/funcs/math_function.h" @@ -440,6 +441,7 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; template struct MergeAdd>; template struct MergeAdd>; diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index ce152f4450811..47a51769b59bb 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -258,4 +258,5 @@ REGISTER_OP_CUDA_KERNEL( ops::SumKernel, ops::SumKernel, ops::SumKernel, - ops::SumKernel); + ops::SumKernel, + ops::SumKernel); diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 3e070da546b2a..9969069124f47 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #endif #include +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -149,6 +150,74 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { #endif #endif +#ifdef PADDLE_CUDA_BF16 +// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. +inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) { + bfloat16 low_half; + // the bfloat16 in lower 16bits + low_half.x = static_cast(val & 0xFFFFu); + low_half = static_cast(static_cast(low_half) + x); + return (val & 0xFFFF0000u) | low_half.x; +} + +inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) { + bfloat16 high_half; + // the bfloat16 in higher 16bits + high_half.x = static_cast(val >> 16); + high_half = static_cast(static_cast(high_half) + x); + return (val & 0xFFFFu) | (static_cast(high_half.x) << 16); +} + +#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) { + return *reinterpret_cast(&x); +} + +static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) { + return *reinterpret_cast<__nv_bfloat16 *>(&x); +} + +CUDA_ATOMIC_WRAPPER(Add, bfloat16) { + return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), + PDBF16ToCUDABF16(val))); +} +#else +CUDA_ATOMIC_WRAPPER(Add, bfloat16) { + // concrete packed bfloat16 value may exsits in lower or higher 16bits + // of the 32bits address. + uint32_t *address_as_ui = reinterpret_cast( + reinterpret_cast(address) - + (reinterpret_cast(address) & 0x02)); + float val_f = static_cast(val); + uint32_t old = *address_as_ui; + uint32_t sum; + uint32_t newval; + uint32_t assumed; + if (((uintptr_t)address & 0x02) == 0) { + // the bfloat16 value stay at lower 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, + bf16_add_to_low_half(assumed, val_f)); + } while (old != assumed); + bfloat16 ret; + ret.x = old & 0xFFFFu; + return ret; + } else { + // the bfloat16 value stay at higher 16 bits of the address. + do { + assumed = old; + old = atomicCAS(address_as_ui, assumed, + bf16_add_to_high_half(assumed, val_f)); + } while (old != assumed); + bfloat16 ret; + ret.x = old >> 16; + return ret; + } +} +#endif +#endif + CUDA_ATOMIC_WRAPPER(Add, complex) { float *real = reinterpret_cast(address); float *imag = real + 1; diff --git a/paddle/pten/kernels/gpu/scale_kernel.cu b/paddle/pten/kernels/gpu/scale_kernel.cu index e1cf78224a19d..da88c23ef8199 100644 --- a/paddle/pten/kernels/gpu/scale_kernel.cu +++ b/paddle/pten/kernels/gpu/scale_kernel.cu @@ -70,6 +70,7 @@ PT_REGISTER_KERNEL(scale, float, double, pten::dtype::float16, + pten::dtype::bfloat16, uint8_t, int8_t, int16_t, From 259fedfe128e3bd7e0b58c1506bae523f8e47dc4 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 03:40:05 +0000 Subject: [PATCH 2/9] refine CUDA_ATOMIC_WRAPPER ADD for bf16 --- paddle/fluid/platform/device/gpu/gpu_primitives.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index 9969069124f47..9aaf6de32ad42 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -150,7 +150,6 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { #endif #endif -#ifdef PADDLE_CUDA_BF16 // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) { bfloat16 low_half; @@ -216,7 +215,6 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) { } } #endif -#endif CUDA_ATOMIC_WRAPPER(Add, complex) { float *real = reinterpret_cast(address); From 4998403ebcae7b293b55c5e9b1cb69a07c47f1de Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 12:29:14 +0000 Subject: [PATCH 3/9] add gather unittest --- .../paddle/fluid/tests/unittests/op_test.py | 7 ++++- .../fluid/tests/unittests/test_gather_op.py | 27 ++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 85423df3d3828..7b075239bbb65 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -480,7 +480,12 @@ def _append_ops(self, block): op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) inputs = append_input_output(block, op_proto, self.inputs, True, self.dtype) outputs = append_input_output(block, op_proto, self.outputs, False, diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 83b39a62f152d..03b0558160067 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid from paddle.framework import core @@ -117,6 +117,31 @@ def config(self): self.index_type = "int32" +class TestGatherBF16Op(OpTest): + def setUp(self): + self.op_type = "gather" + self.dtype = np.uint16 + self.config() + xnp = np.random.random(self.x_shape).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(xnp), + 'Index': np.array(self.index).astype(self.index_type) + } + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.index = [1, 3, 5] + self.index_type = "int32" + + class TestGatherOp1(OpTest): def setUp(self): self.op_type = "gather" From 3053a151537f834681bdf01d3543edddf3117bee Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 12:51:19 +0000 Subject: [PATCH 4/9] solve conflict --- paddle/fluid/operators/math/blas_impl.h | 1868 ----------------------- 1 file changed, 1868 deletions(-) delete mode 100644 paddle/fluid/operators/math/blas_impl.h diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h deleted file mode 100644 index 3fc8f8007c669..0000000000000 --- a/paddle/fluid/operators/math/blas_impl.h +++ /dev/null @@ -1,1868 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "paddle/pten/backends/cpu/cpu_context.h" -#ifdef PADDLE_WITH_MKLML -#include -#endif - -#include -#include -#include -#include - -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/pten/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -namespace math { -namespace detail { - -template -static void axpy(int n, const T alpha, const T *x, const int incx, T *y, - const int incy) { - // Y = Y + alpha * X - while (n-- > 0) { - *y += alpha * *x; - y = y + incy; - x = x + incx; - } -} -} // namespace detail - -template -struct CBlas; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU, please check your code")); - } -}; - -template <> -struct CBlas { - template - static void AXPY(ARGS... args) { - detail::axpy(args...); - } - - template - static void VCOPY(ARGS... args) { - PADDLE_THROW(platform::errors::Unimplemented( - "Blas VCOPY do not supported on CPU with bfloat16," - " please check your code")); - } - - template - static void VADD(int n, const pten::dtype::bfloat16 *x, - const pten::dtype::bfloat16 *y, pten::dtype::bfloat16 *z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] + y[i]; - } - } -}; - -#ifdef PADDLE_WITH_MKLML -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - platform::dynload::cblas_sgemm(args...); - } - - template - static float *GEMM_ALLOC(ARGS... args) { - return platform::dynload::cblas_sgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - platform::dynload::cblas_sgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - platform::dynload::cblas_sgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - platform::dynload::cblas_sgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_sgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - platform::dynload::cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - platform::dynload::cblas_sgemv(args...); - } - - template - static float DOT(ARGS... args) { - return platform::dynload::cblas_sdot(args...); - } - - template - static void SCAL(ARGS... args) { - platform::dynload::cblas_sscal(args...); - } - - template - static float ASUM(ARGS... args) { - return platform::dynload::cblas_sasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - platform::dynload::cblas_sgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - platform::dynload::vsAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vsSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vsMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vsDiv(args...); - } - - template - static void VEXP(ARGS... args) { - platform::dynload::vsExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - platform::dynload::vsSqr(args...); - } - - template - static void VPOW(ARGS... args) { - platform::dynload::vsPowx(args...); - } - - template - static void VINV(ARGS... args) { - platform::dynload::vsInv(args...); - } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmsErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - platform::dynload::mkl_scsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - platform::dynload::cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - platform::dynload::cblas_dgemm(args...); - } - - template - static double *GEMM_ALLOC(ARGS... args) { - return platform::dynload::cblas_dgemm_alloc(args...); - } - - template - static void GEMM_PACK(ARGS... args) { - platform::dynload::cblas_dgemm_pack(args...); - } - - template - static void GEMM_COMPUTE(ARGS... args) { - platform::dynload::cblas_dgemm_compute(args...); - } - - template - static void GEMM_FREE(ARGS... args) { - platform::dynload::cblas_dgemm_free(args...); - } - -#ifdef PADDLE_WITH_LIBXSMM - template - static void SMM_GEMM(ARGS... args) { - libxsmm_dgemm(args...); - } -#endif - - template - static void AXPY(ARGS... args) { - platform::dynload::cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - platform::dynload::cblas_dgemv(args...); - } - - template - static double DOT(ARGS... args) { - return platform::dynload::cblas_ddot(args...); - } - - template - static void SCAL(ARGS... args) { - platform::dynload::cblas_dscal(args...); - } - - template - static double ASUM(ARGS... args) { - return platform::dynload::cblas_dasum(args...); - } - - template - static void GEMM_BATCH(ARGS... args) { - platform::dynload::cblas_dgemm_batch(args...); - } - - template - static void VADD(ARGS... args) { - platform::dynload::vdAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vdSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vdMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vdDiv(args...); - } - - template - static void VEXP(ARGS... args) { - platform::dynload::vdExp(args...); - } - - template - static void VSQUARE(ARGS... args) { - platform::dynload::vdSqr(args...); - } - - template - static void VPOW(ARGS... args) { - platform::dynload::vdPowx(args...); - } - - template - static void VINV(ARGS... args) { - platform::dynload::vdInv(args...); - } - - template - static void VMERF(ARGS... args) { - platform::dynload::vmdErf(args...); - } -#if !defined(_WIN32) - template - static void CSRMM(ARGS... args) { - platform::dynload::mkl_dcsrmm(args...); - } -#endif - - template - static void TRSM(ARGS... args) { - platform::dynload::cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_ccopy(args...); - } - - // the libmklml_intel.so paddle used has no vcAdd, vcSub, - // vcMul, vcDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - platform::dynload::vcAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vcSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vcMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vcDiv(args...); - } - */ - - template - static void VADD(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *X, int incx, - paddle::platform::complex beta, - paddle::platform::complex *Y, int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - platform::dynload::cblas_cgemv(layout, trans, M, N, &alpha, a_, lda, x_, - incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *B, int ldb, - paddle::platform::complex beta, - paddle::platform::complex *C, int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - platform::dynload::cblas_cgemm(layout, trans_a, trans_b, M, N, K, &alpha, - a_, lda, b_, ldb, &beta, c_, ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N, - &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex *alpha, - const paddle::platform::complex **A, - const int *lda, - const paddle::platform::complex **B, - const int *ldb, paddle::platform::complex *beta, - paddle::platform::complex **C, const int *ldc, - int group_count, int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - platform::dynload::cblas_cgemm_batch(layout, trans_a, trans_b, M, N, K, - alpha, A_void, lda, B_void, ldb, beta, - C_void, ldc, group_count, group_size); - } - - template - static void GEMM_EX(ARGS... args) { - platform::dynload::cblas_cgemm_batch(args...); - } -}; - -template <> -struct CBlas> { - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void VCOPY(ARGS... args) { - platform::dynload::cblas_zcopy(args...); - } - - // the libmklml_intel.so paddle used has no vzAdd, vzSub, - // vzMul, vzDiv apis before rebuild from source - // so replace with the raw operator methods - /* - template - static void VADD(ARGS... args) { - platform::dynload::vzAdd(args...); - } - - template - static void VSUB(ARGS... args) { - platform::dynload::vzSub(args...); - } - - template - static void VMUL(ARGS... args) { - platform::dynload::vzMul(args...); - } - - template - static void VDIV(ARGS... args) { - platform::dynload::vzDiv(args...); - } - */ - - template - static void VADD(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] + b[i]; - } - } - - template - static void VSUB(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] - b[i]; - } - } - - template - static void VMUL(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] * b[i]; - } - } - template - static void VDIV(int n, const paddle::platform::complex *a, - const paddle::platform::complex *b, - paddle::platform::complex *y) { - for (int i = 0; i < n; ++i) { - y[i] = a[i] / b[i]; - } - } - - template - static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *X, int incx, - paddle::platform::complex beta, - paddle::platform::complex *Y, int incy) { - const void *a_ = (const void *)(A); - const void *x_ = (const void *)(X); - void *y_ = static_cast(Y); - platform::dynload::cblas_zgemv(layout, trans, M, N, &alpha, a_, lda, x_, - incx, &beta, y_, incy); - } - - template - static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, - CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - const paddle::platform::complex *B, int ldb, - paddle::platform::complex beta, - paddle::platform::complex *C, int ldc) { - const void *a_ = (const void *)(A); - const void *b_ = (const void *)(B); - void *c_ = static_cast(C); - platform::dynload::cblas_zgemm(layout, trans_a, trans_b, M, N, K, &alpha, - a_, lda, b_, ldb, &beta, c_, ldc); - } - - static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, - paddle::platform::complex alpha, - const paddle::platform::complex *A, int lda, - paddle::platform::complex *B, int ldb) { - const void *a_ = (const void *)(A); - void *b_ = static_cast(B); - platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N, - &alpha, a_, lda, b_, ldb); - } - - template - static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, - CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex *alpha, - const paddle::platform::complex **A, - const int *lda, - const paddle::platform::complex **B, - const int *ldb, - paddle::platform::complex *beta, - paddle::platform::complex **C, const int *ldc, - int group_count, int *group_size) { - const void **A_void = (const void **)(&(*A)); - const void **B_void = (const void **)(&(*B)); - void **C_void = reinterpret_cast(C); - - platform::dynload::cblas_zgemm_batch(layout, trans_a, trans_b, M, N, K, - alpha, A_void, lda, B_void, ldb, beta, - C_void, ldc, group_count, group_size); - } - - template - static void GEMM_EX(ARGS... args) { - platform::dynload::cblas_zgemm_batch(args...); - } -}; - -#else - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_sgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_saxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_scopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_sgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_strsm(args...); - } -}; - -template <> -struct CBlas { - template - static void GEMM(ARGS... args) { - cblas_dgemm(args...); - } - - template - static void AXPY(ARGS... args) { - cblas_daxpy(args...); - } - - template - static void VCOPY(ARGS... args) { - cblas_dcopy(args...); - } - - template - static void GEMV(ARGS... args) { - cblas_dgemv(args...); - } - - template - static void TRSM(ARGS... args) { - cblas_dtrsm(args...); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_ccopy(args...); - } - - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - cblas_caxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *X, const int incX, - const paddle::platform::complex beta, - paddle::platform::complex *Y, const int incY) { - cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *B, const int ldb, - const paddle::platform::complex beta, - paddle::platform::complex *C, const int ldc) { - cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, - C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, - const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - paddle::platform::complex *B, const int ldb) { - cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -template <> -struct CBlas> { - template - static void VCOPY(ARGS... args) { - cblas_zcopy(args...); - } - - template - static void AXPY(int n, const paddle::platform::complex alpha, - const paddle::platform::complex *X, const int incX, - paddle::platform::complex *Y, const int incY) { - cblas_zaxpy(n, &alpha, X, incX, Y, incY); - } - - template - static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *X, const int incX, - const paddle::platform::complex beta, - paddle::platform::complex *Y, const int incY) { - cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); - } - - template - static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - const paddle::platform::complex *B, const int ldb, - const paddle::platform::complex beta, - paddle::platform::complex *C, const int ldc) { - cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, - C, ldc); - } - - static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, - const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, - const CBLAS_DIAG diag, const int M, const int N, - const paddle::platform::complex alpha, - const paddle::platform::complex *A, const int lda, - paddle::platform::complex *B, const int ldb) { - cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); - } -}; - -#endif - -template <> -struct CBlas { - static void GEMM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 GEMM not supported on CPU, please check your code")); - } - - static void SMM_GEMM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 SMM_GEMM not supported on CPU, please check your code")); - } - static void VMUL(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VMUL not supported on CPU, please check your code")); - } - static void VEXP(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VEXP not supported on CPU, please check your code")); - } - static void VSQUARE(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VSQUARE not supported on CPU, please check your code")); - } - static void VPOW(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 VPOW not supported on CPU, please check your code")); - } - static void DOT(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 DOT not supported on CPU, please check your code")); - }; - static void SCAL(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 SCAL not supported on CPU, please check your code")); - }; - static void ASUM(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 ASUM not supported on CPU, please check your code")); - }; -#ifdef PADDLE_WITH_MKLML - static void GEMM_BATCH(...) { - PADDLE_THROW(platform::errors::Unimplemented( - "float16 GEMM_BATCH not supported on CPU, please check your code")); - } -#endif -}; - -#ifdef PADDLE_WITH_MKLML -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, - const int M, const int N, - const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} -template <> -template -T *Blas::GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, - const int N, const int K) const { - return CBlas::GEMM_ALLOC(id, M, N, K); -} - -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, - int M, int N, int K, - const T alpha, const T *src, - const int ld, T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} -template <> -template -void Blas::GEMM_PACK(const CBLAS_IDENTIFIER id, - const CBLAS_TRANSPOSE trans, int M, - int N, int K, const T alpha, - const T *src, const int ld, - T *dst) const { - CBlas::GEMM_PACK(CblasRowMajor, id, trans, M, N, K, alpha, src, ld, dst); -} - -template <> -template -void Blas::GEMM_COMPUTE( - int transA, int transB, int M, int N, int K, const T *A, const int lda, - const T *B, const int ldb, T beta, T *C, const int ldc) const { - CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM_COMPUTE(int transA, int transB, int M, int N, - int K, const T *A, const int lda, - const T *B, const int ldb, T beta, - T *C, const int ldc) const { - CBlas::GEMM_COMPUTE(CblasRowMajor, transA, transB, M, N, K, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -template <> -template -void Blas::GEMM_FREE(T *data) const { - CBlas::GEMM_FREE(data); -} -#endif - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - const T *B, T beta, T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, - T *C) const { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} -template <> -template -void Blas::GEMM(bool transA, bool transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} - -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, - int N, int K, T alpha, const T *A, - int lda, const T *B, int ldb, - T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} -template <> -template -void Blas::GEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, int lda, const T *B, - int ldb, T beta, T *C, int ldc) const { - CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template -template -void Blas::MatMul(const framework::Tensor &mat_a, bool trans_a, - const framework::Tensor &mat_b, bool trans_b, - T alpha, framework::Tensor *mat_out, - T beta) const { - auto dim_a = mat_a.dims(); - auto dim_b = mat_b.dims(); - auto dim_out = mat_out->dims(); - PADDLE_ENFORCE_EQ( - dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, true, - platform::errors::InvalidArgument( - "The input and output of matmul should be matrix, the dim size must " - "be 2," - "but received dim size input_a:%d, input_b:%d, output:%d", - dim_a.size(), dim_b.size(), dim_out.size())); - PADDLE_ENFORCE_EQ( - mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), true, - platform::errors::InvalidArgument("The places of matrices in the matmul " - "should be same, please check your " - "code.")); - - int M = dim_out[0]; - int N = dim_out[1]; - int K = !trans_a ? dim_a[1] : dim_a[0]; - - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !trans_b ? CblasNoTrans : CblasTrans; - - this->GEMM(transA, transB, M, N, K, alpha, mat_a.data(), mat_b.data(), - beta, mat_out->data()); -} - -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, - T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} -template <> -template -void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CBlas::AXPY(n, alpha, x, 1, y, 1); -} - -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} -template <> -template -void Blas::VCOPY(int n, const T *x, T *y) const { - CBlas::VCOPY(n, x, 1, y, 1); -} - -template <> -template -void Blas::VADD(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} -template <> -template -void Blas::VADD(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VADD(n, x, y, z); -#else - if (x == z) { - this->template AXPY(n, (T)(1.), y, z); - } else { - this->template VCOPY(n, y, z); - this->template AXPY(n, (T)(1.), x, z); - } -#endif -} - -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} -template <> -template -void Blas::VSUB(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSUB(n, x, y, z); -#else - // try to find if openblas support vsub - for (int i = 0; i < n; ++i) { - z[i] = x[i] - y[i]; - } -#endif -} - -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} -template <> -template -void Blas::VMUL(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMUL(n, x, y, z); -#else - // try to find if openblas support vmul - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } -#endif -} - -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, - T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} -template <> -template -void Blas::VDIV(int n, const T *x, const T *y, T *z) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VDIV(n, x, y, z); -#else - // try to find if openblas support vdiv - for (int i = 0; i < n; ++i) { - z[i] = x[i] / y[i]; - } -#endif -} - -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} -template <> -template -void Blas::VEXP(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VEXP(n, x, y); -#else - // try to find if openblas support vexp - for (int i = 0; i < n; ++i) { - y[i] = std::exp(x[i]); - } -#endif -} - -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} -template <> -template -void Blas::VSQUARE(int n, const T *x, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VSQUARE(n, x, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = x[i] * x[i]; - } -#endif -} - -template <> -template -void Blas::VPOW(int n, const T *x, T a, - T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} -template <> -template -void Blas::VPOW(int n, const T *x, T a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VPOW(n, x, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::pow(x[i], a); - } -#endif -} - -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} -template <> -template -T Blas::DOT(int n, const T *x, const T *y) const { -#ifdef PADDLE_WITH_MKLML - return CBlas::DOT(n, x, 1, y, 1); -#else - // try to find if openblas support cblas_dot - T sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] * y[i]; - } - return sum; -#endif -} - -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} -template <> -template -void Blas::SCAL(int n, const T a, T *x) const { -#ifdef PADDLE_WITH_MKLML - CBlas::SCAL(n, a, x, 1); -#else - // try to find if openblas support cblas_scal - for (int i = 0; i < n; ++i) { - x[i] = a * x[i]; - } -#endif -} - -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} -template <> -template -T Blas::ASUM(int n, T *x, int inc) const { - auto sum = static_cast(0.0); -#ifdef PADDLE_WITH_MKLML - sum = CBlas::ASUM(n, x, inc); -#else - // TODO(jczaja): check if openblas does provide cblas_sasum/cblas_dasum - for (int c = 0; c < n; ++c) { - sum += x[c]; - } -#endif - return sum; -} - -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, - T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} -template <> -template -void Blas::GEMV(bool trans_a, int M, int N, T alpha, - const T *A, const T *B, T beta, T *C) const { - CBLAS_TRANSPOSE transA = !trans_a ? CblasNoTrans : CblasTrans; - CBlas::GEMV(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, platform::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, platform::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, platform::errors::InvalidArgument("Pointer C should not be null.")); -#ifdef PADDLE_WITH_MKLML - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); - } -#endif -} -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T *A, const T *B, - T beta, T *C, int batchCount, - int64_t strideA, - int64_t strideB) const { - PADDLE_ENFORCE_NOT_NULL( - A, platform::errors::InvalidArgument("Pointer A should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - B, platform::errors::InvalidArgument("Pointer B should not be null.")); - PADDLE_ENFORCE_NOT_NULL( - C, platform::errors::InvalidArgument("Pointer C should not be null.")); -#ifdef PADDLE_WITH_MKLML - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA]; - b_array[k] = &B[k * strideB]; - c_array[k] = &C[k * M * N]; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, - a_array.data(), &lda, b_array.data(), &ldb, &beta, - c_array.data(), &ldc, 1 /* group_count */, &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - auto *Ak = &A[k * strideA]; - auto *Bk = &B[k * strideB]; - auto *Ck = &C[k * M * N]; - this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); - } -#endif -} - -template <> -template -void Blas::BatchedGEMM( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, - &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -#endif -} -template <> -template -void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, int M, int N, - int K, T alpha, const T **A, - const T **B, T beta, T **C, - int batchCount) const { -#ifdef PADDLE_WITH_MKLML - const int lda = (std::max)((transA == CblasNoTrans) ? K : M, 1); - const int ldb = (std::max)((transB == CblasNoTrans) ? N : K, 1); - const int ldc = (std::max)(N, 1); - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &M, &N, &K, &alpha, A, - &lda, B, &ldb, &beta, C, &ldc, 1 /* group_count */, - &batchCount); -#else - for (int k = 0; k < batchCount; ++k) { - this->template GEMM(transA, transB, M, N, K, alpha, A[k], B[k], beta, - C[k]); - } -#endif -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead -template <> -template -void Blas::BatchedGEMMWithHead( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, - int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB, int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, - &H2, &alpha, a_array.data(), &lda, b_array.data(), - &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, H2, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, - &sub_width, &alpha, a_array.data(), &lda, - b_array.data(), &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - } -} -template <> -template -void Blas::BatchedGEMMWithHead( - CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int W1, int H1, int W2, - int H2, T alpha, const T *A, const T *B, T beta, T *C, int batchCount, - int64_t strideA, int64_t strideB, int64_t head_number, - bool split_b_vertical) const { - int lda = (transA == CblasNoTrans) ? W1 : H1; - int ldb = (transB == CblasNoTrans) ? W2 : H2; - auto a_array = std::vector(batchCount); - auto b_array = std::vector(batchCount); - auto c_array = std::vector(batchCount); - - if (split_b_vertical) { - int ldc = W2; - int sub_width = W2 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W2 / head_number) - : i * (W2 / head_number) * H2; - int sub_matC_offset = i * W2 / head_number; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &sub_width, - &H2, &alpha, a_array.data(), &lda, b_array.data(), - &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - - } else { - PADDLE_ENFORCE_EQ( - W1, H2, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - W1, H2)); - int ldc = W2 * head_number; - int sub_width = W1 / head_number; - - for (int i = 0; i < head_number; i++) { - int sub_matA_offset = (transA == CblasNoTrans) - ? i * (W1 / head_number) - : i * (W1 / head_number) * H1; - int sub_matB_offset = (transB == CblasNoTrans) - ? i * (W1 / head_number) * W2 - : i * (W1 / head_number); - int sub_matC_offset = i * W2; - for (int k = 0; k < batchCount; ++k) { - a_array[k] = &A[k * strideA] + sub_matA_offset; - b_array[k] = &B[k * strideB] + sub_matB_offset; - c_array[k] = &C[k * H1 * head_number * W2] + sub_matC_offset; - } - - CBlas::GEMM_BATCH(CblasRowMajor, &transA, &transB, &H1, &W2, - &sub_width, &alpha, a_array.data(), &lda, - b_array.data(), &ldb, &beta, c_array.data(), &ldc, - 1 /* group_count */, &batchCount); - } - } -} -#endif // @} End Group Blas MKLML: BatchedGEMMWithHead - -template -template -void Blas::MatMul(const int M, const int N, const int K, - const T *A, const T *B, T *C) const { - this->template GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, - N); -} - -template <> -template -void Blas::MatMul(const int M, const int N, - const int K, const T *A, - const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, - C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, N); -} -template <> -template -void Blas::MatMul(const int M, const int N, const int K, - const T *A, const T *B, T *C) const { -#ifdef PADDLE_WITH_LIBXSMM - // Refer to https://github.com/hfp/libxsmm/blob/master/README.md - // But the threshold is custom constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20; - - // Since the matrix is very small, - // so the unit of calculation is already very fast, - // and the if( M*N*K < LIBXSMM_THRESHOLD) would be overhead, - // use xsmm directly. - // Note: SMM use ColMajor - const char transa = 'N'; - const char transb = 'N'; - const T alpha = static_cast(1); - const T beta = static_cast(0); - CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &N, A, &K, &beta, - C, &N); - return; -#endif - - CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, - static_cast(1), A, K, B, N, static_cast(0), C, N); -} - -template -template -void Blas::MatMul(const framework::Tensor &mat_a, - const MatDescriptor &dim_a, - const framework::Tensor &mat_b, - const MatDescriptor &dim_b, T alpha, - framework::Tensor *mat_out, T beta) const { - MatMul(mat_a.data(), dim_a, mat_b.data(), dim_b, alpha, - mat_out->data(), beta); -} - -template -template -void Blas::MatMul(const T *mat_a, const MatDescriptor &dim_a, - const T *mat_b, const MatDescriptor &dim_b, - T alpha, T *mat_out, T beta) const { - PADDLE_ENFORCE_EQ( - dim_a.width_, dim_b.height_, - platform::errors::InvalidArgument( - "The fisrt matrix width should be same as second matrix height," - "but received fisrt matrix width %d" - ", second matrix height %d", - dim_a.width_, dim_b.height_)); - - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a, mat_b, beta, mat_out); - } else { - PADDLE_ENFORCE_EQ( - dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0, - true, platform::errors::InvalidArgument( - "dim_a.batch_size should be equal to dim_b.batch_size, or " - "one of dim_a.batch_size and dim_b.batch_size should be 0. " - "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", - dim_a.batch_size_, dim_b.batch_size_)); - this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a, - mat_b, beta, mat_out, - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_); - } -} - -#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ - !defined(PADDLE_WITH_HIP) -// @{ Group Blas MKLML: MatMulWithHead -/* - * Multiple two matrixes with multiple heads - * - * A new parameter, i.e head_number is added compared to normal MatMul. - * The head_number describes the number of heads a matrix is vertically - * split. - * - * When user calls this API, the multiplication of two big matrixes is split - * into multiplication of several (head_number_) small matrixes. e.g. if Mat A - * is [3, 24] and Mat B is [24, 4], when multiple A and B with head_number as - * 4, Mat A will be split as 4 matrix of [3, 6] and Mat B will be - * (horizontally) split as 4 matrix of [6, 4]. The result of final matrix - * will be 4 matrix of [3, 4], i.e. [3, 16]. - * Another example is A is [3, 8], B is [2, 16], head_number is 4. In this - * case, A will be split as [3, 2], B will be (vertically) split as - * [2, 4]. The final result will be 4 matrix of 4 matrix of [3,4], i.e. [3, 16] - */ -template -template -void Blas::MatMulWithHead(const framework::Tensor &mat_a, - const MatDescriptor &dim_a, - const framework::Tensor &mat_b, - const MatDescriptor &dim_b, T alpha, - int head_number, - framework::Tensor *mat_out, T beta, - bool mat_b_split_vertical) const { - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, 0, - platform::errors::InvalidArgument( - "The first input width must be some times the head number" - "but received first input width %d" - ", head_number %d", - dim_a.width_, head_number)); - PADDLE_ENFORCE_GE(head_number, 1, - platform::errors::InvalidArgument( - "The head number should be greater equal 1," - "but received head number %d", - head_number)); - PADDLE_ENFORCE_LE( - head_number, dim_a.width_, - platform::errors::InvalidArgument( - "The head number should be less equal first input width," - "but received first input width %d" - ", head_number %d", - dim_a.width_, head_number)); - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - - if (mat_b_split_vertical) { - PADDLE_ENFORCE_EQ( - dim_b.height_, dim_a.width_ / head_number, - platform::errors::InvalidArgument( - "The second input height should be equal than first input width," - "but received second input height %d, first input width %d", - dim_b.height_, dim_a.width_ / head_number)); - PADDLE_ENFORCE_EQ( - dim_a.width_ % head_number, 0, - platform::errors::InvalidArgument( - "The second input width should be some times the head number" - "but received second input width %d" - ", head_number %d", - dim_b.width_, head_number)); - } - - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - int lda = !dim_a.trans_ ? dim_a.width_ : dim_a.height_; - int ldb = !dim_b.trans_ ? dim_b.width_ : dim_b.height_; - int sub_matA_offset; - int sub_matB_offset; - int sub_matC_offset; - int sub_mat_M = dim_a.height_; - int sub_mat_N; - int sub_mat_K; - int ldc; - - for (int i = 0; i < head_number; i++) { - sub_matA_offset = dim_a.trans_ - ? i * (dim_a.width_ / head_number) * dim_a.height_ - : i * (dim_a.width_ / head_number); - if (mat_b_split_vertical) { - sub_matB_offset = dim_b.trans_ - ? i * (dim_b.width_ / head_number) * dim_b.height_ - : i * (dim_b.width_ / head_number); - sub_matC_offset = i * dim_b.width_ / head_number; - - sub_mat_N = dim_b.width_ / head_number; - sub_mat_K = dim_b.height_; - - ldc = dim_b.width_; - } else { - sub_matB_offset = - dim_b.trans_ ? i * (dim_b.height_ / head_number) - : i * (dim_b.height_ / head_number) * dim_b.width_; - sub_matC_offset = i * dim_b.width_; - - sub_mat_N = dim_b.width_; - sub_mat_K = dim_a.width_ / head_number; - - ldc = head_number * dim_b.width_; - } - - this->template GEMM(transA, transB, sub_mat_M, sub_mat_N, sub_mat_K, - alpha, mat_a.data() + sub_matA_offset, lda, - mat_b.data() + sub_matB_offset, ldb, beta, - mat_out->data() + sub_matC_offset, ldc); - } - } else { - PADDLE_ENFORCE_EQ( - (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || - dim_b.batch_size_ == 0), - true, - platform::errors::InvalidArgument( - "The first input batch size should be equal than second input," - "either two input batch size is 0, but received first input batch " - "size" - " %d, second input batch size %d", - dim_a.batch_size_, dim_b.batch_size_)); - - this->template BatchedGEMMWithHead( - transA, transB, dim_a.width_, dim_a.height_, dim_b.width_, - dim_b.height_, alpha, mat_a.data(), mat_b.data(), beta, - mat_out->data(), - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical); - } -} -#endif // @} End Group Blas MKLML: MatMulWithHead - -template -template -void Blas::VINV(int n, const T *a, T *y) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VINV(n, a, y); -#else - for (int i = 0; i < n; ++i) { - y[i] = 1.0 / a[i]; - } -#endif -} - -template <> -template -void Blas::VMERF(int n, const T *a, T *y, - int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} -template <> -template -void Blas::VMERF(int n, const T *a, T *y, - int64_t mode) const { -#ifdef PADDLE_WITH_MKLML - CBlas::VMERF(n, a, y, mode); -#else - for (int i = 0; i < n; ++i) { - y[i] = std::erf(a[i]); - } -#endif -} - -#ifdef PADDLE_WITH_MKLML -template <> -template -void Blas::CSRMM( - const char *transa, const int *m, const int *n, const int *k, - const T *alpha, const char *matdescra, const T *val, const int *indx, - const int *pntrb, const int *pntre, const T *b, const int *ldb, - const T *beta, T *c, const int *ldc) const { - CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, - ldb, beta, c, ldc); -} -template <> -template -void Blas::CSRMM(const char *transa, const int *m, - const int *n, const int *k, const T *alpha, - const char *matdescra, const T *val, - const int *indx, const int *pntrb, - const int *pntre, const T *b, const int *ldb, - const T *beta, T *c, const int *ldc) const { - CBlas::CSRMM(transa, m, n, k, alpha, matdescra, val, indx, pntrb, pntre, b, - ldb, beta, c, ldc); -} -#endif - -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, - CBLAS_DIAG diag, int M, int N, - T alpha, const T *A, int lda, T *B, - int ldb) const { - CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, - B, ldb); -} -template <> -template -void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, - int M, int N, T alpha, const T *A, int lda, - T *B, int ldb) const { - CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, - B, ldb); -} - -} // namespace math -} // namespace operators -} // namespace paddle From 13134f0f013e8c11fe7efce932c2f863ac566b37 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 12:52:53 +0000 Subject: [PATCH 5/9] add scale uinttest --- .../fluid/tests/unittests/test_scale_op.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index c1ce032f50612..2bdd8802ef4fc 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid import paddle.fluid.core as core @@ -153,6 +153,23 @@ def test_check_grad(self): place, ["X"], "Out", max_relative_error=0.05) +class TestScaleBF16Op(OpTest): + def setUp(self): + self.op_type = "scale" + self.dtype = np.uint16 + self.attrs = {'scale': -2.3} + x = np.random.random((10, 10)).astype(np.float32) + out = x * np.float32(self.attrs['scale']) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows): From cd5d04dae4b8520868c9299e2bcb42d3c8993526 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 12:53:18 +0000 Subject: [PATCH 6/9] add sum unittest --- .../fluid/tests/unittests/test_sum_op.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index eddccd4ff24f1..85880ef171e21 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -298,6 +298,32 @@ def test_w_is_selected_rows(self): globals()[cls_name] = TestSumFp16Case +#----------- test bf16 ----------- +class TestSumOp(OpTest): + def setUp(self): + self.op_type = "sum" + self.init_kernel_type() + x0 = np.random.random((3, 40)).astype(np.float32) + x1 = np.random.random((3, 40)).astype(np.float32) + x2 = np.random.random((3, 40)).astype(np.float32) + y = x0 + x1 + x2 + self.inputs = { + "X": [("x0", convert_float_to_uint16(x0)), + ("x1", convert_float_to_uint16(x1)), + ("x2", convert_float_to_uint16(x2))] + } + self.outputs = {'Out': convert_float_to_uint16(y)} + + def init_kernel_type(self): + self.dtype = np.uint16 + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + + class API_Test_Add_n(unittest.TestCase): def test_api(self): with fluid.program_guard(fluid.Program(), fluid.Program()): From 21d6652edbb450387564d6431cc74d5fc9567955 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 18 Feb 2022 13:32:43 +0000 Subject: [PATCH 7/9] solve conflict --- paddle/pten/kernels/funcs/blas/blas_impl.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/pten/kernels/funcs/blas/blas_impl.h b/paddle/pten/kernels/funcs/blas/blas_impl.h index 5c93011ab500a..8ac9b2e64337f 100644 --- a/paddle/pten/kernels/funcs/blas/blas_impl.h +++ b/paddle/pten/kernels/funcs/blas/blas_impl.h @@ -76,6 +76,16 @@ struct CBlas { "Blas VCOPY do not supported on CPU with bfloat16," " please check your code")); } + + template + static void VADD(int n, + const pten::dtype::bfloat16 *x, + const pten::dtype::bfloat16 *y, + pten::dtype::bfloat16 *z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } + } }; #ifdef PADDLE_WITH_MKLML From 433eb1cc905c9a98f9c973c6a1c8f5911563cc6c Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 21 Feb 2022 06:09:22 +0000 Subject: [PATCH 8/9] refine gather unittest --- python/paddle/fluid/tests/unittests/test_gather_op.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 73414f9a2f573..1a451eaf2a99a 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -130,12 +130,10 @@ def setUp(self): self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output() def test_check_grad(self): - place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.5) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.5) def config(self): """ From 4d39865b124ace47dc7e5194879ecdba7a078075 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 1 Mar 2022 02:28:07 +0000 Subject: [PATCH 9/9] refine unittest --- .../paddle/fluid/tests/unittests/test_gather_op.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index 1a451eaf2a99a..978a3d86d882a 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -123,11 +123,15 @@ def setUp(self): self.dtype = np.uint16 self.config() xnp = np.random.random(self.x_shape).astype(np.float32) + axis_np = np.array(self.axis).astype(self.axis_type) + index_np = np.array(self.index).astype(self.index_type) self.inputs = { 'X': convert_float_to_uint16(xnp), - 'Index': np.array(self.index).astype(self.index_type) + 'Index': index_np, + 'Axis': axis_np } - self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + out = gather_numpy(self.inputs['X'], index_np, axis_np[0]) + self.outputs = {'Out': out} def test_check_output(self): self.check_output() @@ -139,9 +143,11 @@ def config(self): """ For multi-dimension input """ - self.x_shape = (10, 20) + self.x_shape = (3, 88, 3) self.index = [1, 3, 5] self.index_type = "int32" + self.axis = [1] + self.axis_type = "int32" class TestGatherOp1(OpTest):