diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index b91d476d2aff..e0b70a658b62 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -358,6 +358,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`. - Values: 0(false) or 1(true) ```(default=1)``` - This variable controls whether to use the MKL-DNN backend in fused RNN operator for CPU context. There are two fusion implementations of RNN operator in MXNet. The MKL-DNN implementation has a better performance than the naive one, but the latter is more stable in the backward operation currently. +* MXNET_FC_TRUE_FP16 + - Values: 0(false) or 1(true) ```(default=0)``` + - If this variable is set to true, MXNet will perform fp16 accumulation when using cuBLAS and input datatype is set to float16. This could increase the speed of the computation, but might result in loss of accuracy. This makes this setting useful mainly for inference usecases. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index 59029eae65c2..44c8ebdbb959 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -50,6 +50,28 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, << "Must init CuBLAS handle in stream"; cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); + auto err = CUBLAS_STATUS_SUCCESS; + using TrueFP16Type = DType; + using PseudoFP16Type = typename CublasType::ScaleType; + // Set up alpha and beta values in the possible formats needed (only different when dtype == half) + TrueFP16Type trueFP16_alpha = static_cast(alpha); + TrueFP16Type trueFP16_beta = static_cast(beta); + PseudoFP16Type pseudoFP16_alpha = static_cast(alpha); + PseudoFP16Type pseudoFP16_beta = static_cast(beta); + const void *alpha_ptr; + const void *beta_ptr; + cudaDataType_t computeType; + bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false); + if (use_true_fp16) { + alpha_ptr = &trueFP16_alpha; + beta_ptr = &trueFP16_beta; + computeType = CublasType::kCudaFlag; + } else { + alpha_ptr = &pseudoFP16_alpha; + beta_ptr = &pseudoFP16_beta; + computeType = CublasType::kCudaFlag; + } + // cublasGemmStridedBatchedEx is only supported for GPU with architecture // capabilities equal or greater than 5.0. Fall back to // cublasSgemmStridedBatched, which doesn't support implicit conversion @@ -59,12 +81,12 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, CUBLAS_CALL(cublasGemmStridedBatchedEx( blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), static_cast(m), static_cast(n), static_cast(k), - reinterpret_cast(&alpha), + alpha_ptr, a, CublasType::kCudaFlag, static_cast(lda), strideA, b, CublasType::kCudaFlag, static_cast(ldb), strideB, - reinterpret_cast(&beta), + beta_ptr, c, CublasType::kCudaFlag, static_cast(ldc), strideC, - static_cast(batchCount), CUDA_R_32F, algo)); + static_cast(batchCount), computeType, algo)); } else { if (std::is_same::value) { CUBLAS_CALL(cublasSgemmStridedBatched( @@ -124,7 +146,7 @@ void gemm_switch_fp32accum(mshadow::Stream* s, bool transA, bool transB, cudaStream_t stream = mshadow::Stream::GetStream(s); if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, - strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } else { CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h index d83eb0d08815..fd6800d184e4 100644 --- a/src/operator/linalg_impl.h +++ b/src/operator/linalg_impl.h @@ -249,6 +249,7 @@ void linalg_gemm(const Tensor *s) { using namespace mxnet; + using namespace mxnet::common::cuda; using mshadow::gpu; CHECK_NOTNULL(s); check_gemm(A, B, C, alpha, beta, tA, tB); @@ -261,25 +262,59 @@ void linalg_gemm(const Tensor= 8000 cudaDataType_t half_datatype = CUDA_R_16F; #else cublasDataType_t half_datatype = CUBLAS_DATA_HALF; #endif - CUBLAS_CALL(cublasSgemmEx(blas_handle, + auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + using TrueFP16Type = mshadow::half::half_t; + using PseudoFP16Type = typename CublasType::ScaleType; + TrueFP16Type trueFP16_alpha = static_cast(alpha); + TrueFP16Type trueFP16_beta = static_cast(beta); + PseudoFP16Type pseudoFP16_alpha = static_cast(alpha); + PseudoFP16Type pseudoFP16_beta = static_cast(beta); + const void *alpha_ptr; + const void *beta_ptr; + cudaDataType_t computeType; + bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false); + if (use_true_fp16) { + alpha_ptr = &trueFP16_alpha; + beta_ptr = &trueFP16_beta; + computeType = CublasType::kCudaFlag; + } else { + alpha_ptr = &pseudoFP16_alpha; + beta_ptr = &pseudoFP16_beta; + computeType = CublasType::kCudaFlag; + } + if (SupportsFloat16Compute(s->dev_id)) { + CUBLAS_CALL(cublasGemmEx(blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N), (tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), - &alpha_f, + alpha_ptr, B.dptr_, half_datatype, B.stride_, A.dptr_, half_datatype, A.stride_, - &beta_f, - C.dptr_, half_datatype, C.stride_)); + beta_ptr, + C.dptr_, half_datatype, C.stride_, + computeType, algo)); + } else { + // pseudo-fp16 (fp32 math with fp16 I/O) + if (use_true_fp16) + common::LogOnce("MXNET_FC_TRUE_FP16 was set but this architecture does not support it."); + float alpha_f = static_cast(alpha); + float beta_f = static_cast(beta); + CUBLAS_CALL(cublasSgemmEx(blas_handle, + (tB ? CUBLAS_OP_T : CUBLAS_OP_N), + (tA ? CUBLAS_OP_T : CUBLAS_OP_N), + C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)), + &alpha_f, + B.dptr_, half_datatype, B.stride_, + A.dptr_, half_datatype, A.stride_, + &beta_f, + C.dptr_, half_datatype, C.stride_)); + } #if CUDA_VERSION >= 9000 SetCublasMathMode(blas_handle, previous_math_mode); #endif diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index aa56eee33dc4..42a2424c7d9b 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -615,6 +615,27 @@ def test_symbol_block_symbolic_bn_fp16_cast(): y1 = net(x) assert np.dtype(y1.dtype).name == 'float16' +@with_seed() +def test_gemms_true_fp16(): + ctx = mx.gpu(0) + input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx) + weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx) + + net = nn.Dense(128, in_units=512, use_bias=False) + net.cast('float16') + net.initialize(ctx=ctx) + net.weight.set_data(weights) + ref_results = net(input) + + os.environ["MXNET_FC_TRUE_FP16"] = "1" + results_trueFP16 = net(input) + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), + atol=atol, rtol=rtol) + os.environ["MXNET_FC_TRUE_FP16"] = "0" + + if __name__ == '__main__': import nose nose.runmodule()