Skip to content

Commit

Permalink
Merge pull request apache#150 from sxjscience/accelerate_batched_gemm
Browse files Browse the repository at this point in the history
Accelerate batched_gemm in GPU using CuBLAS version
  • Loading branch information
sxjscience committed Jul 24, 2016
2 parents 76c12ef + 335b9a0 commit 44d61f8
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 80 deletions.
21 changes: 21 additions & 0 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ inline void MapReduceKeepDim1(expr::Plan<DstExp, DType> dst,
<<<dimGrid, dimBlock, 0, stream>>>(dst, plan, scale, pshape);
}

template<int x_bits, typename DType>
__global__ void GetBatchedViewKernel(DType **dst, DType *src, int num, int stride) {
const int x_size = 1 << x_bits;
const int start = threadIdx.x;
// Copy the addresses of src to dst every stride steps
for (int i = start; i < num; i += x_size) {
dst[i] = src + i * stride;
}
}

template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<gpu> *stream) {
cudaStream_t stream_ = Stream<gpu>::GetStream(stream);
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(1);
CheckLaunchParam(dimGrid, dimBlock, "GetBatchedView");
GetBatchedViewKernel<kBaseThreadBits, DType>
<<<dimGrid, dimBlock, 0, stream_>>> (dst, src, num, stride);
}

template<int x_bits, typename DType, typename DstPlan, typename SrcPlan1, typename SrcPlan2>
__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax) {
const unsigned x_size = 1 << x_bits;
Expand Down
152 changes: 74 additions & 78 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,37 @@
#include "./base.h"
#include "./extension/implicit_gemm.h"

#ifdef __CUDACC__
#include "./cuda/tensor_gpu-inl.cuh"
#endif // #ifdef __CUDACC__

namespace mshadow {
/*!
* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride
* \param dst 2D pointer
* \param src 1D pointer
* \param num number of batches
* \param stride size of each batch
* \param stream
*/
template<typename Device, typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<Device> *stream);
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<cpu> *stream) {
for (int i = 0; i < num; i++) {
dst[i] = src + i * stride;
}
}
#ifdef __CUDACC__
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<gpu> *stream) {
cuda::GetBatchedView(dst, src, num, stride, stream);
}
#endif // #ifdef __CUDACC__

namespace expr {
//---------------------------------------------------------------------
// Matrix Multiplications, depends on BLAS Engine
Expand Down Expand Up @@ -42,7 +72,8 @@ struct BLASEngine {
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc, int batch_count) {
DType beta, DType *C, int ldc, int batch_count,
DType **workspace) {
LOG(FATAL) << "Not implmented!";
}
inline static void gemv(Stream<Device> *stream,
Expand Down Expand Up @@ -116,7 +147,8 @@ struct BLASEngine<cpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -193,7 +225,8 @@ struct BLASEngine<cpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -255,7 +288,8 @@ struct BLASEngine<cpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -324,7 +358,8 @@ struct BLASEngine<cpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -424,7 +459,8 @@ struct BLASEngine<gpu, half::half_t> {
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda, const half::half_t *B, int ldb,
half::half_t beta, half::half_t *C, int ldc, int batch_count) {
half::half_t beta, half::half_t *C, int ldc, int batch_count,
half::half_t **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -491,12 +527,27 @@ struct BLASEngine<gpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<float*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const float**)workspace, lda,
(const float**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, float alpha,
Expand Down Expand Up @@ -575,12 +626,27 @@ struct BLASEngine<gpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<double*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const double**)workspace, lda,
(const double**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, double alpha,
Expand Down Expand Up @@ -638,9 +704,6 @@ struct BLASEngine<gpu, double> {
inline static Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
return transpose ? Shape2(shape[1], shape[0]) : shape;
}
inline static Shape<3> GetBatchedShape(const Shape<3> &shape, bool transpose) {
return transpose ? Shape3(shape[0], shape[2], shape[1]) : shape;
}
// dst = dot(lhs[.T], rhs[.T])
template<typename SV, typename xpu,
bool transpose_left, bool transpose_right, typename DType>
Expand Down Expand Up @@ -732,73 +795,6 @@ struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
}
}
};
// dst = batched_dot(lhs[.T], rhs[.T])
template<typename SV, typename xpu,
bool transpose_left, bool transpose_right, typename DType>
struct DotEngine<SV, xpu, 3, 3, 3, transpose_left, transpose_right, DType> {
inline static void Eval(Tensor<xpu, 3, DType> *p_dst,
const Tensor<xpu, 3, DType> &lhs,
const Tensor<xpu, 3, DType> &rhs,
DType scale) {
Tensor<xpu, 3, DType> &dst = *p_dst;
// set kernel stream
// if there is no stream, crush
BLASEngine<xpu, DType>::SetStream(dst.stream_);
Shape<3> sleft = GetBatchedShape(lhs.shape_, transpose_left);
Shape<3> sright = GetBatchedShape(rhs.shape_, transpose_right);
CHECK(dst.size(0) == sleft[0] && dst.size(0) == sright[0])
<< "batch_dot-gemm: batchsize must be equal."
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
<< "batch_dot-gemm: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
// use column major argument to compatible with most BLAS
if (sleft[1] == 1) {
// For (batch, 1, K) gemm (batch, K, N), we can use (batch, N, K) gemv (batch, K)
BLASEngine<xpu, DType>::batched_gemv
(dst.stream_,
transpose_right,
rhs.size(2), rhs.size(1), scale * SV::AlphaBLAS(),
rhs.dptr_, rhs.stride_,
lhs.dptr_, 1, SV::BetaBLAS(),
dst.dptr_, 1, dst.size(0));
} else if (sleft[2] == 1 && (SV::BetaBLAS() == 0.0f || SV::BetaBLAS() == 1.0f)) {
// For (batch, M, 1) gemm (batch, 1, N) + Beta = 0, we can use (batch, M) ger (batch, N)
if (SV::BetaBLAS() == 0.0f) {
dst = DType(0);
}
BLASEngine<xpu, DType>::batched_ger
(dst.stream_, sright[2], sleft[1], scale * SV::AlphaBLAS(),
rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_, dst.size(0));
} else if (sright[2] == 1) {
// For (batch, M, K) gemm (batch, K, 1), we can use (batch, M, K) gemv (batch, K)
BLASEngine<xpu, DType>::batched_gemv
(dst.stream_,
!transpose_left,
lhs.size(2), lhs.size(1), scale * SV::AlphaBLAS(),
lhs.dptr_, lhs.stride_,
rhs.dptr_, 1, SV::BetaBLAS(),
dst.dptr_, 1, dst.size(0));
} else {
// For general case, use gemm
BLASEngine<xpu, DType>::batched_gemm
(dst.stream_,
transpose_right, transpose_left,
transpose_right ? rhs.size(1) : rhs.size(2),
transpose_left ? lhs.size(2) : lhs.size(1),
transpose_right ? rhs.size(2) : rhs.size(1),
DType(scale * SV::AlphaBLAS()),
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
DType(SV::BetaBLAS()),
dst.dptr_, dst.stride_, dst.size(0));
}
}
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_DOT_ENGINE_INL_H_
19 changes: 17 additions & 2 deletions mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -855,17 +855,32 @@ template<typename Saver, typename Reducer, int dimkeep,
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);

/*!
* \brief CPU/GPU: 1 dimension vector dot
* \param dst Length 1 vector, used to hold the result.
* \param lhs Left operand vector
* \param rhs right operand vector
* \param rhs Right operand vector
*/
template<typename Device, typename DType>
inline void VectorDot(Tensor<Device, 1, DType> dst,
const Tensor<Device, 1, DType> &lhs,
const Tensor<Device, 1, DType> &rhs);
/*!
* \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst
* \param dst Length 3 tensor, used to hold the result
* \param lhs Left operand vector
* \param rhs Right operand vector
* \param alpha multiplier of op(lhs)op(rhs)
* \param beta multiplier of dst
* \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size
*/
template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace);
} // namespace mshadow
// include headers
#include "./stream_gpu-inl.h"
Expand Down
44 changes: 44 additions & 0 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,5 +449,49 @@ inline void VectorDot(Tensor<Device, 1, DType> dst,
mshadow::expr::BLASEngine<Device, DType>::dot(
lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
}

template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace) {
int batch_size = dst.shape_[0];
expr::BLASEngine<Device, DType>::SetStream(dst.stream_);
Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
: lhs.shape_;
Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
: rhs.shape_;
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(lhs.CheckContiguous(), true);
CHECK_EQ(rhs.CheckContiguous(), true);
CHECK(sleft[0] == batch_size && sright[0] == batch_size)
<< "BatchGEMM: batchsize must be equal."
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
<< "BatchGEMM: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(workspace.size(0) >= 3 * batch_size)
<< "Workspace Size must be bigger than " << 3 * batch_size;
CHECK_EQ(workspace.CheckContiguous(), true);
// use column major argument to compatible with most BLAS
expr::BLASEngine<Device, DType>::batched_gemm
(dst.stream_,
transpose_right, transpose_left,
transpose_right ? rhs.size(1) : rhs.size(2),
transpose_left ? lhs.size(2) : lhs.size(1),
transpose_right ? rhs.size(2) : rhs.size(1),
alpha,
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
beta,
dst.dptr_, dst.stride_, batch_size,
workspace.dptr_);
}
} // namespace mshadow
#endif // MSHADOW_TENSOR_CPU_INL_H_

0 comments on commit 44d61f8

Please sign in to comment.