Skip to content

Commit a40f73f

Browse files
authored
[NVSHMEM] Extend CUDA backend to compile and link TIR modules with NVSHMEM (#18093)
1 parent bf752dc commit a40f73f

File tree

11 files changed

+218
-49
lines changed

11 files changed

+218
-49
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ endif()
832832

833833

834834
if (USE_CUDA AND USE_NVSHMEM)
835-
include_directories(SYSTEM ${USE_NVSHMEM}/include)
835+
target_include_directories(tvm_runtime_objs PUBLIC ${NVSHMEM_INCLUDE_DIR})
836836
find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR})
837837
find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR})
838838
target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE})

include/tvm/runtime/disco/builtin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ inline std::string ReduceKind2String(ReduceKind kind) {
6262
* \param device The default device used to initialize the RelaxVM
6363
* \return The RelaxVM as a runtime Module
6464
*/
65-
TVM_DLL Module LoadVMModule(std::string path, Device device);
65+
TVM_DLL Module LoadVMModule(std::string path, Optional<Device> device);
6666
/*!
6767
* \brief Create an uninitialized empty NDArray
6868
* \param shape The shape of the NDArray
6969
* \param dtype The dtype of the NDArray
7070
* \param device The device the NDArray is created on. If None, use the thread local default device
7171
* \return The NDArray created
7272
*/
73-
TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device);
73+
TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional<Device> device);
7474
/*!
7575
* \brief Perform an allreduce operation using the underlying communication library
7676
* \param send The array send to perform allreduce on

python/tvm/contrib/nvcc.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import subprocess
2323
import warnings
24+
from typing import Tuple
2425

2526
import tvm.ffi
2627
from tvm.target import Target
@@ -29,7 +30,7 @@
2930
from . import utils
3031

3132

32-
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
33+
def compile_cuda(code, target_format=None, arch=None, options=None, path_target=None):
3334
"""Compile cuda code with NVCC from env.
3435
3536
Parameters
@@ -54,6 +55,15 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
5455
cubin : bytearray
5556
The bytearray of the cubin
5657
"""
58+
# Check for NVSHMEM dependency
59+
nvshmem_include_path, nvshmem_lib_path = None, None
60+
use_nvshmem = (
61+
tvm.get_global_func("runtime.nvshmem.cumodule_init", allow_missing=True) is not None
62+
)
63+
if use_nvshmem:
64+
target_format = "cubin"
65+
nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
66+
5767
if arch is None:
5868
# If None, then it will use `tvm.target.Target.current().arch`.
5969
# Target arch could be a str like "sm_xx", or a list, such as
@@ -68,6 +78,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
6878

6979
temp = utils.tempdir()
7080
file_name = "tvm_kernels"
81+
if target_format is None and not use_nvshmem:
82+
target_format = "ptx"
7183
if target_format not in ["cubin", "ptx", "fatbin"]:
7284
raise ValueError("target_format must be in cubin, ptx, fatbin")
7385
temp_code = temp.relpath(f"{file_name}.cu")
@@ -89,6 +101,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
89101
out_file.write(code)
90102

91103
file_target = path_target if path_target else temp_target
104+
if use_nvshmem:
105+
file_prefix = file_target.split(".")[0]
106+
file_target = f"{file_prefix}.o" # in the first stage, compile to object file
92107
cmd = ["nvcc"]
93108
cmd += [f"--{target_format}", "-O3"]
94109
if kernels_output_dir is not None:
@@ -107,7 +122,12 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
107122
raise ValueError("options must be str or list of str")
108123

109124
cmd += ["-o", file_target]
110-
cmd += [temp_code]
125+
if not use_nvshmem:
126+
cmd += [temp_code]
127+
else:
128+
cmd += ["-c", temp_code]
129+
cmd += ["-rdc=true"]
130+
cmd += ["-I", nvshmem_include_path]
111131

112132
# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
113133
# just in case it is not in the path. On Windows it is not in the path by default.
@@ -127,6 +147,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
127147
msg += py_str(out)
128148
raise RuntimeError(msg)
129149

150+
# start second stage of compilation
151+
if use_nvshmem:
152+
cmd = ["nvlink"]
153+
cmd += [f"-arch=sm_{compute_version}"]
154+
cmd += [
155+
"-L",
156+
nvshmem_lib_path,
157+
]
158+
cmd += ["-L", os.path.join(find_cuda_path(), "lib64")]
159+
cmd += ["-l", "nvshmem_device"]
160+
cmd += ["-l", "cudadevrt"]
161+
cmd += ["-o", f"{file_prefix}.cubin"]
162+
cmd += [file_target]
163+
164+
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
165+
166+
(out, _) = proc.communicate()
167+
168+
if proc.returncode != 0:
169+
msg = code
170+
msg += "\nCompilation error:\n"
171+
msg += py_str(out)
172+
raise RuntimeError(msg)
173+
174+
file_target = f"{file_prefix}.cubin"
175+
130176
with open(file_target, "rb") as f:
131177
data = bytearray(f.read())
132178
if not data:
@@ -198,6 +244,70 @@ def get_cuda_version(cuda_path=None):
198244
raise RuntimeError("Cannot read cuda version file")
199245

200246

247+
def find_nvshmem_paths() -> Tuple[str, str]:
248+
"""
249+
Searches for the NVSHMEM include and library directories.
250+
Returns:
251+
A tuple containing the path to the include directory and the library directory.
252+
(include_path, lib_path)
253+
"""
254+
candidate_roots = []
255+
256+
# 1. NVSHMEM_HOME env variable
257+
if "NVSHMEM_HOME" in os.environ:
258+
candidate_roots.append(os.environ["NVSHMEM_HOME"])
259+
260+
# 2. CUDA Toolkit
261+
try:
262+
cuda_home = find_cuda_path()
263+
candidate_roots.append(cuda_home)
264+
except RuntimeError:
265+
pass
266+
267+
# 3. Other common system installation paths
268+
candidate_roots.extend(["/usr/local", "/usr"])
269+
270+
seen = set()
271+
unique_candidates = []
272+
for path in candidate_roots:
273+
if path and path not in seen:
274+
seen.add(path)
275+
unique_candidates.append(path)
276+
277+
for root in unique_candidates:
278+
include_path = os.path.join(root, "include")
279+
lib_paths_to_check = [
280+
os.path.join(root, "lib64"),
281+
os.path.join(root, "lib"),
282+
]
283+
284+
if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
285+
for lib_path in lib_paths_to_check:
286+
if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")):
287+
return include_path, lib_path
288+
289+
error_message = [
290+
"Error: Could not find NVSHMEM installation.",
291+
"Searched in the following locations:",
292+
]
293+
error_message.extend([f" - {path}" for path in unique_candidates])
294+
error_message.extend(
295+
[
296+
"",
297+
"Please ensure NVSHMEM is installed and try one of the following:",
298+
(
299+
" 1. Set the 'NVSHMEM_HOME' environment variable "
300+
"to your NVSHMEM installation directory."
301+
),
302+
(
303+
" 2. Ensure your CUDA Toolkit installation includes NVSHMEM and "
304+
"'nvcc' is on your PATH."
305+
),
306+
]
307+
)
308+
raise RuntimeError("\n".join(error_message))
309+
310+
201311
@tvm.ffi.register_func
202312
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
203313
"""use nvcc to generate fatbin code for better optimization"""

python/tvm/runtime/disco/session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def empty(
150150
The created NDArray.
151151
152152
"""
153-
if device is None:
154-
device = Device(device_type=0, device_id=0)
155153
func = self._get_cached_method("runtime.disco.empty")
156154
return func(ShapeTuple(shape), dtype, device, worker0_only, in_group)
157155

@@ -237,6 +235,12 @@ def _sync_worker(self, worker_id: int) -> None:
237235
"""
238236
return _ffi_api.SessionSyncWorker(self, worker_id) # type: ignore # pylint: disable=no-member
239237

238+
def _sync_all(self) -> None:
239+
"""Synchronize the controller with all workers in the current session, and it will
240+
wait until all workers finish executing all the existing instructions."""
241+
for i in range(self.num_workers):
242+
self._sync_worker(i)
243+
240244
def sync_worker_0(self) -> None:
241245
"""Synchronize the controller with worker-0, and it will wait until the worker-0 finishes
242246
executing all the existing instructions."""
@@ -302,8 +306,6 @@ def load_vm_module(
302306
module : DModule
303307
The loaded VM module.
304308
"""
305-
if device is None:
306-
device = Device(device_type=0, device_id=0)
307309
func = self._get_cached_method("runtime.disco.load_vm_module")
308310
return DModule(func(path, device), self)
309311

src/runtime/contrib/nvshmem/init.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,26 @@ void InitNVSHMEMWrapper(String args) {
106106
InitNVSHMEM(uid_64, num_workers, worker_id_start);
107107
}
108108

109+
void NVSHMEMXCumoduleInit(void* cuModule) {
110+
CUmodule mod = static_cast<CUmodule>(cuModule);
111+
auto status = nvshmemx_init_status();
112+
// The NVSHMEM library must have completed device initialization prior to
113+
// nvshmemx_cumodule_init. If not, we skip the cumodule initialization.
114+
if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == NVSHMEM_STATUS_LIMITED_MPG ||
115+
status == NVSHMEM_STATUS_FULL_MPG) {
116+
int result = nvshmemx_cumodule_init(mod);
117+
ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: " << result;
118+
}
119+
}
120+
109121
TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);
110122

111123
TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);
112124

113125
TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
114126
.set_body_typed(InitNVSHMEMWrapper);
115127

128+
TVM_FFI_REGISTER_GLOBAL("runtime.nvshmem.cumodule_init").set_body_typed(NVSHMEMXCumoduleInit);
129+
116130
} // namespace runtime
117131
} // namespace tvm

src/runtime/cuda/cuda_module.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ class CUDAModuleNode : public runtime::ModuleNode {
108108
// must recheck under the lock scope
109109
if (module_[device_id] == nullptr) {
110110
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
111+
static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
112+
if (nvshmem_init_hook.has_value()) {
113+
(*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
114+
}
111115
}
112116
CUfunction func;
113117
CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
@@ -124,6 +128,10 @@ class CUDAModuleNode : public runtime::ModuleNode {
124128
// must recheck under the lock scope
125129
if (module_[device_id] == nullptr) {
126130
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
131+
static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
132+
if (nvshmem_init_hook.has_value()) {
133+
(*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
134+
}
127135
}
128136
CUdeviceptr global;
129137
size_t nbytes;

src/runtime/disco/builtin.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,27 @@ class DSOLibraryCache {
4646
std::mutex mutex_;
4747
};
4848

49-
Module LoadVMModule(std::string path, Device device) {
49+
Module LoadVMModule(std::string path, Optional<Device> device) {
5050
static DSOLibraryCache cache;
5151
Module dso_mod = cache.Open(path);
52-
device = UseDefaultDeviceIfNone(device);
52+
Device dev = UseDefaultDeviceIfNone(device);
5353
ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable");
54-
CHECK(vm_load_executable != nullptr)
55-
<< "ValueError: File `" << path
56-
<< "` is not built by RelaxVM, because `vm_load_executable` does not exist";
54+
if (vm_load_executable == nullptr) {
55+
// not built by RelaxVM, return the dso_mod directly
56+
return dso_mod;
57+
}
5758
auto mod = vm_load_executable().cast<Module>();
5859
ffi::Function vm_initialization = mod.GetFunction("vm_initialization");
5960
CHECK(vm_initialization != nullptr)
6061
<< "ValueError: File `" << path
6162
<< "` is not built by RelaxVM, because `vm_initialization` does not exist";
62-
vm_initialization(static_cast<int>(device.device_type), static_cast<int>(device.device_id),
63+
vm_initialization(static_cast<int>(dev.device_type), static_cast<int>(dev.device_id),
6364
static_cast<int>(AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
6465
static_cast<int>(AllocatorType::kPooled));
6566
return mod;
6667
}
6768

68-
NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device) {
69+
NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional<Device> device) {
6970
return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device));
7071
}
7172

@@ -123,7 +124,7 @@ void SyncWorker() {
123124
TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
124125

125126
TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty")
126-
.set_body_typed([](ffi::Shape shape, DataType dtype, Device device, bool worker0_only,
127+
.set_body_typed([](ffi::Shape shape, DataType dtype, Optional<Device> device, bool worker0_only,
127128
bool in_group) -> Optional<NDArray> {
128129
int worker_id = WorkerId();
129130
int group_size =

src/runtime/disco/utils.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,8 @@
2727
namespace tvm {
2828
namespace runtime {
2929

30-
inline Device UseDefaultDeviceIfNone(Device device) {
31-
if (device.device_type == 0 && device.device_id == 0) {
32-
return DiscoWorker::ThreadLocal()->default_device;
33-
}
34-
return device;
30+
inline Device UseDefaultDeviceIfNone(Optional<Device> device) {
31+
return device.value_or(DiscoWorker::ThreadLocal()->default_device);
3532
}
3633

3734
/*!

src/target/source/codegen_cuda.cc

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,10 @@ std::string CodeGenCUDA::Finish() {
297297
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
298298
decl_stream << "#endif\n";
299299

300-
decl_stream << "\n#ifdef _WIN32\n";
301-
decl_stream << " using uint = unsigned int;\n";
302-
decl_stream << " using uchar = unsigned char;\n";
303-
decl_stream << " using ushort = unsigned short;\n";
304-
decl_stream << " using int64_t = long long;\n";
305-
decl_stream << " using uint64_t = unsigned long long;\n";
306-
decl_stream << "#else\n";
307-
decl_stream << " #define uint unsigned int\n";
308-
decl_stream << " #define uchar unsigned char\n";
309-
decl_stream << " #define ushort unsigned short\n";
310-
decl_stream << " #define int64_t long long\n";
311-
decl_stream << " #define uint64_t unsigned long long\n";
312-
decl_stream << "#endif\n";
300+
decl_stream << "#include <cstdint>\n";
301+
decl_stream << "using uint = unsigned int;\n";
302+
decl_stream << "using uchar = unsigned char;\n";
303+
decl_stream << "using ushort = unsigned short;\n";
313304

314305
return CodeGenC::Finish();
315306
}

0 commit comments

Comments
 (0)