Skip to content

Commit 8f60213

Browse files
authored
[Runtime] Serialization/Deserialization of runtime module (#15244)
1 parent 0166400 commit 8f60213

File tree

11 files changed

+344
-164
lines changed

11 files changed

+344
-164
lines changed

include/tvm/runtime/module.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,11 @@ class TVM_DLL ModuleNode : public Object {
232232
return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0;
233233
}
234234

235+
/*! \brief Returns true if this module is 'Binary Serializable'. */
236+
bool IsBinarySerializable() const {
237+
return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0;
238+
}
239+
235240
/*!
236241
* \brief Returns true if this module has a definition for a function of \p name. If
237242
* \p query_imports is true, also search in any imported modules.

include/tvm/target/codegen.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ using runtime::TVMRetValue;
4747
*/
4848
runtime::Module Build(IRModule mod, Target target);
4949

50+
/*!
51+
* \brief Serialize runtime module including its submodules
52+
* \param mod The runtime module to serialize including its import tree.
53+
* \param export_dso By default, include the info of DSOExportable modules. If disabled, an error
54+
* will be raised when encountering DSO modules.
55+
*/
56+
std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true);
57+
58+
/*!
59+
* \brief Deserialize runtime module including its submodules
60+
* \param blob byte stream, which are generated by `SerializeModuleToBytes`.
61+
* \return runtime::Module runtime module constructed from the given stream
62+
*/
63+
runtime::Module DeserializeModuleFromBytes(std::string blob);
64+
5065
/*!
5166
* \brief Pack imported device library to a C file.
5267
* Compile the C file and link with the host library
@@ -77,6 +92,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib,
7792
runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
7893
const std::string& target_triple,
7994
const std::string& c_symbol_prefix = "");
95+
8096
} // namespace codegen
8197
} // namespace tvm
8298
#endif // TVM_TARGET_CODEGEN_H_

python/tvm/contrib/torch/optimize_torch.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,19 @@
2525
which is used to optimize the `torch.nn.module` by TVM metaSchedule,
2626
and returns a custom TorchScript operator
2727
"""
28-
import base64
28+
2929
import contextlib
3030
import tempfile
3131
from typing import Optional, Tuple, Union
32-
32+
import base64
3333
import torch
3434
import torch.utils.dlpack
3535
import tvm
36+
import tvm._ffi
37+
from tvm._ffi import register_func
3638
from tvm import meta_schedule as ms
3739
from tvm import relay
38-
from tvm._ffi import get_global_func, register_func
40+
from tvm._ffi import get_global_func
3941
from tvm.target import Target
4042

4143

@@ -51,14 +53,6 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]):
5153
return ret
5254

5355

54-
@register_func("script_torch.save_to_base64")
55-
def save_to_base64(obj) -> bytes:
56-
with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
57-
obj.export_library(tmpfile.name)
58-
with open(tmpfile.name, "rb") as temp_file:
59-
return base64.b64encode(temp_file.read())
60-
61-
6256
def optimize_torch(
6357
func,
6458
example_inputs,
@@ -173,3 +167,11 @@ def optimize_torch(
173167
save_runtime_mod(executor_factory.module)
174168

175169
return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
170+
171+
172+
@register_func("export_runtime_module")
173+
def save_to_base64(obj) -> bytes:
174+
with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
175+
obj.export_library(tmpfile.name)
176+
with open(tmpfile.name, "rb") as temp_file:
177+
return base64.b64encode(temp_file.read())

python/tvm/runtime/module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from typing import Sequence
2424
import numpy as np
2525

26-
import tvm._ffi
2726
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
2827
from tvm._ffi.libinfo import find_include_path
2928
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module

src/contrib/torch/base64.h

Lines changed: 0 additions & 75 deletions
This file was deleted.

src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc

Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <vector>
3030

3131
#include "../../../runtime/graph_executor/graph_executor_factory.h"
32-
#include "../base64.h"
32+
#include "../../support/base64.h"
3333
#include "runtime_bridge.h"
3434

3535
namespace tvm {
@@ -46,54 +46,6 @@ struct ThreadLocalStore {
4646
}
4747
};
4848

49-
/*
50-
* Encode TVM runtime module to base64 stream
51-
*/
52-
std::string serialize(tvm::runtime::Module module) {
53-
static const runtime::PackedFunc* f_to_str =
54-
runtime::Registry::Get("script_torch.save_to_base64");
55-
ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
56-
"`script_torch.save_to_base64` in the global registry";
57-
return (*f_to_str)(module);
58-
}
59-
60-
struct Deleter { // deleter
61-
explicit Deleter(std::string file_name) { this->file_name = file_name; }
62-
void operator()(FILE* p) const {
63-
fclose(p);
64-
ICHECK(remove(file_name.c_str()) == 0)
65-
<< "remove temporary file (" << file_name << ") unsuccessfully";
66-
}
67-
std::string file_name;
68-
};
69-
70-
/*
71-
* Decode TVM runtime module from base64 stream
72-
*/
73-
tvm::runtime::Module deserialize(std::string state) {
74-
auto length = tvm::support::b64strlen(state);
75-
76-
std::vector<u_char> bytes(length); // bytes stream
77-
tvm::support::b64decode(state, bytes.data());
78-
79-
const std::string name = tmpnam(NULL);
80-
auto file_name = name + ".so";
81-
std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
82-
fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
83-
fflush(pFile.get());
84-
85-
std::string load_f_name = "runtime.module.loadfile_so";
86-
const PackedFunc* f = runtime::Registry::Get(load_f_name);
87-
ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
88-
<< " resolved to (" << load_f_name << ") in the global registry."
89-
<< "Ensure that you have loaded the correct runtime code, and"
90-
<< "that you are on the correct hardware architecture.";
91-
92-
tvm::runtime::Module ret = (*f)(file_name, "");
93-
94-
return ret;
95-
}
96-
9749
TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
9850
ThreadLocalStore::ThreadLocal()->mod = mod;
9951
});
@@ -242,15 +194,104 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
242194
return output_length;
243195
}
244196

