Skip to content

Conversation

@junrushao
Copy link
Member

@junrushao junrushao commented Dec 15, 2023

nn.ExternModule allows incorporation of handcrafted kernels into the compilation stack and being invoked by Relax just like TIR or any other ordinary operator. This PR simplifies its workflow.

The system consists of the abstract base class ExternModule and its two derivatives:

  • .o (object files) can be linked using ObjectModule.
  • .cpp (C++ files) and .cu (CUDA files) can be compiled and linked into the system usung SourceModule.

Symbols, and shape/dtype inference. To provide the system with sufficient information about the kernels, it is required to provide all symbols of an external module, as well as a method for each symbol that tells the system about the output dtype/shape of this symbol.

Consider a case where function my_func accepts two tensors, a of shape (x, y, 1), b of shape (y, z, 5), and then produces a tensor c of shape (x, y, z, 9), the shape/dtype inference function should look like:

def shape_dtype_inference(a, b):
  x, y, _ = a.shape
  _, z, _ = b.shape
  return nn.Tensor.placeholder((x, y, z, 9), dtype="float32")

Regarding the interface, the symbols and their corresponding shape/dtype inference function should be provided as a Python dictionary that maps each symbol to the function as below:

symbols={
  "my_func": shape_dtype_inference,
}

Calling convention. All external modules now follows "destination-passing-style" (DPS) calling convention, which means the returned tensors are pre-allocated by the system already and passed in as an argument of the external
function.

Reuse the example above, the implementation of my_func should include three parameters in its signature, where tensors are represented using DLTensor from DLPack, the de facto standard of in-memory representation of tensors. More info on DLPack: https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206.

To expose the symbol, TVM_DLL_EXPORT_TYPED_FUNC(symbol, function) is guaranteed available:

// those headers are guaranteed to be available
#include <dlpack/dlpack.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/packed_func.h>
namespace {
// anonymous namespace hides the symbol `_my_func_impl` from other TUs
int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) {
// `a` and `b` are inputs, and `c` is the output
}
}
// expose symbol `my_func` instead of `_my_func_impl`
TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl);

A compiler pass AttachExternModules. It is introduced to attach a list of nn.ExternModules into an IRModule at any stage of the compilation pipeline, and attach the compiled external modules as runtime.Modules into IRModule's external_mods attribute. It is required by linking in relax.build, but with the existence of this pass, source compilation can be deferred to arbitrary stage of TVM compilation.

Caveats. It is required to call nn.add_extern to register external modules exactly once during export_tvm. Each symbol should be registered exactly once to avoid potential conflicts, and otherwise an error will be raised. This programming model might be a bit of constraint, and we will consider loose it slightly in the future.

Also, for backward compatibility, ExternModules are exported from export_tvm only when allow_extern flag is turned on. Otherwise, any external module will cause an exception asking to turn on the flag.

@junrushao junrushao force-pushed the feature/2023-12-14/refactor-extern-module branch 10 times, most recently from 4eebd7c to 6cfc830 Compare December 15, 2023 23:15
@junrushao junrushao marked this pull request as ready for review December 15, 2023 23:17
@junrushao junrushao force-pushed the feature/2023-12-14/refactor-extern-module branch from 6cfc830 to e199b8e Compare December 15, 2023 23:19
@tqchen tqchen requested a review from cyx-6 December 15, 2023 23:20
@junrushao junrushao force-pushed the feature/2023-12-14/refactor-extern-module branch 2 times, most recently from 853d200 to 1b0d3ed Compare December 15, 2023 23:38
@junrushao junrushao force-pushed the feature/2023-12-14/refactor-extern-module branch 5 times, most recently from 126d82a to ed35d75 Compare December 16, 2023 10:37
`nn.ExternModule` allows incorporation of handcrafted kernels into the
compilation stack and being invoked by Relax just like TIR or any other
ordinary operator. This PR simplifies its workflow.

The system consists of the abstract base class `ExternModule` and its
two derivatives:
- `.o` (object files) can be linked using `ObjectModule`.
- `.cpp` (C++ files) and `.cu` (CUDA files) can be compiled and linked
  into the system usung `SourceModule`.

**Symbols, and shape/dtype inference.**
To provide the system with sufficient information about the kernels, it
is required to provide all symbols of an external module, as well as a
method for each symbol that tells the system about the output dtype/shape
of this symbol.

Consider a case where function `my_func` accepts two tensors, `a` of
shape `(x, y, 1)`, `b` of shape `(y, z, 5)`, and then produces a tensor
`c` of shape `(x, y, z, 9)`, the shape/dtype inference function should
look like:

```python
def shape_dtype_inference(a, b):
  x, y, _ = a.shape
  _, z, _ = b.shape
  return nn.Tensor.placeholder((x, y, z, 9), dtype="float32")
```

Regarding the interface, the symbols and their corresponding shape/dtype
inference function should be provided as a Python dictionary that maps
each symbol to the function as below:

```python
symbols={
  "my_func": shape_dtype_inference,
}
```

**Calling convention.**
All external modules now follows "destination-passing-style" (DPS)
calling convention, which means the returned tensors are pre-allocated
by the system already and passed in as an argument of the external
function.

Reuse the example above, the implementation of `my_func` should include
three parameters in its signature, where tensors are represented using
DLTensor from DLPack, the de facto standard of in-memory representation
of tensors. More info on DLPack:
https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206.

To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is
guaranteed available:

```C++
// those headers are guaranteed to be available
\#include <dlpack/dlpack.h>
\#include <tvm/runtime/data_type.h>
\#include <tvm/runtime/packed_func.h>
namespace {
// anonymous namespace hides the symbol `_my_func_impl` from other TUs
int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) {
// `a` and `b` are inputs, and `c` is the output
}
}
// expose symbol `my_func` instead of `_my_func_impl`
TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl);
```

**A compiler pass `AttachExternModules`.**
It is introduced to attach a list of `nn.ExternModule`s into an IRModule
at any stage of the compilation pipeline, and attach the compiled external
modules as `runtime.Module`s into IRModule's `external_mods` attribute.
It is required by linking in `relax.build`, but with the existence of
this pass, source compilation can be deferred to arbitrary stage of TVM
compilation.

**Caveats.**
It is required to call `nn.add_extern` to register external modules exactly
once during `export_tvm`. Each symbol should be registered exactly once to
avoid potential conflicts, and otherwise an error will be raised. This
programming model might be a bit of constraint, and we will consider loose
it slightly in the future.

Also, for backward compatibility, `ExternModule`s are exported from
`export_tvm` only when `allow_extern` flag is turned on. Otherwise, any
external module will cause an exception asking to turn on the flag.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants