Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Replace cublassgemm with cublassgemmex for >= 7.5
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jul 13, 2018
1 parent 5b4d528 commit 08c25e7
Showing 1 changed file with 42 additions and 15 deletions.
57 changes: 42 additions & 15 deletions src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,50 @@ void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half:

// cublas col-major processing accounted for by switching first two operands

#define LINALG_GPU_GEMM(fname, DType) \
template<> inline \
void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2, DType>& B, \
const Tensor<gpu, 2, DType>& C, DType alpha, DType beta, \
bool tA, bool tB, Stream<gpu> *s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
check_gemm(A, B, C, alpha, beta, tA, tB); \
CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
(tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), \
&alpha, B.dptr_, B.stride_, A.dptr_, A.stride_, \
&beta, C.dptr_, C.stride_)) \
#define LINALG_GPU_GEMM(fname, DType) \
template <> \
inline void linalg_gemm<gpu, DType>( \
const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2, DType>& B, \
const Tensor<gpu, 2, DType>& C, DType alpha, DType beta, bool tA, \
bool tB, Stream<gpu>* s) { \
using namespace mxnet; \
using mshadow::gpu; \
CHECK_NOTNULL(s); \
check_gemm(A, B, C, alpha, beta, tA, tB); \
CUBLAS_CALL(cublas##fname( \
Stream<gpu>::GetBlasHandle(s), (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0), \
(tB ? B.size(1) : B.size(0)), &alpha, B.dptr_, B.stride_, A.dptr_, \
A.stride_, &beta, C.dptr_, C.stride_)) \
}

#if CUDA_VERSION >= 7050
template <>
inline void linalg_gemm<gpu, float>(const Tensor<gpu, 2, float>& A,
const Tensor<gpu, 2, float>& B,
const Tensor<gpu, 2, float>& C, float alpha,
float beta, bool tA, bool tB,
Stream<gpu>* s) {
using namespace mxnet;
using mshadow::gpu;
CHECK_NOTNULL(s);
check_gemm(A, B, C, alpha, beta, tA, tB);
#if CUDA_VERSION >= 8000
cudaDataType_t full_datatype = CUDA_R_32F;
#else
cublasDataType_t full_datatype = CUBLAS_DATA_FULL;
#endif
CUBLAS_CALL(cublasSgemmEx(
Stream<gpu>::GetBlasHandle(s), (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0),
(tB ? B.size(1) : B.size(0)), &alpha, B.dptr_, full_datatype, B.stride_,
A.dptr_, full_datatype, A.stride_, &beta, C.dptr_, full_datatype,
C.stride_))
}

#else
LINALG_GPU_GEMM(Sgemm, float)
#endif
LINALG_GPU_GEMM(Dgemm, double)

// Version where matrix rows are given by first axis.
Expand Down

0 comments on commit 08c25e7

Please sign in to comment.