Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Add backend MXGetMaxSupportedArch() and frontend get_rtc_compile_opts() for CUDA enhanced compatibility #20443

Merged
merged 3 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ MXNET_DLL int MXGetEnv(const char* name,
MXNET_DLL int MXSetEnv(const char* name,
const char* value);

/*!
* \brief Get the maximum SM architecture supported by the nvrtc compiler
* \param max_arch The maximum supported architecture (e.g. would be 80, if Ampere)
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXGetMaxSupportedArch(uint32_t *max_arch);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import
from .util import get_max_supported_compute_capability, get_rtc_compile_opts # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability

Expand Down
26 changes: 25 additions & 1 deletion python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def get_cuda_compute_capability(ctx):
raise ValueError('Expecting a gpu context to get cuda compute capability, '
'while received ctx {}'.format(str(ctx)))

libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll', 'cuda.dll')
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
Expand Down Expand Up @@ -1176,3 +1176,27 @@ def setenv(name, value):
"""
passed_value = None if value is None else c_str(value)
check_call(_LIB.MXSetEnv(c_str(name), passed_value))


def get_max_supported_compute_capability():
"""Get the maximum compute capability (SM arch) supported by the nvrtc compiler
"""
max_supported_cc = ctypes.c_int()
check_call(_LIB.MXGetMaxSupportedArch(ctypes.byref(max_supported_cc)))
return max_supported_cc.value


def get_rtc_compile_opts(ctx):
"""Get the compile ops suitable for the context, given the toolkit/driver config
"""
device_cc = get_cuda_compute_capability(ctx)
max_supported_cc = get_max_supported_compute_capability()

# CUDA toolkits starting with 11.1 (first to support arch 86) can compile directly to SASS
can_compile_to_SASS = max_supported_cc >= 86
should_compile_to_SASS = can_compile_to_SASS and \
device_cc <= max_supported_cc
device_cc_as_used = min(device_cc, max_supported_cc)
arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute",
device_cc_as_used)
return [arch_opt]
11 changes: 11 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <nnvm/pass.h>
#include "./c_api_common.h"
#include "../operator/subgraph/subgraph_property.h"
#include "../common/cuda/rtc.h"

int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
Expand Down Expand Up @@ -128,3 +129,13 @@ int MXSetEnv(const char* name,
#endif
API_END();
}

int MXGetMaxSupportedArch(uint32_t *max_arch) {
API_BEGIN();
#if MXNET_USE_CUDA
*max_arch = static_cast<uint32_t>(mxnet::common::cuda::rtc::GetMaxSupportedArch());
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation.";
#endif
API_END();
}
50 changes: 32 additions & 18 deletions src/common/cuda/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@ std::string to_string(OpReqType req) {

} // namespace util

int GetMaxSupportedArch() {
#if CUDA_VERSION < 10000
constexpr int max_supported_sm_arch = 72;
#elif CUDA_VERSION < 11000
constexpr int max_supported_sm_arch = 75;
#elif CUDA_VERSION < 11010
constexpr int max_supported_sm_arch = 80;
#elif CUDA_VERSION < 11020
constexpr int max_supported_sm_arch = 86;
#else
// starting with cuda 11.2, nvrtc can report the max supported arch,
// removing the need to update this routine with each new cuda version.
static int max_supported_sm_arch = []() {
int num_archs = 0;
NVRTC_CALL(nvrtcGetNumSupportedArchs(&num_archs));
std::vector<int> archs(num_archs);
if (num_archs > 0) {
NVRTC_CALL(nvrtcGetSupportedArchs(archs.data()));
} else {
LOG(FATAL) << "Could not determine supported cuda archs.";
}
return archs[num_archs - 1];
}();
#endif
return max_supported_sm_arch;
}

namespace {

// Obtain compilation log from the program.
Expand Down Expand Up @@ -97,27 +124,14 @@ std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) {
}

std::tuple<bool, std::string> GetArchString(const int sm_arch) {
#if CUDA_VERSION < 10000
constexpr int max_supported_sm_arch = 72;
#elif CUDA_VERSION < 11000
constexpr int max_supported_sm_arch = 75;
#elif CUDA_VERSION < 11010
constexpr int max_supported_sm_arch = 80;
#else
constexpr int max_supported_sm_arch = 86;
#endif

#if CUDA_VERSION <= 11000
const int sm_arch_as_used = std::min(sm_arch, GetMaxSupportedArch());
// Always use PTX for CUDA <= 11.0
const bool known_arch = false;
#else
const bool known_arch = sm_arch <= max_supported_sm_arch;
#endif
const int actual_sm_arch = std::min(sm_arch, max_supported_sm_arch);
const bool known_arch = (CUDA_VERSION > 11000) &&
(sm_arch == sm_arch_as_used);
if (known_arch) {
return {known_arch, "sm_" + std::to_string(actual_sm_arch)};
return {known_arch, "sm_" + std::to_string(sm_arch_as_used)};
} else {
return {known_arch, "compute_" + std::to_string(actual_sm_arch)};
return {known_arch, "compute_" + std::to_string(sm_arch_as_used)};
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/common/cuda/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ std::string to_string(OpReqType req);

} // namespace util

int GetMaxSupportedArch();

extern std::mutex lock;

/*! \brief Compile and get the GPU kernel. Uses cache in order to
Expand Down
4 changes: 2 additions & 2 deletions src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ CudaModule::Chunk::Chunk(
<< "For lower version of CUDA, please prepend your kernel defintiions "
<< "with extern \"C\" instead.";
#endif
std::vector<const char*> c_options(options.size());
for (const auto& i : options) c_options.emplace_back(i.c_str());
std::vector<const char*> c_options;
for (const auto& i : options) c_options.push_back(i.c_str());
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data());
if (compile_res != NVRTC_SUCCESS) {
size_t err_size;
Expand Down
17 changes: 10 additions & 7 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import mxnet.ndarray.sparse as mxsps
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose
from mxnet.test_utils import check_symbolic_forward, check_symbolic_backward, discard_stderr
from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment
from mxnet.test_utils import default_context, rand_shape_2d, rand_ndarray, same, environment, get_rtc_compile_opts
from mxnet.base import MXNetError
from mxnet import autograd

Expand Down Expand Up @@ -1796,6 +1796,7 @@ def test_autograd_save_memory():

@pytest.mark.serial
def test_cuda_rtc():
ctx = mx.gpu(0)
source = r'''
extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
Expand All @@ -1809,18 +1810,20 @@ def test_cuda_rtc():
y[i] += alpha * smem[threadIdx.x];
}
'''
module = mx.rtc.CudaModule(source)

compile_opts = get_rtc_compile_opts(ctx)
module = mx.rtc.CudaModule(source, options=compile_opts)
axpy = module.get_kernel("axpy", "const float *x, float *y, float alpha")
x = mx.nd.ones((10,), ctx=mx.gpu(0))
y = mx.nd.zeros((10,), ctx=mx.gpu(0))
axpy.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
x = mx.nd.ones((10,), ctx=ctx)
y = mx.nd.zeros((10,), ctx=ctx)
axpy.launch([x, y, 3.0], ctx, (1, 1, 1), (10, 1, 1))
assert (y.asnumpy() == 3).all()

saxpy = module.get_kernel("saxpy", "const float *x, float *y, float alpha")
saxpy.launch([x, y, 4.0], mx.gpu(0), (1, 1, 1), (10, 1, 1), 10)
saxpy.launch([x, y, 4.0], ctx, (1, 1, 1), (10, 1, 1), 10)
assert (y.asnumpy() == 7).all()

saxpy.launch([x, y, 5.0], mx.gpu(0), (2, 1, 1), (5, 1, 1), 5)
saxpy.launch([x, y, 5.0], ctx, (2, 1, 1), (5, 1, 1), 5)
assert (y.asnumpy() == 12).all()


Expand Down