Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ endif()


if (USE_CUDA AND USE_NVSHMEM)
include_directories(SYSTEM ${USE_NVSHMEM}/include)
target_include_directories(tvm_runtime_objs PUBLIC ${NVSHMEM_INCLUDE_DIR})
find_library(NVSHMEM_HOST nvshmem_host ${NVSHMEM_LIB_DIR})
find_library(NVSHMEM_DEVICE nvshmem_device ${NVSHMEM_LIB_DIR})
target_link_libraries(tvm PRIVATE ${NVSHMEM_HOST} ${NVSHMEM_DEVICE})
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ inline std::string ReduceKind2String(ReduceKind kind) {
* \param device The default device used to initialize the RelaxVM
* \return The RelaxVM as a runtime Module
*/
TVM_DLL Module LoadVMModule(std::string path, Device device);
TVM_DLL Module LoadVMModule(std::string path, Optional<Device> device);
/*!
* \brief Create an uninitialized empty NDArray
* \param shape The shape of the NDArray
* \param dtype The dtype of the NDArray
* \param device The device the NDArray is created on. If None, use the thread local default device
* \return The NDArray created
*/
TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device);
TVM_DLL NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional<Device> device);
/*!
* \brief Perform an allreduce operation using the underlying communication library
* \param send The array send to perform allreduce on
Expand Down
114 changes: 112 additions & 2 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import subprocess
import warnings
from typing import Tuple

import tvm.ffi
from tvm.target import Target
Expand All @@ -29,7 +30,7 @@
from . import utils


def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
def compile_cuda(code, target_format=None, arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env.

Parameters
Expand All @@ -54,6 +55,15 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
cubin : bytearray
The bytearray of the cubin
"""
# Check for NVSHMEM dependency
nvshmem_include_path, nvshmem_lib_path = None, None
use_nvshmem = (
tvm.get_global_func("runtime.nvshmem.cumodule_init", allow_missing=True) is not None
)
if use_nvshmem:
target_format = "cubin"
nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()

if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
Expand All @@ -68,6 +78,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target

temp = utils.tempdir()
file_name = "tvm_kernels"
if target_format is None and not use_nvshmem:
target_format = "ptx"
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath(f"{file_name}.cu")
Expand All @@ -89,6 +101,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
out_file.write(code)

file_target = path_target if path_target else temp_target
if use_nvshmem:
file_prefix = file_target.split(".")[0]
file_target = f"{file_prefix}.o" # in the first stage, compile to object file
cmd = ["nvcc"]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
Expand All @@ -107,7 +122,12 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
raise ValueError("options must be str or list of str")

cmd += ["-o", file_target]
cmd += [temp_code]
if not use_nvshmem:
cmd += [temp_code]
else:
cmd += ["-c", temp_code]
cmd += ["-rdc=true"]
cmd += ["-I", nvshmem_include_path]

# NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
# just in case it is not in the path. On Windows it is not in the path by default.
Expand All @@ -127,6 +147,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
msg += py_str(out)
raise RuntimeError(msg)

# start second stage of compilation
if use_nvshmem:
cmd = ["nvlink"]
cmd += [f"-arch=sm_{compute_version}"]
cmd += [
"-L",
nvshmem_lib_path,
]
cmd += ["-L", os.path.join(find_cuda_path(), "lib64")]
cmd += ["-l", "nvshmem_device"]
cmd += ["-l", "cudadevrt"]
cmd += ["-o", f"{file_prefix}.cubin"]
cmd += [file_target]

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

(out, _) = proc.communicate()

if proc.returncode != 0:
msg = code
msg += "\nCompilation error:\n"
msg += py_str(out)
raise RuntimeError(msg)

file_target = f"{file_prefix}.cubin"

with open(file_target, "rb") as f:
data = bytearray(f.read())
if not data:
Expand Down Expand Up @@ -198,6 +244,70 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")


def find_nvshmem_paths() -> Tuple[str, str]:
"""
Searches for the NVSHMEM include and library directories.
Returns:
A tuple containing the path to the include directory and the library directory.
(include_path, lib_path)
"""
candidate_roots = []

# 1. NVSHMEM_HOME env variable
if "NVSHMEM_HOME" in os.environ:
candidate_roots.append(os.environ["NVSHMEM_HOME"])

# 2. CUDA Toolkit
try:
cuda_home = find_cuda_path()
candidate_roots.append(cuda_home)
except RuntimeError:
pass

# 3. Other common system installation paths
candidate_roots.extend(["/usr/local", "/usr"])

seen = set()
unique_candidates = []
for path in candidate_roots:
if path and path not in seen:
seen.add(path)
unique_candidates.append(path)

for root in unique_candidates:
include_path = os.path.join(root, "include")
lib_paths_to_check = [
os.path.join(root, "lib64"),
os.path.join(root, "lib"),
]

if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
for lib_path in lib_paths_to_check:
if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")):
return include_path, lib_path

error_message = [
"Error: Could not find NVSHMEM installation.",
"Searched in the following locations:",
]
error_message.extend([f" - {path}" for path in unique_candidates])
error_message.extend(
[
"",
"Please ensure NVSHMEM is installed and try one of the following:",
(
" 1. Set the 'NVSHMEM_HOME' environment variable "
"to your NVSHMEM installation directory."
),
(
" 2. Ensure your CUDA Toolkit installation includes NVSHMEM and "
"'nvcc' is on your PATH."
),
]
)
raise RuntimeError("\n".join(error_message))


@tvm.ffi.register_func
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def empty(
The created NDArray.

"""
if device is None:
device = Device(device_type=0, device_id=0)
func = self._get_cached_method("runtime.disco.empty")
return func(ShapeTuple(shape), dtype, device, worker0_only, in_group)

Expand Down Expand Up @@ -237,6 +235,12 @@ def _sync_worker(self, worker_id: int) -> None:
"""
return _ffi_api.SessionSyncWorker(self, worker_id) # type: ignore # pylint: disable=no-member

def _sync_all(self) -> None:
"""Synchronize the controller with all workers in the current session, and it will
wait until all workers finish executing all the existing instructions."""
for i in range(self.num_workers):
self._sync_worker(i)

def sync_worker_0(self) -> None:
"""Synchronize the controller with worker-0, and it will wait until the worker-0 finishes
executing all the existing instructions."""
Expand Down Expand Up @@ -302,8 +306,6 @@ def load_vm_module(
module : DModule
The loaded VM module.
"""
if device is None:
device = Device(device_type=0, device_id=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change intentional?

Copy link
Contributor Author

@Kathryn-cat Kathryn-cat Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in the latest dlpack, DLDeviceType has enum values 1 to 17. The valueDevice(device_type=0, device_id=0) would raise an error of unrecognized device type. Since it is meant to indicate a Null value, I replace the subsequent usage with Optional<Device> type, see the changes of UseDefaultDeviceIfNone.

func = self._get_cached_method("runtime.disco.load_vm_module")
return DModule(func(path, device), self)

Expand Down
14 changes: 14 additions & 0 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,26 @@ void InitNVSHMEMWrapper(String args) {
InitNVSHMEM(uid_64, num_workers, worker_id_start);
}

void NVSHMEMXCumoduleInit(void* cuModule) {
CUmodule mod = static_cast<CUmodule>(cuModule);
auto status = nvshmemx_init_status();
// The NVSHMEM library must have completed device initialization prior to
// nvshmemx_cumodule_init. If not, we skip the cumodule initialization.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not device initialized, we should return with error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design here is to enable NVSHMEM compilation and linking broadly for every kernel, including those whose NVSHMEM context is not initialized and do not use NVSHMEM in their kernels.

In such case, nvshmemx_init_status() is used to check whether we need to call nvshmemx_cumodule_init or not. If not device initialized, we just skip nvshmemx_cumodule_init.

if (status == NVSHMEM_STATUS_IS_INITIALIZED || status == NVSHMEM_STATUS_LIMITED_MPG ||
status == NVSHMEM_STATUS_FULL_MPG) {
int result = nvshmemx_cumodule_init(mod);
ICHECK_EQ(result, 0) << "nvshmemx_cumodule_init failed with error code: " << result;
}
}

TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);

TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);

TVM_FFI_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
.set_body_typed(InitNVSHMEMWrapper);

TVM_FFI_REGISTER_GLOBAL("runtime.nvshmem.cumodule_init").set_body_typed(NVSHMEMXCumoduleInit);

} // namespace runtime
} // namespace tvm
8 changes: 8 additions & 0 deletions src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class CUDAModuleNode : public runtime::ModuleNode {
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
if (nvshmem_init_hook.has_value()) {
(*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
}
}
CUfunction func;
CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
Expand All @@ -124,6 +128,10 @@ class CUDAModuleNode : public runtime::ModuleNode {
// must recheck under the lock scope
if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
static auto nvshmem_init_hook = ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
if (nvshmem_init_hook.has_value()) {
(*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
}
}
CUdeviceptr global;
size_t nbytes;
Expand Down
17 changes: 9 additions & 8 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,27 @@ class DSOLibraryCache {
std::mutex mutex_;
};

Module LoadVMModule(std::string path, Device device) {
Module LoadVMModule(std::string path, Optional<Device> device) {
static DSOLibraryCache cache;
Module dso_mod = cache.Open(path);
device = UseDefaultDeviceIfNone(device);
Device dev = UseDefaultDeviceIfNone(device);
ffi::Function vm_load_executable = dso_mod.GetFunction("vm_load_executable");
CHECK(vm_load_executable != nullptr)
<< "ValueError: File `" << path
<< "` is not built by RelaxVM, because `vm_load_executable` does not exist";
if (vm_load_executable == nullptr) {
// not built by RelaxVM, return the dso_mod directly
return dso_mod;
}
auto mod = vm_load_executable().cast<Module>();
ffi::Function vm_initialization = mod.GetFunction("vm_initialization");
CHECK(vm_initialization != nullptr)
<< "ValueError: File `" << path
<< "` is not built by RelaxVM, because `vm_initialization` does not exist";
vm_initialization(static_cast<int>(device.device_type), static_cast<int>(device.device_id),
vm_initialization(static_cast<int>(dev.device_type), static_cast<int>(dev.device_id),
static_cast<int>(AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
static_cast<int>(AllocatorType::kPooled));
return mod;
}

NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Device device) {
NDArray DiscoEmptyNDArray(ffi::Shape shape, DataType dtype, Optional<Device> device) {
return NDArray::Empty(shape, dtype, UseDefaultDeviceIfNone(device));
}

Expand Down Expand Up @@ -123,7 +124,7 @@ void SyncWorker() {
TVM_FFI_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);

TVM_FFI_REGISTER_GLOBAL("runtime.disco.empty")
.set_body_typed([](ffi::Shape shape, DataType dtype, Device device, bool worker0_only,
.set_body_typed([](ffi::Shape shape, DataType dtype, Optional<Device> device, bool worker0_only,
bool in_group) -> Optional<NDArray> {
int worker_id = WorkerId();
int group_size =
Expand Down
7 changes: 2 additions & 5 deletions src/runtime/disco/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@
namespace tvm {
namespace runtime {

inline Device UseDefaultDeviceIfNone(Device device) {
if (device.device_type == 0 && device.device_id == 0) {
return DiscoWorker::ThreadLocal()->default_device;
}
return device;
inline Device UseDefaultDeviceIfNone(Optional<Device> device) {
return device.value_or(DiscoWorker::ThreadLocal()->default_device);
}

/*!
Expand Down
17 changes: 4 additions & 13 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
decl_stream << "#endif\n";

decl_stream << "\n#ifdef _WIN32\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because NVSHMEM contains #include <cstdint>, which is in conflict with the original #define int64_t long long and could lead to CUDA compilation error. The #define semantics is quite error prone, so I remove it and just do using.

decl_stream << " using uint = unsigned int;\n";
decl_stream << " using uchar = unsigned char;\n";
decl_stream << " using ushort = unsigned short;\n";
decl_stream << " using int64_t = long long;\n";
decl_stream << " using uint64_t = unsigned long long;\n";
decl_stream << "#else\n";
decl_stream << " #define uint unsigned int\n";
decl_stream << " #define uchar unsigned char\n";
decl_stream << " #define ushort unsigned short\n";
decl_stream << " #define int64_t long long\n";
decl_stream << " #define uint64_t unsigned long long\n";
decl_stream << "#endif\n";
decl_stream << "#include <cstdint>\n";
decl_stream << "using uint = unsigned int;\n";
decl_stream << "using uchar = unsigned char;\n";
decl_stream << "using ushort = unsigned short;\n";

return CodeGenC::Finish();
}
Expand Down
Loading