@@ -326,10 +326,12 @@ def load_inline(
326326 cuda_sources: Sequence[str] | str, optional
327327 The CUDA source code. It can be a list of sources or a single source.
328328 functions: Mapping[str, str] | Sequence[str] | str, optional
329- The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
330- are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
331- single string is given, they are the functions needed to be exported, and the docstrings are set to empty
332- strings. A single function name can also be given as a string.
329+ The functions in cpp_sources or cuda_source that will be exported to the tvm ffi module. When a mapping is
330+ given, the keys are the names of the exported functions, and the values are docstrings for the functions. When
331+ a sequence or a single string is given, they are the functions needed to be exported, and the docstrings are set
332+ to empty strings. A single function name can also be given as a string. When cpp_sources is given, the functions
333+ must be declared (not necessarily defined) in the cpp_sources. When cpp_sources is not given, the functions
334+ must be defined in the cuda_sources. If not specified, no function will be exported.
333335 extra_cflags: Sequence[str], optional
334336 The extra compiler flags for C++ compilation.
335337 The default flags are:
@@ -369,6 +371,7 @@ def load_inline(
369371 elif isinstance (cuda_sources , str ):
370372 cuda_sources = [cuda_sources ]
371373 cuda_source = "\n " .join (cuda_sources )
374+ with_cpp = len (cpp_sources ) > 0
372375 with_cuda = len (cuda_sources ) > 0
373376
374377 extra_ldflags = extra_ldflags or []
@@ -381,8 +384,13 @@ def load_inline(
381384 functions = {functions : "" }
382385 elif isinstance (functions , Sequence ):
383386 functions = {name : "" for name in functions }
384- cpp_source = _decorate_with_tvm_ffi (cpp_source , functions )
385- cuda_source = _decorate_with_tvm_ffi (cuda_source , {})
387+
388+ if with_cpp :
389+ cpp_source = _decorate_with_tvm_ffi (cpp_source , functions )
390+ cuda_source = _decorate_with_tvm_ffi (cuda_source , {})
391+ else :
392+ cpp_source = _decorate_with_tvm_ffi (cpp_source , {})
393+ cuda_source = _decorate_with_tvm_ffi (cuda_source , functions )
386394
387395 # determine the cache dir for the built module
388396 if build_directory is None :
0 commit comments