Skip to content

Commit d3a5811

Browse files
authored
[FFI] Update the interface of ffi.load_inline to match torch (#18274)
This PR update the interface of ffi.load_inline to match torch.utils.cpp_extensions.load_inline: - Rename cpp_source to cpp_sources, cuda_source to cuda_sources. - Unify the cpp_functions and cuda_functions into functions. - Add build_directory to allow the user to specify the build directory directly.
1 parent e1700e1 commit d3a5811

File tree

3 files changed

+204
-85
lines changed

3 files changed

+204
-85
lines changed

ffi/examples/inline_module/main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
def main():
2424
mod: Module = tvm_ffi.cpp.load_inline(
2525
name="hello",
26-
cpp_source=r"""
27-
void AddOne(DLTensor* x, DLTensor* y) {
26+
cpp_sources=r"""
27+
void add_one_cpu(DLTensor* x, DLTensor* y) {
2828
// implementation of a library function
2929
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
3030
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -36,16 +36,18 @@ def main():
3636
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
3737
}
3838
}
39+
40+
void add_one_cuda(DLTensor* x, DLTensor* y);
3941
""",
40-
cuda_source=r"""
42+
cuda_sources=r"""
4143
__global__ void AddOneKernel(float* x, float* y, int n) {
4244
int idx = blockIdx.x * blockDim.x + threadIdx.x;
4345
if (idx < n) {
4446
y[idx] = x[idx] + 1;
4547
}
4648
}
4749
48-
void AddOneCUDA(DLTensor* x, DLTensor* y) {
50+
void add_one_cuda(DLTensor* x, DLTensor* y) {
4951
// implementation of a library function
5052
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
5153
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -67,8 +69,7 @@ def main():
6769
static_cast<float*>(y->data), n);
6870
}
6971
""",
70-
cpp_functions={"add_one_cpu": "AddOne"},
71-
cuda_functions={"add_one_cuda": "AddOneCUDA"},
72+
functions=["add_one_cpu", "add_one_cuda"],
7273
)
7374

7475
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)

ffi/python/tvm_ffi/cpp/load_inline.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
def _hash_sources(
3535
cpp_source: str,
3636
cuda_source: str,
37-
cpp_functions: Mapping[str, str],
38-
cuda_functions: Mapping[str, str],
37+
functions: Sequence[str] | Mapping[str, str],
3938
extra_cflags: Sequence[str],
4039
extra_cuda_cflags: Sequence[str],
4140
extra_ldflags: Sequence[str],
@@ -45,12 +44,13 @@ def _hash_sources(
4544
m = hashlib.sha256()
4645
m.update(cpp_source.encode("utf-8"))
4746
m.update(cuda_source.encode("utf-8"))
48-
for name, doc in sorted(cpp_functions.items()):
49-
m.update(name.encode("utf-8"))
50-
m.update(doc.encode("utf-8"))
51-
for name, doc in sorted(cuda_functions.items()):
52-
m.update(name.encode("utf-8"))
53-
m.update(doc.encode("utf-8"))
47+
if isinstance(functions, Mapping):
48+
for name in sorted(functions):
49+
m.update(name.encode("utf-8"))
50+
m.update(functions[name].encode("utf-8"))
51+
else:
52+
for name in sorted(functions):
53+
m.update(name.encode("utf-8"))
5454
for flag in extra_cflags:
5555
m.update(flag.encode("utf-8"))
5656
for flag in extra_cuda_cflags:
@@ -242,8 +242,10 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
242242
source,
243243
]
244244

245-
for exported_name, func_name_in_source in functions.items():
246-
sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name}, {func_name_in_source});")
245+
for func_name, func_doc in functions.items():
246+
sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({func_name}, {func_name});")
247+
_ = func_doc # todo: add support to embed function docstring to the tvm ffi functions.
248+
247249
sources.append("")
248250

