Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
GEMM Tensor Core Support (#13336)
Browse files Browse the repository at this point in the history
* add codepath for implicit conversion to half precision for batch_gemm on gpu

* fix argument call

* document implicit float32 conversions

* separate gemm test from laop tests, add test for tensorcore code path for gemm
  • Loading branch information
sbodenstein authored and eric-haibin-lin committed Nov 21, 2018
1 parent 60dd170 commit 1e4afd5
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 44 deletions.
50 changes: 49 additions & 1 deletion src/operator/linalg_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,57 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
&beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0))) \
}

LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)

#if CUDA_VERSION < 9010
LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
#else
template <>
inline void linalg_batch_gemm<gpu, float>(const Tensor<gpu, 3, float>& A,
const Tensor<gpu, 3, float>& B,
const Tensor<gpu, 3, float>& C,
float alpha, float beta, bool tA,
bool tB, Stream<gpu>* s) {
using namespace mxnet;
using mshadow::gpu;
CHECK_NOTNULL(s);
linalg_check_batch_size(A.size(0), B.size(0), C.size(0));
check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB);
auto blas_handle = Stream<gpu>::GetBlasHandle(s);
bool use_tensor_ops =
GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion();

using namespace mshadow::cuda;
auto cublas_math_mode =
use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);

// cublasGemmStridedBatchedEx is only supported for GPU with architecture
// capabilities equal or greater than 5.0. Fall back to
// cublasSgemmStridedBatched, which doesn't support implicit conversion
// to half-precision to use TensorCores
auto cc_major = (s->prop).major;
if ((cc_major >= 5) && use_tensor_ops) {
CUBLAS_CALL(cublasGemmStridedBatchedEx(
blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(2), C.size(1),
(tB ? B.size(2) : B.size(1)), &alpha, B.dptr_, CUDA_R_32F,
B.stride_, B.size(1) * B.stride_, A.dptr_, CUDA_R_32F, A.stride_,
A.size(1) * A.stride_, &beta, C.dptr_, CUDA_R_32F, C.stride_,
C.size(1) * C.stride_, A.size(0), CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
CUBLAS_CALL(cublasSgemmStridedBatched(
blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
(tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(2), C.size(1),
(tB ? B.size(2) : B.size(1)), &alpha, B.dptr_, B.stride_,
B.size(1) * B.stride_, A.dptr_, A.stride_, A.size(1) * A.stride_,
&beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0)));
}
SetCublasMathMode(blas_handle, previous_math_mode);
}
#endif // CUDA_VERSION < 9010

// Version where matrix rows are given by second axis.
#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
template<> inline \
Expand Down
5 changes: 5 additions & 0 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
Expand Down
10 changes: 10 additions & 0 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, *B
without the overhead of the additional swapaxis operations.
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
.. note:: The operator supports float32 and float64 data types only.
Examples::
Expand Down Expand Up @@ -134,6 +139,11 @@ calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, ax
without the overhead of the additional swapaxis operations.
When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
.. note:: The operator supports float32 and float64 data types only.
Examples::
Expand Down
114 changes: 71 additions & 43 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mxnet.base import py_str, MXNetError, _as_list
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
import unittest
import os

def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4):
dshape = (N, T, I)
Expand Down Expand Up @@ -5242,46 +5243,10 @@ def test_deformable_psroipooling():
grad_nodes=grad_nodes, ctx=mx.gpu(0))


# Helper functions for test_laop

def _make_symm_symbol(a, ndims):
assert ndims >= 2
tr_shape = list(range(ndims))
tr_shape[-1] = ndims-2
tr_shape[-2] = ndims-1
tr_shape = tuple(tr_shape)
return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))

def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
assert ndims >= 2
# The last two dimensions must both be m
# Create mask for lower triangle and diagonal
index = mx.sym.arange(start=0, stop=m, step=1, dtype=np.int32)
lt_mask = mx.sym.one_hot(index, depth=m, dtype=dtype)
for j in range(1, m):
part1 = mx.sym.zeros(shape=(j, m), dtype=dtype)
index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
if not lower:
lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
shp = tuple([1]*(ndims-2) + [m, m])
lt_mask = mx.sym.reshape(lt_mask, shape=shp)
return mx.sym.broadcast_mul(a, lt_mask)

