[Unity][nn.Module] Refactor ExternModule
#16247
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
nn.ExternModuleallows incorporation of handcrafted kernels into the compilation stack and being invoked by Relax just like TIR or any other ordinary operator. This PR simplifies its workflow.The system consists of the abstract base class
ExternModuleand its two derivatives:.o(object files) can be linked usingObjectModule..cpp(C++ files) and.cu(CUDA files) can be compiled and linked into the system usungSourceModule.Symbols, and shape/dtype inference. To provide the system with sufficient information about the kernels, it is required to provide all symbols of an external module, as well as a method for each symbol that tells the system about the output dtype/shape of this symbol.
Consider a case where function
my_funcaccepts two tensors,aof shape(x, y, 1),bof shape(y, z, 5), and then produces a tensorcof shape(x, y, z, 9), the shape/dtype inference function should look like:Regarding the interface, the symbols and their corresponding shape/dtype inference function should be provided as a Python dictionary that maps each symbol to the function as below:
Calling convention. All external modules now follows "destination-passing-style" (DPS) calling convention, which means the returned tensors are pre-allocated by the system already and passed in as an argument of the external
function.
Reuse the example above, the implementation of
my_funcshould include three parameters in its signature, where tensors are represented using DLTensor from DLPack, the de facto standard of in-memory representation of tensors. More info on DLPack: https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206.To expose the symbol,
TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)is guaranteed available:A compiler pass
AttachExternModules. It is introduced to attach a list ofnn.ExternModules into an IRModule at any stage of the compilation pipeline, and attach the compiled external modules asruntime.Modules into IRModule'sexternal_modsattribute. It is required by linking inrelax.build, but with the existence of this pass, source compilation can be deferred to arbitrary stage of TVM compilation.Caveats. It is required to call
nn.add_externto register external modules exactly once duringexport_tvm. Each symbol should be registered exactly once to avoid potential conflicts, and otherwise an error will be raised. This programming model might be a bit of constraint, and we will consider loose it slightly in the future.Also, for backward compatibility,
ExternModules are exported fromexport_tvmonly whenallow_externflag is turned on. Otherwise, any external module will cause an exception asking to turn on the flag.