197+
inline size_t b64strlen(const std::string b64str) {
198+
ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
199+
size_t length = b64str.size() / 4 * 3;
200+
if (b64str[b64str.size() - 2] == '=') {
201+
length -= 2;
202+
} else if (b64str[b64str.size() - 1] == '=') {
203+
length -= 1;
204+
}
205+
return length;
206+
}
207+
208+
inline void b64decode(const std::string b64str, uint8_t* ret) {
209+
size_t index = 0;
210+
const auto length = b64str.size();
211+
for (size_t i = 0; i < length; i += 4) {
212+
int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
213+
int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
214+
int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
215+
int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
216+
uint8_t st1 = (ch0 << 2) + (ch1 >> 4);
217+
ret[index++] = st1;
218+
if (b64str[i + 2] != '=') {
219+
uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
220+
ret[index++] = st2;
221+
if (b64str[i + 3] != '=') {
222+
uint8_t st3 = ((ch2 & 0b11) << 6) + ch3;
223+
ret[index++] = st3;
224+
}
225+
}
226+
}
227+
ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
228+
}
229+
230+
/*!
231+
* \brief Export TVM runtime module to base64 stream including its submodules.
232+
* Note that this targets modules that are binary serializable and DSOExportable.
233+
* \param module The runtime module to export
234+
* \return std::string The content of exported file
235+
*/
236+
std::string ExportModuleToBase64(tvm::runtime::Module module) {
237+
static const tvm::runtime::PackedFunc* f_to_str =
238+
tvm::runtime::Registry::Get("export_runtime_module");
239+
ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
240+
"`export_runtime_module` in the global registry";
241+
return (*f_to_str)(module);
242+
}
243+
244+
struct Deleter { // deleter
245+
explicit Deleter(std::string file_name) { this->file_name = file_name; }
246+
void operator()(FILE* p) const {
247+
fclose(p);
248+
ICHECK(remove(file_name.c_str()) == 0)
249+
<< "remove temporary file (" << file_name << ") unsuccessfully";
250+
}
251+
std::string file_name;
252+
};
253+
254+
/*!
255+
* \brief Import TVM runtime module from base64 stream
256+
* Note that this targets modules that are binary serializable and DSOExportable.
257+
* \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
258+
* \return runtime::Module runtime module constructed from the given stream
259+
*/
260+
tvm::runtime::Module ImportModuleFromBase64(std::string base64str) {
261+
auto length = b64strlen(base64str);
262+
263+
std::vector<uint8_t> bytes(length); // bytes stream
264+
b64decode(base64str, bytes.data());
265+
266+
auto now = std::chrono::system_clock::now();
267+
auto in_time_t = std::chrono::system_clock::to_time_t(now);
268+
std::stringstream datetime;
269+
datetime << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d-%X");
270+
const std::string file_name = "tmp-module-" + datetime.str() + ".so";
271+
LOG(INFO) << file_name;
272+
std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
273+
fwrite(bytes.data(), sizeof(uint8_t), length, pFile.get());
274+
fflush(pFile.get());
275+
276+
std::string load_f_name = "runtime.module.loadfile_so";
277+
const tvm::runtime::PackedFunc* f = tvm::runtime::Registry::Get(load_f_name);
278+
ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
279+
<< " resolved to (" << load_f_name << ") in the global registry."
280+
<< "Ensure that you have loaded the correct runtime code, and"
281+
<< "that you are on the correct hardware architecture.";
282+
tvm::runtime::Module ret = (*f)(file_name, "");
283+
return ret;
284+
}
285+
245286
char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
246-
std::string std = tvm::contrib::serialize(runtime_module->mod);
287+
std::string std = ExportModuleToBase64(runtime_module->mod);
247288
char* ret = new char[std.length() + 1];
248289
snprintf(ret, std.length() + 1, "%s", std.c_str());
249290
return ret;
250291
}
251292

