diff --git a/include/mxnet/c_api_test.h b/include/mxnet/c_api_test.h index df7079842657..fcfac33b2938 100644 --- a/include/mxnet/c_api_test.h +++ b/include/mxnet/c_api_test.h @@ -91,6 +91,13 @@ MXNET_DLL int MXGetEnv(const char* name, MXNET_DLL int MXSetEnv(const char* name, const char* value); +/*! + * \brief Get the maximum SM architecture supported by the nvrtc compiler + * \param max_arch The maximum supported architecture (e.g. would be 80, if Ampere) + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXGetMaxSupportedArch(uint32_t *max_arch); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 01bdaad5b581..bd0e73c6e2f5 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -50,6 +50,7 @@ from .symbol import Symbol from .symbol.numpy import _Symbol as np_symbol from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import +from .util import get_max_supported_compute_capability, get_rtc_compile_opts # pylint: disable=unused-import from .runtime import Features from .numpy_extension import get_cuda_compute_capability diff --git a/python/mxnet/util.py b/python/mxnet/util.py index a3785354a640..cafff0f9dd9e 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -875,7 +875,7 @@ def get_cuda_compute_capability(ctx): raise ValueError('Expecting a gpu context to get cuda compute capability, ' 'while received ctx {}'.format(str(ctx))) - libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll', 'cuda.dll') for libname in libnames: try: cuda = ctypes.CDLL(libname) @@ -1176,3 +1176,27 @@ def setenv(name, value): """ passed_value = None if value is None else c_str(value) check_call(_LIB.MXSetEnv(c_str(name), passed_value)) + + +def get_max_supported_compute_capability(): + """Get the maximum compute capability (SM arch) supported by the nvrtc compiler + """ + max_supported_cc = ctypes.c_int() + check_call(_LIB.MXGetMaxSupportedArch(ctypes.byref(max_supported_cc))) + return max_supported_cc.value + + +def get_rtc_compile_opts(ctx): + """Get the compile ops suitable for the context, given the toolkit/driver config + """ + device_cc = get_cuda_compute_capability(ctx) + max_supported_cc = get_max_supported_compute_capability() + + # CUDA toolkits starting with 11.1 (first to support arch 86) can compile directly to SASS + can_compile_to_SASS = max_supported_cc >= 86 + should_compile_to_SASS = can_compile_to_SASS and \ + device_cc <= max_supported_cc + device_cc_as_used = min(device_cc, max_supported_cc) + arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute", + device_cc_as_used) + return [arch_opt] diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index e84b0c0b1395..ac691234fcf0 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -26,6 +26,7 @@ #include #include "./c_api_common.h" #include "../operator/subgraph/subgraph_property.h" +#include "../common/cuda/rtc.h" int MXBuildSubgraphByOpNames(SymbolHandle sym_handle, const char* prop_name, @@ -128,3 +129,13 @@ int MXSetEnv(const char* name, #endif API_END(); } + +int MXGetMaxSupportedArch(uint32_t *max_arch) { + API_BEGIN(); +#if MXNET_USE_CUDA + *max_arch = static_cast(mxnet::common::cuda::rtc::GetMaxSupportedArch()); +#else + LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; +#endif + API_END(); +} diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index af4abbee468e..5b27e0bbd225 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -67,6 +67,33 @@ std::string to_string(OpReqType req) { } // namespace util +int GetMaxSupportedArch() { +#if CUDA_VERSION < 10000 + constexpr int max_supported_sm_arch = 72; +#elif CUDA_VERSION < 11000 + constexpr int max_supported_sm_arch = 75; +#elif CUDA_VERSION < 11010 + constexpr int max_supported_sm_arch = 80; +#elif CUDA_VERSION < 11020 + constexpr int max_supported_sm_arch = 86; +#else + // starting with cuda 11.2, nvrtc can report the max supported arch, + // removing the need to update this routine with each new cuda version. + static int max_supported_sm_arch = []() { + int num_archs = 0; + NVRTC_CALL(nvrtcGetNumSupportedArchs(&num_archs)); + std::vector archs(num_archs); + if (num_archs > 0) { + NVRTC_CALL(nvrtcGetSupportedArchs(archs.data())); + } else { + LOG(FATAL) << "Could not determine supported cuda archs."; + } + return archs[num_archs - 1]; + }(); +#endif + return max_supported_sm_arch; +} + namespace { // Obtain compilation log from the program. @@ -97,27 +124,14 @@ std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) { } std::tuple GetArchString(const int sm_arch) { -#if CUDA_VERSION < 10000 - constexpr int max_supported_sm_arch = 72; -#elif CUDA_VERSION < 11000 - constexpr int max_supported_sm_arch = 75; -#elif CUDA_VERSION < 11010 - constexpr int max_supported_sm_arch = 80; -#else - constexpr int max_supported_sm_arch = 86; -#endif - -#if CUDA_VERSION <= 11000 + const int sm_arch_as_used = std::min(sm_arch, GetMaxSupportedArch()); // Always use PTX for CUDA <= 11.0 - const bool known_arch = false; -#else - const bool known_arch = sm_arch <= max_supported_sm_arch; -#endif - const int actual_sm_arch = std::min(sm_arch, max_supported_sm_arch); + const bool known_arch = (CUDA_VERSION > 11000) && + (sm_arch == sm_arch_as_used); if (known_arch) { - return {known_arch, "sm_" + std::to_string(actual_sm_arch)}; + return {known_arch, "sm_" + std::to_string(sm_arch_as_used)}; } else { - return {known_arch, "compute_" + std::to_string(actual_sm_arch)}; + return {known_arch, "compute_" + std::to_string(sm_arch_as_used)}; } } diff --git a/src/common/cuda/rtc.h b/src/common/cuda/rtc.h index 126c967a0cb3..8c36aa161927 100644 --- a/src/common/cuda/rtc.h +++ b/src/common/cuda/rtc.h @@ -53,6 +53,8 @@ std::string to_string(OpReqType req); } // namespace util +int GetMaxSupportedArch(); + extern std::mutex lock; /*! \brief Compile and get the GPU kernel. Uses cache in order to diff --git a/src/common/rtc.cc b/src/common/rtc.cc index a2662ce9f59c..ece9c0566acd 100644 --- a/src/common/rtc.cc +++ b/src/common/rtc.cc @@ -44,8 +44,8 @@ CudaModule::Chunk::Chunk( << "For lower version of CUDA, please prepend your kernel defintiions " << "with extern \"C\" instead."; #endif - std::vector c_options(options.size()); - for (const auto& i : options) c_options.emplace_back(i.c_str()); + std::vector c_options; + for (const auto& i : options) c_options.push_back(i.c_str()); nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data()); if (compile_res != NVRTC_SUCCESS) { size_t err_size; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 73690059aa5c..03aef70934b0 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -27,7 +27,7 @@ import mxnet.ndarray.sparse as mxsps from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, discard_stderr -from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment +from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment, get_rtc_compile_opts from mxnet.base import MXNetError from mxnet import autograd @@ -1796,6 +1796,7 @@ def test_autograd_save_memory(): @pytest.mark.serial def test_cuda_rtc(): + ctx = mx.gpu(0) source = r''' extern "C" __global__ void axpy(const float *x, float *y, float alpha) { int i = threadIdx.x + blockIdx.x * blockDim.x; @@ -1809,18 +1810,20 @@ def test_cuda_rtc(): y[i] += alpha * smem[threadIdx.x]; } ''' - module = mx.rtc.CudaModule(source) + + compile_opts = get_rtc_compile_opts(ctx) + module = mx.rtc.CudaModule(source, options=compile_opts) axpy = module.get_kernel("axpy", "const float *x, float *y, float alpha") - x = mx.nd.ones((10,), ctx=mx.gpu(0)) - y = mx.nd.zeros((10,), ctx=mx.gpu(0)) - axpy.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1)) + x = mx.nd.ones((10,), ctx=ctx) + y = mx.nd.zeros((10,), ctx=ctx) + axpy.launch([x, y, 3.0], ctx, (1, 1, 1), (10, 1, 1)) assert (y.asnumpy() == 3).all() saxpy = module.get_kernel("saxpy", "const float *x, float *y, float alpha") - saxpy.launch([x, y, 4.0], mx.gpu(0), (1, 1, 1), (10, 1, 1), 10) + saxpy.launch([x, y, 4.0], ctx, (1, 1, 1), (10, 1, 1), 10) assert (y.asnumpy() == 7).all() - saxpy.launch([x, y, 5.0], mx.gpu(0), (2, 1, 1), (5, 1, 1), 5) + saxpy.launch([x, y, 5.0], ctx, (2, 1, 1), (5, 1, 1), 5) assert (y.asnumpy() == 12).all()