Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions ffi/python/tvm_ffi/cpp/load_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,12 @@ def load_inline(
cuda_sources: Sequence[str] | str, optional
The CUDA source code. It can be a list of sources or a single source.
functions: Mapping[str, str] | Sequence[str] | str, optional
The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
single string is given, they are the functions needed to be exported, and the docstrings are set to empty
strings. A single function name can also be given as a string.
The functions in cpp_sources or cuda_source that will be exported to the tvm ffi module. When a mapping is
given, the keys are the names of the exported functions, and the values are docstrings for the functions. When
a sequence or a single string is given, they are the functions needed to be exported, and the docstrings are set
to empty strings. A single function name can also be given as a string. When cpp_sources is given, the functions
must be declared (not necessarily defined) in the cpp_sources. When cpp_sources is not given, the functions
must be defined in the cuda_sources. If not specified, no function will be exported.
extra_cflags: Sequence[str], optional
The extra compiler flags for C++ compilation.
The default flags are:
Expand Down Expand Up @@ -369,6 +371,7 @@ def load_inline(
elif isinstance(cuda_sources, str):
cuda_sources = [cuda_sources]
cuda_source = "\n".join(cuda_sources)
with_cpp = len(cpp_sources) > 0
with_cuda = len(cuda_sources) > 0

extra_ldflags = extra_ldflags or []
Expand All @@ -381,8 +384,13 @@ def load_inline(
functions = {functions: ""}
elif isinstance(functions, Sequence):
functions = {name: "" for name in functions}
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})

if with_cpp:
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
else:
cpp_source = _decorate_with_tvm_ffi(cpp_source, {})
cuda_source = _decorate_with_tvm_ffi(cuda_source, functions)

# determine the cache dir for the built module
if build_directory is None:
Expand Down
3 changes: 0 additions & 3 deletions ffi/tests/python/test_load_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def test_load_inline_cpp_build_dir():
def test_load_inline_cuda():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
void add_one_cuda(DLTensor* x, DLTensor* y);
""",
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down