Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using the VM across the RPC boundary. #7746

Merged
merged 10 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ class Executable : public ModuleNode {
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;

/*!
* \brief Save the entire executable to a binary stream.
* \param stream The binary stream to save to.
* \brief Write the Executable to the binary stream in serialized form.
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;

/*!
* \brief Write the Executable to the provided path as a file contianing its serialized content.
* \param path The path to write the serialized data to.
* \param format The format of the serialized blob.
*/
void SaveToFile(const std::string& path, const std::string& format) final;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*!
Expand Down Expand Up @@ -135,18 +140,35 @@ class Executable : public ModuleNode {
*
* \return The runtime module that contains the hardware dependent code.
*/
runtime::Module GetLib() const { return this->imports_[0]; }
runtime::Module GetLib() const {
ICHECK_EQ(this->imports_.size(), 1)
<< "The kernel library must be imported as the only module in an Executable";

return this->imports_[0];
}

/*!
* \brief Set the `lib` module in an executable.
*
* This allows us to do partial initialization in the case of (de|ser)ialization cases.
* This method also ensures correct initialization of library ensuring we only Import a
* single library.
*
* NB: This also provides some abstraction over how libraries are stored as there are plans
* to iterate on the way runtime::Module works in the backend of the compiler.
*/
void SetLib(const runtime::Module& lib) {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(lib.defined()) << "library can not be null";
ICHECK(lib.defined())
<< "the provided library can not be null";

ICHECK_EQ(this->imports().size(), 0) << "can only import the library once";
ICHECK_EQ(this->imports().size(), 0)
<< "you can only import one device specific library";

this->Import(lib);
}

/*!
* \brief Get the arity of the VM Fucntion.
* \brief Get the arity of the VMFunction.
* \param func Function name.
* \return The number of parameters.
*/
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No
"""
Export the module and all imported modules into a single device library.

This function only works on hos LLVM modules, other runtime::Module
This function only works on host LLVM modules, other runtime::Module
subclasses will work with this API but they must support implement
the save and load mechanisms of modules completely including saving
from streams and files. This will pack your non-shared library module
Expand All @@ -285,8 +285,12 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No

fcompile : function(target, file_list, kwargs), optional
The compilation function to use create the final library object during
export. For example this is used to link together all produced artifacts
export.

For example, when fcompile=_cc.create_shared, or when it is not supplied but
module is "llvm," this is used to link all produced artifacts
into a final dynamic library.

This behavior is controlled by the type of object exported.
If fcompile has attribute object_format, will compile host library
to that format. Otherwise, will use default format "o".
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,26 @@ class VirtualMachine(object):
POOLED_ALLOCATOR = 2

def __init__(self, exe, device, memory_cfg=None):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
"""
Construct a VirtualMachine wrapper class which provides a simple
interface over the raw C++ Module based API.

Parameters
----------
exe: Union[Executable, Module]
The executable either with the wrapper Python type or the raw runtime.Module.

device: Union[Device, List[Device]]
The device, or devices on which to execute the VM code.

memory_cfg: Optional[str]
The allocator behavior to use for the VM.

Returns
-------
vm: VirtualMachine
A VM wrapper object.
"""
if not isinstance(exe, Executable) and not isinstance(exe, Module):
jroesch marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"exe is expected to be the type of Executable, "
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,16 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
if (f == nullptr) {
std::string loaders = "";
for (auto name : Registry::ListNames()) {
if (name.rfind(loadkey, 0) == 0) {
if (name.find(loadkey, 0) == 0) {
if (loaders.size() > 0) {
loaders += ", ";
}
loaders += name.substr(loadkey.size());
}
}
ICHECK(f != nullptr) << "Binary was created using " << type_key
<< " but a loader of that name is not registered. Available loaders are "
<< loaders << ". Perhaps you need to recompile with this runtime enabled.";
LOG(FATAL) << "Binary was created using " << type_key
jroesch marked this conversation as resolved.
Show resolved Hide resolved
<< " but a loader of that name is not registered. Available loaders are "
<< loaders << ". Perhaps you need to recompile with this runtime enabled.";
}

return (*f)(static_cast<void*>(stream));
Expand Down
1 change: 0 additions & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ void LoadHeader(dmlc::Stream* strm) {
}

runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::cout << "code: " << code.size() << std::endl;
auto exec = make_object<Executable>();

// Support null-initialization of lib, to enable initialization during
Expand Down
16 changes: 16 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,27 +803,43 @@ def test_constant_shape_with_external_codegen():


def test_vm_rpc():
"""
This test checks to make sure you can export a VMExecutable,
upload it to a remote machine using RPC and then execute it
on the other machine.
"""
target = "llvm"
target_host = "llvm"

# Build a IRModule.
x = relay.var("x", shape=(10, 1))
f = relay.Function([x], x + x)
mod = IRModule.from_expr(f)

# Compile to VMExecutable.
vm_exec = vm.compile(mod, target=target, target_host=target_host)

# Export to Disk
temp = utils.tempdir()
jroesch marked this conversation as resolved.
Show resolved Hide resolved
path = temp.relpath("vm_library.so")
vm_exec.mod.export_library(path)

# Use LocalRPC for testing.
remote = rpc.LocalSession()

# Upload the serialized Executable.
remote.upload(path)
# Get a handle to remote Executable.
rexec = remote.load_module("vm_library.so")

ctx = remote.cpu()
# Build a VM out of the executable and context.
vm_factory = runtime.vm.VirtualMachine(rexec, ctx)
np_input = np.random.uniform(size=(10, 1)).astype("float32")
input_tensor = tvm.nd.array(np_input, ctx)
# Invoke its "main" function.
out = vm_factory.invoke("main", [input_tensor])
# Check the result.
np.testing.assert_allclose(out.asnumpy(), np_input + np_input)


Expand Down