249251
return "\n".join(sources)
@@ -252,26 +254,26 @@ def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
252254
def load_inline(
253255
name: str,
254256
*,
255-
cpp_source: str | None = None,
256-
cuda_source: str | None = None,
257-
cpp_functions: Mapping[str, str] | None = None,
258-
cuda_functions: Mapping[str, str] | None = None,
257+
cpp_sources: str | None = None,
258+
cuda_sources: str | None = None,
259+
functions: Sequence[str] | None = None,
259260
extra_cflags: Sequence[str] | None = None,
260261
extra_cuda_cflags: Sequence[str] | None = None,
261262
extra_ldflags: Sequence[str] | None = None,
262263
extra_include_paths: Sequence[str] | None = None,
264+
build_directory: Optional[str] = None,
263265
) -> Module:
264266
"""Compile and load a C++/CUDA tvm ffi module from inline source code.
265267
266-
This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_source and cuda_source
267-
are compiled to an object file, and then linked together into a shared library. It's possible to only provide
268-
cpp_source or cuda_source.
268+
This function compiles the given C++ and/or CUDA source code into a shared library. Both cpp_sources and
269+
cuda_sources are compiled to an object file, and then linked together into a shared library. It's possible to only
270+
provide cpp_sources or cuda_sources.
269271
270-
The `cpp_functions` and `cuda_functions` parameters are used to specify which functions in the source code
271-
should be exported to the tvm ffi module. The keys of the mapping are the names of the exported functions, and the
272-
values are the names of the functions in the source code. The exported name and the function name in the source code
273-
must be different. The exported name must be a valid C identifier while the function name in the source code can
274-
contain namespace qualifiers.
272+
The `functions` parameter is used to specify which functions in the source code should be exported to the tvm ffi module.
273+
It can be a mapping, a sequence, or a single string. When a mapping is given, the keys are the names of the exported
274+
functions, and the values are docstrings for the functions. When a sequence or a single string is given, they are the
275+
functions needed to be exported, and the docstrings are set to empty strings. A single function name can also be given
276+
as a string, indicating that only one function is to be exported.
275277
276278
Extra compiler and linker flags can be provided via the `extra_cflags`, `extra_cuda_cflags`, and `extra_ldflags`
277279
parameters. The default flags are generally sufficient for most use cases, but you may need to provide additional
@@ -281,22 +283,24 @@ def load_inline(
281283
any header from tvm ffi and dlpack in your source code. You can also provide additional include paths via the
282284
`extra_include_paths` parameter and include custom headers in your source code.
283285
284-
The compiled shared library is cached in a cache directory to avoid recompilation. The cache directory can be
285-
specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified, the default cache directory is
286-
`~/.cache/tvm-ffi`.
286+
The compiled shared library is cached in a cache directory to avoid recompilation. The `build_directory` parameter
287+
is provided to specify the build directory. If not specified, a default tvm ffi cache directory will be used.
288+
The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` environment variable. If not specified,
289+
the default cache directory is `~/.cache/tvm-ffi`.
287290
288291
Parameters
289292
----------
290293
name: str
291294
The name of the tvm ffi module.
292-
cpp_source: str, optional
293-
The C++ source code.
294-
cuda_source: str, optional
295-
The CUDA source code.
296-
cpp_functions: Mapping[str, str], optional
297-
The mapping from the exported function name to the function name in the C++ source code.
298-
cuda_functions: Mapping[str, str], optional
299-
The mapping from the exported function name to the function name in the CUDA source code.
295+
cpp_sources: Sequence[str] | str, optional
296+
The C++ source code. It can be a list of sources or a single source.
297+
cuda_sources: Sequence[str] | str, optional
298+
The CUDA source code. It can be a list of sources or a single source.
299+
functions: Mapping[str, str] | Sequence[str] | str, optional
300+
The functions in cpp_sources that will be exported to the tvm ffi module. When a mapping is given, the keys
301+
are the names of the exported functions, and the values are docstrings for the functions. When a sequence or a
302+
single string is given, they are the functions needed to be exported, and the docstrings are set to empty
303+
strings. A single function name can also be given as a string.
300304
extra_cflags: Sequence[str], optional
301305
The extra compiler flags for C++ compilation.
302306
The default flags are:
@@ -316,46 +320,58 @@ def load_inline(
316320
The extra include paths.
317321
The default include paths are:
318322
- The include path of tvm ffi
323+
build_directory: str, optional
324+
The build directory. If not specified, a default tvm ffi cache directory will be used. By default, the
325+
cache directory is `~/.cache/tvm-ffi`. You can also set the `TVM_FFI_CACHE_DIR` environment variable to
326+
specify the cache directory.
327+
319328
Returns
320329
-------
321330
mod: Module
322331
The loaded tvm ffi module.
323332
"""
324-
if cpp_source is None:
325-
cpp_source = ""
326-
if cuda_source is None:
327-
cuda_source = ""
328-
if cpp_functions is None:
329-
cpp_functions = {}
330-
if cuda_functions is None:
331-
cuda_functions = {}
333+
if cpp_sources is None:
334+
cpp_sources = []
335+
elif isinstance(cpp_sources, str):
336+
cpp_sources = [cpp_sources]
337+
cpp_source = "\n".join(cpp_sources)
338+
if cuda_sources is None:
339+
cuda_sources = []
340+
elif isinstance(cuda_sources, str):
341+
cuda_sources = [cuda_sources]
342+
cuda_source = "\n".join(cuda_sources)
343+
with_cuda = len(cuda_sources) > 0
344+
332345
extra_ldflags = extra_ldflags or []
333346
extra_cflags = extra_cflags or []
334347
extra_cuda_cflags = extra_cuda_cflags or []
335348
extra_include_paths = extra_include_paths or []
336349

337-
# whether we have cuda source in this module
338-
with_cuda = len(cuda_source.strip()) > 0
339-
340350
# add function registration code to sources
341-
cpp_source = _decorate_with_tvm_ffi(cpp_source, cpp_functions)
342-
cuda_source = _decorate_with_tvm_ffi(cuda_source, cuda_functions)
351+
if isinstance(functions, str):
352+
functions = {functions: ""}
353+
elif isinstance(functions, Sequence):
354+
functions = {name: "" for name in functions}
355+
cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
356+
cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
343357

344358
# determine the cache dir for the built module
345-
cache_dir = os.path.join(
346-
os.environ.get("TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi"))
347-
)
348-
source_hash: str = _hash_sources(
349-
cpp_source,
350-
cuda_source,
351-
cpp_functions,
352-
cuda_functions,
353-
extra_cflags,
354-
extra_cuda_cflags,
355-
extra_ldflags,
356-
extra_include_paths,
357-
)
358-
build_dir: str = os.path.join(cache_dir, "{}_{}".format(name, source_hash))
359+
if build_directory is None:
360+
build_directory = os.environ.get(
361+
"TVM_FFI_CACHE_DIR", os.path.expanduser("~/.cache/tvm-ffi")
362+
)
363+
source_hash: str = _hash_sources(
364+
cpp_source,
365+
cuda_source,
366+
functions,
367+
extra_cflags,
368+
extra_cuda_cflags,
369+
extra_ldflags,
370+
extra_include_paths,
371+
)
372+
build_dir: str = os.path.join(build_directory, "{}_{}".format(name, source_hash))
373+
else:
374+
build_dir = os.path.abspath(build_directory)
359375
os.makedirs(build_dir, exist_ok=True)
360376

361377
# generate build.ninja

0 commit comments

Comments
 (0)