Skip to content

Commit e153dad

Browse files
authored
[FFI] Update load_inline interface (apache#18307)
update load_inline interface
1 parent c35eaf1 commit e153dad

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

python/tvm_ffi/cpp/load_inline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,12 @@ def load_inline(
326326
cuda_sources: Sequence[str] | str, optional
327327
The CUDA source code. It can be a list of sources or a single source.
328328
functions: Mapping[str, str] | Sequence[str] | str, optional
329-
The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
330-
are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
331-
single string is given, they are the functions needed to be exported, and the docstrings are set to empty
332-
strings. A single function name can also be given as a string.
329+
The functions in cpp_sources or cuda_source that will be exported to the tvm ffi module. When a mapping is
330+
given, the keys are the names of the exported functions, and the values are docstrings for the functions. When
331+
a sequence or a single string is given, they are the functions needed to be exported, and the docstrings are set
332+
to empty strings. A single function name can also be given as a string. When cpp_sources is given, the functions
333+
must be declared (not necessarily defined) in the cpp_sources. When cpp_sources is not given, the functions
334+
must be defined in the cuda_sources. If not specified, no function will be exported.
333335
extra_cflags: Sequence[str], optional
334336
The extra compiler flags for C++ compilation.
335337
The default flags are:
@@ -369,6 +371,7 @@ def load_inline(
369371
elif isinstance(cuda_sources, str):
370372
cuda_sources = [cuda_sources]
371373
cuda_source = "\n".join(cuda_sources)
374+
with_cpp = len(cpp_sources) > 0
372375
with_cuda = len(cuda_sources) > 0
373376

374377
extra_ldflags = extra_ldflags or []
@@ -381,8 +384,13 @@ def load_inline(
381384
functions = {functions: ""}
382385
elif isinstance(functions, Sequence):
383386
functions = {name: "" for name in functions}
384-
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
385-
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
387+
388+
if with_cpp:
389+
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
390+
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
391+
else:
392+
cpp_source = _decorate_with_tvm_ffi(cpp_source, {})
393+
cuda_source = _decorate_with_tvm_ffi(cuda_source, functions)
386394

387395
# determine the cache dir for the built module
388396
if build_directory is None:

tests/python/test_load_inline.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def test_load_inline_cpp_build_dir():
159159
def test_load_inline_cuda():
160160
mod: Module = tvm_ffi.cpp.load_inline(
161161
name="hello",
162-
cpp_sources=r"""
163-
void add_one_cuda(DLTensor* x, DLTensor* y);
164-
""",
165162
cuda_sources=r"""
166163
__global__ void AddOneKernel(float* x, float* y, int n) {
167164
int idx = blockIdx.x * blockDim.x + threadIdx.x;

0 commit comments

Comments
 (0)