2121import os
2222import subprocess
2323import warnings
24+ from typing import Tuple
2425
2526import tvm .ffi
2627from tvm .target import Target
2930from . import utils
3031
3132
32- def compile_cuda (code , target_format = "ptx" , arch = None , options = None , path_target = None ):
33+ def compile_cuda (code , target_format = None , arch = None , options = None , path_target = None ):
3334 """Compile cuda code with NVCC from env.
3435
3536 Parameters
@@ -54,6 +55,15 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
5455 cubin : bytearray
5556 The bytearray of the cubin
5657 """
58+ # Check for NVSHMEM dependency
59+ nvshmem_include_path , nvshmem_lib_path = None , None
60+ use_nvshmem = (
61+ tvm .get_global_func ("runtime.nvshmem.cumodule_init" , allow_missing = True ) is not None
62+ )
63+ if use_nvshmem :
64+ target_format = "cubin"
65+ nvshmem_include_path , nvshmem_lib_path = find_nvshmem_paths ()
66+
5767 if arch is None :
5868 # If None, then it will use `tvm.target.Target.current().arch`.
5969 # Target arch could be a str like "sm_xx", or a list, such as
@@ -68,6 +78,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
6878
6979 temp = utils .tempdir ()
7080 file_name = "tvm_kernels"
81+ if target_format is None and not use_nvshmem :
82+ target_format = "ptx"
7183 if target_format not in ["cubin" , "ptx" , "fatbin" ]:
7284 raise ValueError ("target_format must be in cubin, ptx, fatbin" )
7385 temp_code = temp .relpath (f"{ file_name } .cu" )
@@ -89,6 +101,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
89101 out_file .write (code )
90102
91103 file_target = path_target if path_target else temp_target
104+ if use_nvshmem :
105+ file_prefix = file_target .split ("." )[0 ]
106+ file_target = f"{ file_prefix } .o" # in the first stage, compile to object file
92107 cmd = ["nvcc" ]
93108 cmd += [f"--{ target_format } " , "-O3" ]
94109 if kernels_output_dir is not None :
@@ -107,7 +122,12 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
107122 raise ValueError ("options must be str or list of str" )
108123
109124 cmd += ["-o" , file_target ]
110- cmd += [temp_code ]
125+ if not use_nvshmem :
126+ cmd += [temp_code ]
127+ else :
128+ cmd += ["-c" , temp_code ]
129+ cmd += ["-rdc=true" ]
130+ cmd += ["-I" , nvshmem_include_path ]
111131
112132 # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
113133 # just in case it is not in the path. On Windows it is not in the path by default.
@@ -127,6 +147,32 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
127147 msg += py_str (out )
128148 raise RuntimeError (msg )
129149
150+ # start second stage of compilation
151+ if use_nvshmem :
152+ cmd = ["nvlink" ]
153+ cmd += [f"-arch=sm_{ compute_version } " ]
154+ cmd += [
155+ "-L" ,
156+ nvshmem_lib_path ,
157+ ]
158+ cmd += ["-L" , os .path .join (find_cuda_path (), "lib64" )]
159+ cmd += ["-l" , "nvshmem_device" ]
160+ cmd += ["-l" , "cudadevrt" ]
161+ cmd += ["-o" , f"{ file_prefix } .cubin" ]
162+ cmd += [file_target ]
163+
164+ proc = subprocess .Popen (cmd , stdout = subprocess .PIPE , stderr = subprocess .STDOUT )
165+
166+ (out , _ ) = proc .communicate ()
167+
168+ if proc .returncode != 0 :
169+ msg = code
170+ msg += "\n Compilation error:\n "
171+ msg += py_str (out )
172+ raise RuntimeError (msg )
173+
174+ file_target = f"{ file_prefix } .cubin"
175+
130176 with open (file_target , "rb" ) as f :
131177 data = bytearray (f .read ())
132178 if not data :
@@ -198,6 +244,70 @@ def get_cuda_version(cuda_path=None):
198244 raise RuntimeError ("Cannot read cuda version file" )
199245
200246
247+ def find_nvshmem_paths () -> Tuple [str , str ]:
248+ """
249+ Searches for the NVSHMEM include and library directories.
250+ Returns:
251+ A tuple containing the path to the include directory and the library directory.
252+ (include_path, lib_path)
253+ """
254+ candidate_roots = []
255+
256+ # 1. NVSHMEM_HOME env variable
257+ if "NVSHMEM_HOME" in os .environ :
258+ candidate_roots .append (os .environ ["NVSHMEM_HOME" ])
259+
260+ # 2. CUDA Toolkit
261+ try :
262+ cuda_home = find_cuda_path ()
263+ candidate_roots .append (cuda_home )
264+ except RuntimeError :
265+ pass
266+
267+ # 3. Other common system installation paths
268+ candidate_roots .extend (["/usr/local" , "/usr" ])
269+
270+ seen = set ()
271+ unique_candidates = []
272+ for path in candidate_roots :
273+ if path and path not in seen :
274+ seen .add (path )
275+ unique_candidates .append (path )
276+
277+ for root in unique_candidates :
278+ include_path = os .path .join (root , "include" )
279+ lib_paths_to_check = [
280+ os .path .join (root , "lib64" ),
281+ os .path .join (root , "lib" ),
282+ ]
283+
284+ if os .path .isfile (os .path .join (include_path , "nvshmem.h" )):
285+ for lib_path in lib_paths_to_check :
286+ if os .path .isfile (os .path .join (lib_path , "libnvshmem.a" )):
287+ return include_path , lib_path
288+
289+ error_message = [
290+ "Error: Could not find NVSHMEM installation." ,
291+ "Searched in the following locations:" ,
292+ ]
293+ error_message .extend ([f" - { path } " for path in unique_candidates ])
294+ error_message .extend (
295+ [
296+ "" ,
297+ "Please ensure NVSHMEM is installed and try one of the following:" ,
298+ (
299+ " 1. Set the 'NVSHMEM_HOME' environment variable "
300+ "to your NVSHMEM installation directory."
301+ ),
302+ (
303+ " 2. Ensure your CUDA Toolkit installation includes NVSHMEM and "
304+ "'nvcc' is on your PATH."
305+ ),
306+ ]
307+ )
308+ raise RuntimeError ("\n " .join (error_message ))
309+
310+
201311@tvm .ffi .register_func
202312def tvm_callback_cuda_compile (code , target ): # pylint: disable=unused-argument
203313 """use nvcc to generate fatbin code for better optimization"""
0 commit comments