Skip to content

Commit a164d89

Browse files
committed
[Unity][nn.Module] Refactor ExternModule
`nn.ExternModule` allows incorporation of handcrafted kernels into the compilation stack and being invoked by Relax just like TIR or any other ordinary operator. This PR simplifies its workflow. The system consists of the abstract base class `ExternModule` and its two derivatives: - `.o` (object files) can be linked using `ObjectModule`. - `.cpp` (C++ files) and `.cu` (CUDA files) can be compiled and linked into the system usung `SourceModule`. **Symbols, and shape/dtype inference.** To provide the system with sufficient information about the kernels, it is required to provide all symbols of an external module, as well as a method for each symbol that tells the system about the output dtype/shape of this symbol. Consider a case where function `my_func` accepts two tensors, `a` of shape `(x, y, 1)`, `b` of shape `(y, z, 5)`, and then produces a tensor `c` of shape `(x, y, z, 9)`, the shape/dtype inference function should look like: ```python def shape_dtype_inference(a, b): x, y, _ = a.shape _, z, _ = b.shape return nn.Tensor.placeholder((x, y, z, 9), dtype="float32") ``` Regarding the interface, the symbols and their corresponding shape/dtype inference function should be provided as a Python dictionary that maps each symbol to the function as below: ```python symbols={ "my_func": shape_dtype_inference, } ``` **Calling convention.** All external modules now follows "destination-passing-style" (DPS) calling convention, which means the returned tensors are pre-allocated by the system already and passed in as an argument of the external function. Reuse the example above, the implementation of `my_func` should include three parameters in its signature, where tensors are represented using DLTensor from DLPack, the de facto standard of in-memory representation of tensors. More info on DLPack: https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206. To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is guaranteed available: ```C++ // those headers are guaranteed to be available \#include <dlpack/dlpack.h> \#include <tvm/runtime/data_type.h> \#include <tvm/runtime/packed_func.h> namespace { // anonymous namespace hides the symbol `_my_func_impl` from other TUs int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) { // `a` and `b` are inputs, and `c` is the output } } // expose symbol `my_func` instead of `_my_func_impl` TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl); ``` **A compiler pass `AttachExternModules`.** It is introduced to attach a list of `nn.ExternModule`s into an IRModule at any stage of the compilation pipeline, and attach the compiled external modules as `runtime.Module`s into IRModule's `external_mods` attribute. It is required by linking in `relax.build`, but with the existence of this pass, source compilation can be deferred to arbitrary stage of TVM compilation. **Caveats.** It is required to call `nn.add_extern` to register external modules exactly once during `export_tvm`. Each symbol should be registered exactly once to avoid potential conflicts, and otherwise an error will be raised. This programming model might be a bit of constraint, and we will consider loose it slightly in the future. Also, for backward compatibility, `ExternModule`s are exported from `export_tvm` only when `allow_extern` flag is turned on. Otherwise, any external module will cause an exception asking to turn on the flag.
1 parent f794db4 commit a164d89

File tree

17 files changed

+1234
-1047
lines changed

17 files changed

+1234
-1047
lines changed

python/tvm/contrib/cc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=N
8282
The compiler command.
8383
8484
cwd : Optional[str]
85-
The urrent working directory.
85+
The current working directory.
8686
8787
ccache_env : Optional[Dict[str, str]]
8888
The environment variable for ccache. Set `None` to disable ccache by default.

python/tvm/relax/frontend/nn/__init__.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,9 @@
1717
"""A PyTorch-like API to build IRModules."""
1818
# pylint: disable=redefined-builtin
1919
from . import op, spec
20-
from .core import (
21-
Effect,
22-
ExternModule,
23-
Module,
24-
ModuleList,
25-
Parameter,
26-
SourceModule,
27-
Tensor,
28-
)
20+
from .core import Effect, Module, ModuleList, Parameter, Tensor
21+
from .exporter import add_extern
22+
from .extern import ExternModule, ObjectModule, SourceModule
2923
from .modules import (
3024
GELU,
3125
Conv1D,
@@ -35,7 +29,6 @@
3529
KVCache,
3630
LayerNorm,
3731
Linear,
38-
MultiLinear,
3932
ReLU,
4033
RMSNorm,
4134
SiLU,

0 commit comments

Comments
 (0)