Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Add fp16 support for batch_dot (#366)
Browse files Browse the repository at this point in the history
* fp16 batch dot

* fix unsupported arch
  • Loading branch information
eric-haibin-lin authored Dec 23, 2018
1 parent 8b31376 commit 6dc04f7
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,26 @@ struct BLASEngine<gpu, half::half_t> {
const half::half_t *A, int lda, const half::half_t *B, int ldb,
half::half_t beta, half::half_t *C, int ldc, int batch_count,
half::half_t **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 9000
int major = stream->prop.major;
int minor = stream->prop.minor;
// fp16 is not supported before ARCH 53
if ((major > 5) || (major == 5 && minor >= 3)) {
const __half* A_h = reinterpret_cast<const __half*>(A);
const __half* B_h = reinterpret_cast<const __half*>(B);
__half* alpha_h = reinterpret_cast<__half*>(&alpha);
__half* beta_h = reinterpret_cast<__half*>(&beta);
__half* C_h = reinterpret_cast<__half*>(C);
cublasStatus_t err = cublasHgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, alpha_h,
A_h, lda, m * k,
B_h, ldb, k * n,
beta_h, C_h, ldc, m * n,
batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail";
return;
}
#endif
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down

0 comments on commit 6dc04f7

Please sign in to comment.