Skip to content

Commit

Permalink
[bf16] pten matmul cuda kernel support bf16 (#39485)
Browse files Browse the repository at this point in the history
* pten matmul cuda kernel support bf16

* fix pten kernel name

* add matmul_grad bf16 kernel

* add emptylike bf16 kernel

* fix compile

* suppport rocm

* fix error

* fix rocm

* add bf16 header file

* fix compile
  • Loading branch information
zhiqiu authored Feb 16, 2022
1 parent f31c242 commit d5a0d31
Show file tree
Hide file tree
Showing 12 changed files with 659 additions and 101 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}

KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
}

std::once_flag kernel_sig_map_init_flag;
Expand Down
243 changes: 243 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,102 @@ inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, int N,
int K, platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 80,
platform::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb, A,
CUDA_R_16BF, lda, &h_beta, C, CUDA_R_16BF, N, CUDA_R_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Expand Down Expand Up @@ -1208,6 +1304,42 @@ inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::bfloat16 alpha,
const platform::bfloat16 *A, const platform::bfloat16 *B,
platform::bfloat16 beta, platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::GEMV(bool trans_a, int M, int N,
platform::bfloat16 alpha,
const platform::bfloat16 *A,
const platform::bfloat16 *B,
platform::bfloat16 beta,
platform::bfloat16 *C) const {
// Because cublas doesn't support bfloat gemv, we use cublasHgemm to achieve
// it.
if (trans_a) {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::bfloat16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1306,6 +1438,91 @@ void Blas<pten::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 9010
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 *A,
const platform::bfloat16 *B, platform::bfloat16 beta, platform::bfloat16 *C,
int batchCount, int64_t strideA, int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasGemmStridedBatchedEx(
handle, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_16BF, ldb,
strideB, A, CUDA_R_16BF, lda, strideA, &h_beta, C, CUDA_R_16BF, ldc,
strideC, batchCount, CUBLAS_COMPUTE_32F, algo));
});
#else
// raise error
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
Expand Down Expand Up @@ -1356,6 +1573,32 @@ inline void Blas<pten::GPUContext>::BatchedGEMM(
}
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <>
inline void Blas<pten::GPUContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::bfloat16 alpha, const platform::bfloat16 **A,
const platform::bfloat16 **B, platform::bfloat16 beta,
platform::bfloat16 **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::bfloat16>(transA, transB, M, N, K, alpha,
A[k], B[k], beta, C[k]);
}
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
Expand Down
Loading

0 comments on commit d5a0d31

Please sign in to comment.