From 335b9a0876a5fa48577e9f568d82ca8faf396c00 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Mon, 11 Jul 2016 15:58:48 +0800 Subject: [PATCH] Accelerate batched_gemm in GPU using CuBLAS version Fix lint Adding workspace Fix build error Fix build error on GPU Fix comment Update comment --- mshadow/cuda/tensor_gpu-inl.cuh | 21 +++++ mshadow/dot_engine-inl.h | 152 ++++++++++++++++---------------- mshadow/tensor.h | 19 +++- mshadow/tensor_cpu-inl.h | 44 +++++++++ 4 files changed, 156 insertions(+), 80 deletions(-) diff --git a/mshadow/cuda/tensor_gpu-inl.cuh b/mshadow/cuda/tensor_gpu-inl.cuh index 8ca577d437d3..ffee3c1b77dd 100644 --- a/mshadow/cuda/tensor_gpu-inl.cuh +++ b/mshadow/cuda/tensor_gpu-inl.cuh @@ -194,6 +194,27 @@ inline void MapReduceKeepDim1(expr::Plan dst, <<>>(dst, plan, scale, pshape); } +template +__global__ void GetBatchedViewKernel(DType **dst, DType *src, int num, int stride) { + const int x_size = 1 << x_bits; + const int start = threadIdx.x; + // Copy the addresses of src to dst every stride steps + for (int i = start; i < num; i += x_size) { + dst[i] = src + i * stride; + } +} + +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + cudaStream_t stream_ = Stream::GetStream(stream); + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(1); + CheckLaunchParam(dimGrid, dimBlock, "GetBatchedView"); + GetBatchedViewKernel + <<>> (dst, src, num, stride); +} + template __global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax) { const unsigned x_size = 1 << x_bits; diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 01def25f3d41..1fbeaaf034d1 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -10,7 +10,37 @@ #include "./base.h" #include "./extension/implicit_gemm.h" +#ifdef __CUDACC__ +#include "./cuda/tensor_gpu-inl.cuh" +#endif // #ifdef __CUDACC__ + namespace mshadow { + /*! +* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride +* \param dst 2D pointer +* \param src 1D pointer +* \param num number of batches +* \param stride size of each batch +* \param stream +*/ +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream); +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + for (int i = 0; i < num; i++) { + dst[i] = src + i * stride; + } +} +#ifdef __CUDACC__ +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + cuda::GetBatchedView(dst, src, num, stride, stream); +} +#endif // #ifdef __CUDACC__ + namespace expr { //--------------------------------------------------------------------- // Matrix Multiplications, depends on BLAS Engine @@ -42,7 +72,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc, int batch_count) { + DType beta, DType *C, int ldc, int batch_count, + DType **workspace) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, @@ -116,7 +147,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count) { + float beta, float *C, int ldc, int batch_count, + float **workspace) { for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, @@ -193,7 +225,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count) { + double beta, double *C, int ldc, int batch_count, + double **workspace) { for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, @@ -255,7 +288,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count) { + float beta, float *C, int ldc, int batch_count, + float **workspace) { for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, @@ -324,7 +358,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count) { + double beta, double *C, int ldc, int batch_count, + double **workspace) { for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, @@ -424,7 +459,8 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, - half::half_t beta, half::half_t *C, int ldc, int batch_count) { + half::half_t beta, half::half_t *C, int ldc, int batch_count, + half::half_t **workspace) { for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, @@ -491,12 +527,27 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count) { + float beta, float *C, int ldc, int batch_count, + float **workspace) { +#if defined(__CUDACC__) && CUDA_VERSION >= 4010 + // Cast DType* to DType** using workspace as a buffer + GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); + GetBatchedView(workspace + batch_count, + const_cast(B), batch_count, k * n, stream); + GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); + cublasStatus_t err = cublasSgemmBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + (const float**)workspace, lda, + (const float**)(workspace + batch_count), ldb, + &beta, workspace + 2 * batch_count, ldc, batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail"; +#else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } +#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, bool trans, int m, int n, float alpha, @@ -575,12 +626,27 @@ struct BLASEngine { bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count) { + double beta, double *C, int ldc, int batch_count, + double **workspace) { +#if defined(__CUDACC__) && CUDA_VERSION >= 4010 + // Cast DType* to DType** using workspace as a buffer + GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); + GetBatchedView(workspace + batch_count, + const_cast(B), batch_count, k * n, stream); + GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); + cublasStatus_t err = cublasDgemmBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + (const double**)workspace, lda, + (const double**)(workspace + batch_count), ldb, + &beta, workspace + 2 * batch_count, ldc, batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail"; +#else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } +#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, bool trans, int m, int n, double alpha, @@ -638,9 +704,6 @@ struct BLASEngine { inline static Shape<2> GetShape(const Shape<2> &shape, bool transpose) { return transpose ? Shape2(shape[1], shape[0]) : shape; } -inline static Shape<3> GetBatchedShape(const Shape<3> &shape, bool transpose) { - return transpose ? Shape3(shape[0], shape[2], shape[1]) : shape; -} // dst = dot(lhs[.T], rhs[.T]) template @@ -732,73 +795,6 @@ struct DotEngine { } } }; -// dst = batched_dot(lhs[.T], rhs[.T]) -template -struct DotEngine { - inline static void Eval(Tensor *p_dst, - const Tensor &lhs, - const Tensor &rhs, - DType scale) { - Tensor &dst = *p_dst; - // set kernel stream - // if there is no stream, crush - BLASEngine::SetStream(dst.stream_); - Shape<3> sleft = GetBatchedShape(lhs.shape_, transpose_left); - Shape<3> sright = GetBatchedShape(rhs.shape_, transpose_right); - CHECK(dst.size(0) == sleft[0] && dst.size(0) == sright[0]) - << "batch_dot-gemm: batchsize must be equal." - << "dst: " << dst.shape_ << "\n" - << "lhs: " << sleft << "\n" - << "rhs: " << sright << "\n"; - CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) - << "batch_dot-gemm: matrix shape mismatch" - << "dst: " << dst.shape_ << "\n" - << "lhs: " << sleft << "\n" - << "rhs: " << sright << "\n"; - // use column major argument to compatible with most BLAS - if (sleft[1] == 1) { - // For (batch, 1, K) gemm (batch, K, N), we can use (batch, N, K) gemv (batch, K) - BLASEngine::batched_gemv - (dst.stream_, - transpose_right, - rhs.size(2), rhs.size(1), scale * SV::AlphaBLAS(), - rhs.dptr_, rhs.stride_, - lhs.dptr_, 1, SV::BetaBLAS(), - dst.dptr_, 1, dst.size(0)); - } else if (sleft[2] == 1 && (SV::BetaBLAS() == 0.0f || SV::BetaBLAS() == 1.0f)) { - // For (batch, M, 1) gemm (batch, 1, N) + Beta = 0, we can use (batch, M) ger (batch, N) - if (SV::BetaBLAS() == 0.0f) { - dst = DType(0); - } - BLASEngine::batched_ger - (dst.stream_, sright[2], sleft[1], scale * SV::AlphaBLAS(), - rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_, dst.size(0)); - } else if (sright[2] == 1) { - // For (batch, M, K) gemm (batch, K, 1), we can use (batch, M, K) gemv (batch, K) - BLASEngine::batched_gemv - (dst.stream_, - !transpose_left, - lhs.size(2), lhs.size(1), scale * SV::AlphaBLAS(), - lhs.dptr_, lhs.stride_, - rhs.dptr_, 1, SV::BetaBLAS(), - dst.dptr_, 1, dst.size(0)); - } else { - // For general case, use gemm - BLASEngine::batched_gemm - (dst.stream_, - transpose_right, transpose_left, - transpose_right ? rhs.size(1) : rhs.size(2), - transpose_left ? lhs.size(2) : lhs.size(1), - transpose_right ? rhs.size(2) : rhs.size(1), - DType(scale * SV::AlphaBLAS()), - rhs.dptr_, rhs.stride_, - lhs.dptr_, lhs.stride_, - DType(SV::BetaBLAS()), - dst.dptr_, dst.stride_, dst.size(0)); - } - } -}; } // namespace expr } // namespace mshadow #endif // MSHADOW_DOT_ENGINE_INL_H_ diff --git a/mshadow/tensor.h b/mshadow/tensor.h index 680a0be8c9c7..ac094dff8f81 100644 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -855,17 +855,32 @@ template *dst, const expr::Exp &exp, DType scale = 1); - /*! * \brief CPU/GPU: 1 dimension vector dot * \param dst Length 1 vector, used to hold the result. * \param lhs Left operand vector - * \param rhs right operand vector + * \param rhs Right operand vector */ template inline void VectorDot(Tensor dst, const Tensor &lhs, const Tensor &rhs); +/*! + * \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst + * \param dst Length 3 tensor, used to hold the result + * \param lhs Left operand vector + * \param rhs Right operand vector + * \param alpha multiplier of op(lhs)op(rhs) + * \param beta multiplier of dst + * \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size + */ +template +inline void BatchGEMM(Tensor dst, + const Tensor &lhs, + const Tensor &rhs, + DType alpha, + DType beta, + Tensor workspace); } // namespace mshadow // include headers #include "./stream_gpu-inl.h" diff --git a/mshadow/tensor_cpu-inl.h b/mshadow/tensor_cpu-inl.h index aa68d97f56c7..703573b898d8 100644 --- a/mshadow/tensor_cpu-inl.h +++ b/mshadow/tensor_cpu-inl.h @@ -449,5 +449,49 @@ inline void VectorDot(Tensor dst, mshadow::expr::BLASEngine::dot( lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_); } + +template +inline void BatchGEMM(Tensor dst, + const Tensor &lhs, + const Tensor &rhs, + DType alpha, + DType beta, + Tensor workspace) { + int batch_size = dst.shape_[0]; + expr::BLASEngine::SetStream(dst.stream_); + Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1]) + : lhs.shape_; + Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1]) + : rhs.shape_; + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(lhs.CheckContiguous(), true); + CHECK_EQ(rhs.CheckContiguous(), true); + CHECK(sleft[0] == batch_size && sright[0] == batch_size) + << "BatchGEMM: batchsize must be equal." + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) + << "BatchGEMM: matrix shape mismatch" + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + CHECK(workspace.size(0) >= 3 * batch_size) + << "Workspace Size must be bigger than " << 3 * batch_size; + CHECK_EQ(workspace.CheckContiguous(), true); + // use column major argument to compatible with most BLAS + expr::BLASEngine::batched_gemm + (dst.stream_, + transpose_right, transpose_left, + transpose_right ? rhs.size(1) : rhs.size(2), + transpose_left ? lhs.size(2) : lhs.size(1), + transpose_right ? rhs.size(2) : rhs.size(1), + alpha, + rhs.dptr_, rhs.stride_, + lhs.dptr_, lhs.stride_, + beta, + dst.dptr_, dst.stride_, batch_size, + workspace.dptr_); +} } // namespace mshadow #endif // MSHADOW_TENSOR_CPU_INL_H_