252293
TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
253-
tvm::runtime::Module ret = tvm::contrib::deserialize(state);
294+
tvm::runtime::Module ret = ImportModuleFromBase64(state);
254295
return new TVMContribTorchRuntimeModule(ret);
255296
}
256297

src/node/structural_hash.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/runtime/container/adt.h>
2929
#include <tvm/runtime/profiling.h>
3030
#include <tvm/runtime/registry.h>
31+
#include <tvm/target/codegen.h>
3132

3233
#include <algorithm>
3334
#include <unordered_map>
@@ -360,6 +361,22 @@ struct ADTObjTrait {
360361

361362
TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
362363

364+
struct ModuleNodeTrait {
365+
static constexpr const std::nullptr_t VisitAttrs = nullptr;
366+
static constexpr const std::nullptr_t SHashReduce = nullptr;
367+
static constexpr const std::nullptr_t SEqualReduce = nullptr;
368+
};
369+
370+
TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait)
371+
.set_creator([](const std::string& blob) {
372+
runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
373+
return RefToObjectPtr::Get(rtmod);
374+
})
375+
.set_repr_bytes([](const Object* n) -> std::string {
376+
const auto* rtmod = static_cast<const runtime::ModuleNode*>(n);
377+
return codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*export_dso*/ false);
378+
});
379+
363380
void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
364381
bool hash_data) {
365382
ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";

src/runtime/library_module.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode {
6767
PackedFuncWrapper packed_func_wrapper_;
6868
};
6969

70-
/*!
71-
* \brief Helper classes to get into internal of a module.
72-
*/
73-
class ModuleInternal {
74-
public:
75-
// Get mutable reference of imports.
76-
static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
77-
};
78-
7970
PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
8071
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
8172
TVMValue ret_value;

0 commit comments

Comments
 (0)