# @ankkhedia: Getting rid of fixed seed as flakiness could not be reproduced
# tracked at https://github.com/apache/incubator-mxnet/issues/11718
@with_seed()
def test_laop():
dtype = np.float64
rtol_fw = 1e-7
atol_fw = 1e-9
def _gemm_test_helper(dtype, grad_check, rtol_fw = 1e-7, atol_fw = 1e-9):
num_eps = 1e-6
rtol_bw = 1e-5
atol_bw = 1e-6
# enable numerical checking of gradients
grad_check = 1

data1 = mx.symbol.Variable('data1')
data2 = mx.symbol.Variable('data2')
Expand All @@ -5296,15 +5261,14 @@ def test_laop():
rep_3x = lambda a, m, n :\
np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))

# Test gemm separately from other la-operators.
shape1 = (2, 3)
shape2 = (3, 2)
shape3 = (3, 3)
shape4 = (2, 2)
data_in1 = np.random.uniform(1, 10, shape1)
data_in2 = np.random.uniform(1, 10, shape2)
data_in3 = np.random.uniform(1, 10, shape3)
data_in4 = np.random.uniform(1, 10, shape4)
data_in1 = np.random.uniform(1, 10, shape1).astype(dtype)
data_in2 = np.random.uniform(1, 10, shape2).astype(dtype)
data_in3 = np.random.uniform(1, 10, shape3).astype(dtype)
data_in4 = np.random.uniform(1, 10, shape4).astype(dtype)
# Check all transpositions of gemm operator.
data_in1_t = np.transpose(data_in1)
data_in2_t = np.transpose(data_in2)
Expand Down Expand Up @@ -5406,7 +5370,71 @@ def test_laop():
if grad_check == 1:
check_grad(test_gemm, [a2, b2])

# Now test all the other operators.
# Test gemm separately from other la-operators.
@with_seed()
def test_gemm():
_gemm_test_helper(np.float64, True)
os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "0"
_gemm_test_helper(np.float32, False, rtol_fw = 1e-5, atol_fw = 1e-7)
if default_context().device_type == 'gpu':
os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "1"
_gemm_test_helper(np.float32, False, rtol_fw = 2e-5, atol_fw = 2e-7)
os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "0"

# Helper functions for test_laop

def _make_symm_symbol(a, ndims):
assert ndims >= 2
tr_shape = list(range(ndims))
tr_shape[-1] = ndims-2
tr_shape[-2] = ndims-1
tr_shape = tuple(tr_shape)
return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))

def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
assert ndims >= 2
# The last two dimensions must both be m
# Create mask for lower triangle and diagonal
index = mx.sym.arange(start=0, stop=m, step=1, dtype=np.int32)
lt_mask = mx.sym.one_hot(index, depth=m, dtype=dtype)
for j in range(1, m):
part1 = mx.sym.zeros(shape=(j, m), dtype=dtype)
index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
if not lower:
lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
shp = tuple([1]*(ndims-2) + [m, m])
lt_mask = mx.sym.reshape(lt_mask, shape=shp)
return mx.sym.broadcast_mul(a, lt_mask)

# @ankkhedia: Getting rid of fixed seed as flakiness could not be reproduced
# tracked at https://github.com/apache/incubator-mxnet/issues/11718
@with_seed()
def test_laop():
dtype = np.float64
rtol_fw = 1e-7
atol_fw = 1e-9
num_eps = 1e-6
rtol_bw = 1e-5
atol_bw = 1e-6
# enable numerical checking of gradients
grad_check = 1

data1 = mx.symbol.Variable('data1')
data2 = mx.symbol.Variable('data2')
data3 = mx.symbol.Variable('data3')

check_fw = lambda sym, location, expected :\
check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
atol=atol_fw, dtype=dtype)
check_grad = lambda sym, location:\
check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
atol=atol_bw, dtype=dtype)
rep_3x = lambda a, m, n :\
np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))

for lower in [True, False]:
upper = not lower

Expand Down

0 comments on commit 1e4afd5

Please sign in to comment.