Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

gemm: fp16 tensorcore fixes #1211

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions caffe2/utils/math_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,39 +315,66 @@ void Gemm<float16, CUDAContext, TensorCoreEngine>(
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

// enable TensorCore for this call on this handle
if (TensorCoreAvailable()) {
CUBLAS_ENFORCE(cublasSetMathMode(
context->cublas_handle(),
CUBLAS_TENSOR_OP_MATH));
}

CUBLAS_CHECK(cublasGemmEx(
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
int device = CaffeCudaGetDevice();
auto& prop = GetDeviceProperty(device);

if (prop.major >= 5) {
// enable TensorCore for this call on this handle
auto algo = CUBLAS_GEMM_DFALT;
if (TensorCoreAvailable()) {
CUBLAS_ENFORCE(cublasSetMathMode(
context->cublas_handle(),
CUBLAS_TENSOR_OP_MATH));
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}

// Now disable TensorCore math for subsequent calls to this handle
if (TensorCoreAvailable()) {
CUBLAS_ENFORCE(cublasSetMathMode(
context->cublas_handle(),
CUBLAS_DEFAULT_MATH));
CUBLAS_CHECK(cublasGemmEx(
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N,
CUDA_R_32F,
algo));

// Now disable TensorCore math for subsequent calls to this handle
if (TensorCoreAvailable()) {
CUBLAS_ENFORCE(cublasSetMathMode(
context->cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
} else {
// fall back to SgemmEx when arch < Maxwell
CUBLAS_CHECK(cublasSgemmEx(
context->cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&beta,
C,
CUDA_R_16F,
N));
}
}

Expand Down