3434def _hash_sources (
3535 cpp_source : str ,
3636 cuda_source : str ,
37- cpp_functions : Mapping [str , str ],
38- cuda_functions : Mapping [str , str ],
37+ functions : Sequence [str ] | Mapping [str , str ],
3938 extra_cflags : Sequence [str ],
4039 extra_cuda_cflags : Sequence [str ],
4140 extra_ldflags : Sequence [str ],
@@ -45,12 +44,13 @@ def _hash_sources(
4544 m = hashlib .sha256 ()
4645 m .update (cpp_source .encode ("utf-8" ))
4746 m .update (cuda_source .encode ("utf-8" ))
48- for name , doc in sorted (cpp_functions .items ()):
49- m .update (name .encode ("utf-8" ))
50- m .update (doc .encode ("utf-8" ))
51- for name , doc in sorted (cuda_functions .items ()):
52- m .update (name .encode ("utf-8" ))
53- m .update (doc .encode ("utf-8" ))
47+ if isinstance (functions , Mapping ):
48+ for name in sorted (functions ):
49+ m .update (name .encode ("utf-8" ))
50+ m .update (functions [name ].encode ("utf-8" ))
51+ else :
52+ for name in sorted (functions ):
53+ m .update (name .encode ("utf-8" ))
5454 for flag in extra_cflags :
5555 m .update (flag .encode ("utf-8" ))
5656 for flag in extra_cuda_cflags :
@@ -242,8 +242,10 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
242242 source ,
243243 ]
244244
245- for exported_name , func_name_in_source in functions .items ():
246- sources .append (f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({ exported_name } , { func_name_in_source } );" )
245+ for func_name , func_doc in functions .items ():
246+ sources .append (f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({ func_name } , { func_name } );" )
247+ _ = func_doc # todo: add support to embed function docstring to the tvm ffi functions.
248+
247249 sources .append ("" )
248250
249251 return "\n " .join (sources )
@@ -252,26 +254,26 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
252254def load_inline (
253255 name : str ,
254256 * ,
255- cpp_source : str | None = None ,
256- cuda_source : str | None = None ,
257- cpp_functions : Mapping [str , str ] | None = None ,
258- cuda_functions : Mapping [str , str ] | None = None ,
257+ cpp_sources : str | None = None ,
258+ cuda_sources : str | None = None ,
259+ functions : Sequence [str ] | None = None ,
259260 extra_cflags : Sequence [str ] | None = None ,
260261 extra_cuda_cflags : Sequence [str ] | None = None ,
261262 extra_ldflags : Sequence [str ] | None = None ,
262263 extra_include_paths : Sequence [str ] | None = None ,
264+ build_directory : Optional [str ] = None ,
263265) -> Module :
264266 """Compile and load a C++/CUDA tvm ffi module from inline source code.
265267
266- This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_source and cuda_source
267- are compiled to an object file, and then linked together into a shared library. It's possible to only provide
268- cpp_source or cuda_source .
268+ This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_sources and
269+ cuda_sources are compiled to an object file, and then linked together into a shared library. It's possible to only
270+ provide cpp_sources or cuda_sources .
269271
270- The `cpp_functions` and `cuda_functions` parameters are used to specify which functions in the source code
271- should be exported to the tvm ffi module. The keys of the mapping are the names of the exported functions, and the
272- values are the names of the functions in the source code. The exported name and the function name in the source code
273- must be different. The exported name must be a valid C identifier while the function name in the source code can
274- contain namespace qualifiers .
272+ The `functions` parameter is used to specify which functions in the source code should be exported to the tvm ffi module.
273+ It can be a mapping, a sequence, or a single string. When a mapping is given, the keys are the names of the exported
274+ functions, and the values are docstrings for the functions. When a sequence or a single string is given, they are the
275+ functions needed to be exported, and the docstrings are set to empty strings. A single function name can also be given
276+ as a string, indicating that only one function is to be exported .
275277
276278 Extra compiler and linker flags can be provided via the `extra_cflags`, `extra_cuda_cflags`, and `extra_ldflags`
277279 parameters. The default flags are generally sufficient for most use cases, but you may need to provide additional
@@ -281,22 +283,24 @@ def load_inline(
281283 any header from tvm ffi and dlpack in your source code. You can also provide additional include paths via the
282284 `extra_include_paths` parameter and include custom headers in your source code.
283285
284- The compiled shared library is cached in a cache directory to avoid recompilation. The cache directory can be
285- specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified, the default cache directory is
286- `~/.cache/tvm-ffi`.
286+ The compiled shared library is cached in a cache directory to avoid recompilation. The `build_directory` parameter
287+ is provided to specify the build directory. If not specified, a default tvm ffi cache directory will be used.
288+ The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified,
289+ the default cache directory is `~/.cache/tvm-ffi`.
287290
288291 Parameters
289292 ----------
290293 name: str
291294 The name of the tvm ffi module.
292- cpp_source: str, optional
293- The C++ source code.
294- cuda_source: str, optional
295- The CUDA source code.
296- cpp_functions: Mapping[str, str], optional
297- The mapping from the exported function name to the function name in the C++ source code.
298- cuda_functions: Mapping[str, str], optional
299- The mapping from the exported function name to the function name in the CUDA source code.
295+ cpp_sources: Sequence[str] | str, optional
296+ The C++ source code. It can be a list of sources or a single source.
297+ cuda_sources: Sequence[str] | str, optional
298+ The CUDA source code. It can be a list of sources or a single source.
299+ functions: Mapping[str, str] | Sequence[str] | str, optional
300+ The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
301+ are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
302+ single string is given, they are the functions needed to be exported, and the docstrings are set to empty
303+ strings. A single function name can also be given as a string.
300304 extra_cflags: Sequence[str], optional
301305 The extra compiler flags for C++ compilation.
302306 The default flags are:
@@ -316,46 +320,58 @@ def load_inline(
316320 The extra include paths.
317321 The default include paths are:
318322 - The include path of tvm ffi
323+ build_directory: str, optional
324+ The build directory. If not specified, a default tvm ffi cache directory will be used. By default, the
325+ cache directory is `~/.cache/tvm-ffi`. You can also set the `TVM_FFI_CACHE_DIR` environment variable to
326+ specify the cache directory.
327+
319328 Returns
320329 -------
321330 mod: Module
322331 The loaded tvm ffi module.
323332 """
324- if cpp_source is None :
325- cpp_source = ""
326- if cuda_source is None :
327- cuda_source = ""
328- if cpp_functions is None :
329- cpp_functions = {}
330- if cuda_functions is None :
331- cuda_functions = {}
333+ if cpp_sources is None :
334+ cpp_sources = []
335+ elif isinstance (cpp_sources , str ):
336+ cpp_sources = [cpp_sources ]
337+ cpp_source = "\n " .join (cpp_sources )
338+ if cuda_sources is None :
339+ cuda_sources = []
340+ elif isinstance (cuda_sources , str ):
341+ cuda_sources = [cuda_sources ]
342+ cuda_source = "\n " .join (cuda_sources )
343+ with_cuda = len (cuda_sources ) > 0
344+
332345 extra_ldflags = extra_ldflags or []
333346 extra_cflags = extra_cflags or []
334347 extra_cuda_cflags = extra_cuda_cflags or []
335348 extra_include_paths = extra_include_paths or []
336349
337- # whether we have cuda source in this module
338- with_cuda = len (cuda_source .strip ()) > 0
339-
340350 # add function registration code to sources
341- cpp_source = _decorate_with_tvm_ffi (cpp_source , cpp_functions )
342- cuda_source = _decorate_with_tvm_ffi (cuda_source , cuda_functions )
351+ if isinstance (functions , str ):
352+ functions = {functions : "" }
353+ elif isinstance (functions , Sequence ):
354+ functions = {name : "" for name in functions }
355+ cpp_source = _decorate_with_tvm_ffi (cpp_source , functions )
356+ cuda_source = _decorate_with_tvm_ffi (cuda_source , {})
343357
344358 # determine the cache dir for the built module
345- cache_dir = os .path .join (
346- os .environ .get ("TVM_FFI_CACHE_DIR" , os .path .expanduser ("~/.cache/tvm-ffi" ))
347- )
348- source_hash : str = _hash_sources (
349- cpp_source ,
350- cuda_source ,
351- cpp_functions ,
352- cuda_functions ,
353- extra_cflags ,
354- extra_cuda_cflags ,
355- extra_ldflags ,
356- extra_include_paths ,
357- )
358- build_dir : str = os .path .join (cache_dir , "{}_{}" .format (name , source_hash ))
359+ if build_directory is None :
360+ build_directory = os .environ .get (
361+ "TVM_FFI_CACHE_DIR" , os .path .expanduser ("~/.cache/tvm-ffi" )
362+ )
363+ source_hash : str = _hash_sources (
364+ cpp_source ,
365+ cuda_source ,
366+ functions ,
367+ extra_cflags ,
368+ extra_cuda_cflags ,
369+ extra_ldflags ,
370+ extra_include_paths ,
371+ )
372+ build_dir : str = os .path .join (build_directory , "{}_{}" .format (name , source_hash ))
373+ else :
374+ build_dir = os .path .abspath (build_directory )
359375 os .makedirs (build_dir , exist_ok = True )
360376
361377 # generate build.ninja
0 commit comments