Commit a0e5898
authored
[Unity][nn.Module] Refactor
`nn.ExternModule` allows incorporation of handcrafted kernels into the
compilation stack and being invoked by Relax just like TIR or any other
ordinary operator. This PR simplifies its workflow.
The system consists of the abstract base class `ExternModule` and its
two derivatives:
- `.o` (object files) can be linked using `ObjectModule`.
- `.cpp` (C++ files) and `.cu` (CUDA files) can be compiled and linked
into the system usung `SourceModule`.
**Symbols, and shape/dtype inference.**
To provide the system with sufficient information about the kernels, it
is required to provide all symbols of an external module, as well as a
method for each symbol that tells the system about the output dtype/shape
of this symbol.
Consider a case where function `my_func` accepts two tensors, `a` of
shape `(x, y, 1)`, `b` of shape `(y, z, 5)`, and then produces a tensor
`c` of shape `(x, y, z, 9)`, the shape/dtype inference function should
look like:
```python
def shape_dtype_inference(a, b):
x, y, _ = a.shape
_, z, _ = b.shape
return nn.Tensor.placeholder((x, y, z, 9), dtype="float32")
```
Regarding the interface, the symbols and their corresponding shape/dtype
inference function should be provided as a Python dictionary that maps
each symbol to the function as below:
```python
symbols={
"my_func": shape_dtype_inference,
}
```
**Calling convention.**
All external modules now follows "destination-passing-style" (DPS)
calling convention, which means the returned tensors are pre-allocated
by the system already and passed in as an argument of the external
function.
Reuse the example above, the implementation of `my_func` should include
three parameters in its signature, where tensors are represented using
DLTensor from DLPack, the de facto standard of in-memory representation
of tensors. More info on DLPack:
https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206.
To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is
guaranteed available:
```C++
// those headers are guaranteed to be available
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/data_type.h>
\#include <tvm/runtime/packed_func.h>
namespace {
// anonymous namespace hides the symbol `_my_func_impl` from other TUs
int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) {
// `a` and `b` are inputs, and `c` is the output
}
}
// expose symbol `my_func` instead of `_my_func_impl`
TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl);
```
**A compiler pass `AttachExternModules`.**
It is introduced to attach a list of `nn.ExternModule`s into an IRModule
at any stage of the compilation pipeline, and attach the compiled external
modules as `runtime.Module`s into IRModule's `external_mods` attribute.
It is required by linking in `relax.build`, but with the existence of
this pass, source compilation can be deferred to arbitrary stage of TVM
compilation.
**Caveats.**
It is required to call `nn.add_extern` to register external modules exactly
once during `export_tvm`. Each symbol should be registered exactly once to
avoid potential conflicts, and otherwise an error will be raised. This
programming model might be a bit of constraint, and we will consider loose
it slightly in the future.
Also, for backward compatibility, `ExternModule`s are exported from
`export_tvm` only when `allow_extern` flag is turned on. Otherwise, any
external module will cause an exception asking to turn on the flag.ExternModule (#16247)1 parent 2d0d4e4 commit a0e5898
File tree
17 files changed
+1234
-1047
lines changed- 3rdparty
- python/tvm
- contrib
- relax
- frontend/nn
- transform
- runtime
- target
- tests/python/relax
17 files changed
+1234
-1047
lines changedSubmodule flashinfer updated 20 files
- include/flashinfer/cascade.cuh+1-1
- include/flashinfer/decode.cuh+157-122
- include/flashinfer/mma.cuh+1
- include/flashinfer/page.cuh+49-84
- include/flashinfer/prefill.cuh+282-387
- include/flashinfer/state.cuh+2
- include/flashinfer/utils.cuh+9
- include/flashinfer/vec_dtypes.cuh+1-1
- include/flashinfer/wrapper.cuh+95-90
- python/csrc/batch_decode.cu+3-4
- python/csrc/single_prefill.cu+7-9
- python/flashinfer/ops/__init__.py+192-13
- src/bench_batch_decode.cu+12-26
- src/bench_single_decode.cu+2-1
- src/bench_single_prefill.cu+5-4
- src/test_batch_decode.cu+7-5
- src/test_batch_prefill.cu+131-157
- src/test_page.cu+6-9
- src/test_single_prefill.cu+3-2
- src/tvm_wrapper.cu+129-58
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
82 | 82 | | |
83 | 83 | | |
84 | 84 | | |
85 | | - | |
| 85 | + | |
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
29 | 23 | | |
30 | 24 | | |
31 | 25 | | |
| |||
35 | 29 | | |
36 | 30 | | |
37 | 31 | | |
38 | | - | |
39 | 32 | | |
40 | 33 | | |
41 | 34 | | |
| |||
0 commit comments