Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using the VM across the RPC boundary. #7746

Merged
merged 10 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ class Executable : public ModuleNode {
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

/*!
* \brief Write the Executable to the binary stream in serialized form.
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

/*!
* \brief Write the Executable to the provided path as a file contianing its serialized content.
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Serialize the executable into global section, constant section, and
* code section.
Expand Down Expand Up @@ -125,12 +138,37 @@ class Executable : public ModuleNode {
* \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk.
*
* \return The runtime module that contains the hardwre dependent code.
* \return The runtime module that contains the hardware dependent code.
*/
runtime::Module GetLib() const {
ICHECK_EQ(this->imports_.size(), 1)
<< "The kernel library must be imported as the only module in an Executable";

return this->imports_[0];
}

/*!
* \brief Set the `lib` module in an executable.
*
* This allows us to do partial initialization in the case of (de|ser)ialization cases.
* This method also ensures correct initialization of library ensuring we only Import a
* single library.
*
* NB: This also provides some abstraction over how libraries are stored as there are plans
* to iterate on the way runtime::Module works in the backend of the compiler.
*/
runtime::Module GetLib() const { return lib; }
void SetLib(const runtime::Module& lib) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(lib.defined())
<< "the provided library can not be null";

ICHECK_EQ(this->imports().size(), 0)
<< "you can only import one device specific library";

this->Import(lib);
}

/*!
* \brief Get the arity of the VM Fucntion.
* \brief Get the arity of the VMFunction.
* \param func Function name.
* \return The number of parameters.
*/
Expand All @@ -148,9 +186,6 @@ class Executable : public ModuleNode {

const char* type_key() const final { return "VMExecutable"; }

/*! \brief The runtime module/library that contains both the host and also the device
* code when executing on non-CPU devices. */
runtime::Module lib;
/*! \brief The global constant pool. */
std::vector<ObjectRef> constants;
/*! \brief A map from globals (as strings) to their index in the function map. */
Expand Down
23 changes: 17 additions & 6 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,24 +269,35 @@ def _collect_dso_modules(self):
return self._collect_from_import_tree(is_dso_exportable)

def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
"""Export the module and its imported device code one library.
"""
Export the module and all imported modules into a single device library.

This function only works on host llvm modules.
It will pack all the imported modules
This function only works on host LLVM modules, other runtime::Module
subclasses will work with this API but they must support implement
the save and load mechanisms of modules completely including saving
from streams and files. This will pack your non-shared library module
into a single shared library which can later be loaded by TVM.

Parameters
----------
file_name : str
The name of the shared library.

fcompile : function(target, file_list, kwargs), optional
Compilation function to use create dynamic library.
The compilation function to use create the final library object during
export.

For example, when fcompile=_cc.create_shared, or when it is not supplied but
module is "llvm," this is used to link all produced artifacts
into a final dynamic library.

This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".

workspace_dir : str, optional
the path to a directory used to create intermediary
artifacts for the process exporting of the library.
The path of the directory used to create the intermediate
artifacts when exporting the module.
If this is not provided a temporary dir will be created.

kwargs : dict, optional
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

import tvm
from tvm.runtime import Module
from tvm._ffi.runtime_ctypes import TVMByteArray
from tvm._ffi import base as _base
from .object import Object
Expand Down Expand Up @@ -299,12 +300,44 @@ class VirtualMachine(object):
POOLED_ALLOCATOR = 2

def __init__(self, exe, device, memory_cfg=None):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(exe, Executable):
"""
Construct a VirtualMachine wrapper class which provides a simple
interface over the raw C++ Module based API.

Parameters
----------
exe: Union[Executable, Module]
The executable either with the wrapper Python type or the raw runtime.Module.

In most cases this will be the Python wrapper class tvm.runtime.vm.Executable but
if you instead get the underlying runtime.Module subclass (i.e `exe.mod`) you
can directly pass it to this method.

This case can occur when doing things such as RPC where TVM's module APIs
return the raw modules, not the wrapped modules. This constructor will
handle this internally.

device: Union[Device, List[Device]]
The device, or devices on which to execute the VM code.

memory_cfg: Optional[str]
The allocator behavior to use for the VM.

Returns
-------
vm: VirtualMachine
A VM wrapper object.
"""
if not isinstance(exe, Executable) and not isinstance(exe, Module):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"exe is expected to be the type of Executable, "
+ "but received {}".format(type(exe))
)
self.module = _ffi_api._VirtualMachine(exe.module)

if not isinstance(exe, Executable):
exe = Executable(exe)

self.module = exe.mod["vm_load_executable"]()
self._exec = exe
self._init = self.module["init"]
self._invoke = self.module["invoke"]
Expand Down
8 changes: 5 additions & 3 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,18 +1155,20 @@ void VMCompiler::Codegen() {

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module lib;
if (funcs.size() > 0) {
Map<String, IRModule> build_funcs;
for (const auto& i : funcs) {
build_funcs.Set(i.first, i.second);
}
exec_->lib = tvm::build(build_funcs, target_host_);
lib = tvm::build(build_funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.
exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_);
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_);
exec_->SetLib(lib);
jroesch marked this conversation as resolved.
Show resolved Hide resolved
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
42 changes: 23 additions & 19 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
#undef TVM_INIT_CONTEXT_FUNC
}

Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + type_key;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.find(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
LOG(FATAL) << "Binary was created using " << type_key
jroesch marked this conversation as resolved.
Show resolved Hide resolved
<< " but a loader of that name is not registered. Available loaders are "
<< loaders << ". Perhaps you need to recompile with this runtime enabled.";
}

return (*f)(static_cast<void*>(stream));
}

/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
Expand Down Expand Up @@ -133,25 +155,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
} else {
std::string loadkey = "runtime.module.loadbinary_";
std::string fkey = loadkey + tkey;
const PackedFunc* f = Registry::Get(fkey);
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
ICHECK(f != nullptr)
<< "Binary was created using " << tkey
<< " but a loader of that name is not registered. Available loaders are " << loaders
<< ". Perhaps you need to recompile with this runtime enabled.";
}
Module m = (*f)(static_cast<void*>(stream));
auto m = LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,21 @@
#include <tvm/runtime/module.h>

#include <functional>
#include <string>

namespace tvm {
namespace runtime {

/*! \brief Load a module with the given type key directly from the stream.
* This function wraps the registry mechanism used to store type based deserializers
* for each runtime::Module sub-class.
*
* \param type_key The type key of the serialized module.
* \param stream A pointer to the stream containing the serialized module.
* \return module The deserialized module.
*/
Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream);

/*!
* \brief Library is the common interface
* for storing data in the form of shared libaries.
Expand Down
58 changes: 57 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include <utility>
#include <vector>

#include "../file_utils.h"
#include "../library_module.h"
#include "serialize_utils.h"

namespace tvm {
Expand Down Expand Up @@ -74,6 +76,12 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else if (name == "vm_load_executable") {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(this);
*rv = Module(vm);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
Expand Down Expand Up @@ -477,7 +485,17 @@ void LoadHeader(dmlc::Stream* strm) {

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
auto exec = make_object<Executable>();
exec->lib = lib;

// Support null-initialization of lib, to enable initialization during
// deserialization before we have we have deserialized the imports.
if (lib.defined()) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ICHECK_EQ(exec->imports_.size(), 0)
jroesch marked this conversation as resolved.
Show resolved Hide resolved
<< "A VMExecutable should never have more than one import inside an the executable, \n"
<< "the first import should *always* be the library containing"
<< "the platform specific kernel code";
exec->Import(lib);
jroesch marked this conversation as resolved.
Show resolved Hide resolved
}

exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);

Expand Down Expand Up @@ -765,6 +783,44 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) {
}
}

void Executable::SaveToBinary(dmlc::Stream* stream) {
auto code_bytes = this->Save();
std::string code(code_bytes.data, code_bytes.size);
stream->Write(code);

ICHECK(this->imports()[0].defined()) << "the library must be imported before serialization";
}

Module ExecutableLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string code;
stream->Read(&code);
auto exec = Executable::Load(code, Module());
return exec;
}

void Executable::SaveToFile(const std::string& path, const std::string& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
SaveToBinary(strm);
SaveBinaryToFile(path, data);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_VMExecutable").set_body_typed(ExecutableLoadBinary);

// Load module from module.
Module ExecutableLoadFile(const std::string& file_name, const std::string& format) {
std::string data;
LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
auto exec = ExecutableLoadBinary(reinterpret_cast<void*>(strm));
return exec;
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_VMExecutable").set_body_typed(ExecutableLoadFile);

TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->());
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,12 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
ICHECK(exec) << "The executable is not created yet.";
exec_ = exec;

runtime::Module lib = exec_->lib;
// Get the list of packed functions.
runtime::Module lib = exec_->GetLib();

ICHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions"
<< "\n";
<< "If the executable has declared primitive functions, the"
<< "generated kernel library must non-be null.";

for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second);
Expand Down
Loading