2929#include < vector>
3030
3131#include " ../../../runtime/graph_executor/graph_executor_factory.h"
32- #include " ../base64.h"
32+ #include " ../../support/ base64.h"
3333#include " runtime_bridge.h"
3434
3535namespace tvm {
@@ -46,54 +46,6 @@ struct ThreadLocalStore {
4646 }
4747};
4848
49- /*
50- * Encode TVM runtime module to base64 stream
51- */
52- std::string serialize (tvm::runtime::Module module ) {
53- static const runtime::PackedFunc* f_to_str =
54- runtime::Registry::Get (" script_torch.save_to_base64" );
55- ICHECK (f_to_str) << " IndexError: Cannot find the packed function "
56- " `script_torch.save_to_base64` in the global registry" ;
57- return (*f_to_str)(module );
58- }
59-
60- struct Deleter { // deleter
61- explicit Deleter (std::string file_name) { this ->file_name = file_name; }
62- void operator ()(FILE* p) const {
63- fclose (p);
64- ICHECK (remove (file_name.c_str ()) == 0 )
65- << " remove temporary file (" << file_name << " ) unsuccessfully" ;
66- }
67- std::string file_name;
68- };
69-
70- /*
71- * Decode TVM runtime module from base64 stream
72- */
73- tvm::runtime::Module deserialize (std::string state) {
74- auto length = tvm::support::b64strlen (state);
75-
76- std::vector<u_char> bytes (length); // bytes stream
77- tvm::support::b64decode (state, bytes.data ());
78-
79- const std::string name = tmpnam (NULL );
80- auto file_name = name + " .so" ;
81- std::unique_ptr<FILE, Deleter> pFile (fopen (file_name.c_str (), " wb" ), Deleter (file_name));
82- fwrite (bytes.data (), sizeof (u_char), length, pFile.get ());
83- fflush (pFile.get ());
84-
85- std::string load_f_name = " runtime.module.loadfile_so" ;
86- const PackedFunc* f = runtime::Registry::Get (load_f_name);
87- ICHECK (f != nullptr ) << " Loader for `.so` files is not registered,"
88- << " resolved to (" << load_f_name << " ) in the global registry."
89- << " Ensure that you have loaded the correct runtime code, and"
90- << " that you are on the correct hardware architecture." ;
91-
92- tvm::runtime::Module ret = (*f)(file_name, " " );
93-
94- return ret;
95- }
96-
9749TVM_REGISTER_GLOBAL (" tvmtorch.save_runtime_mod" ).set_body_typed([](tvm::runtime::Module mod) {
9850 ThreadLocalStore::ThreadLocal ()->mod = mod;
9951});
@@ -242,15 +194,104 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
242194 return output_length;
243195}
244196
197+ inline size_t b64strlen (const std::string b64str) {
198+ ICHECK (b64str.size () % 4 == 0 ) << " invalid base64 encoding" ;
199+ size_t length = b64str.size () / 4 * 3 ;
200+ if (b64str[b64str.size () - 2 ] == ' =' ) {
201+ length -= 2 ;
202+ } else if (b64str[b64str.size () - 1 ] == ' =' ) {
203+ length -= 1 ;
204+ }
205+ return length;
206+ }
207+
208+ inline void b64decode (const std::string b64str, uint8_t * ret) {
209+ size_t index = 0 ;
210+ const auto length = b64str.size ();
211+ for (size_t i = 0 ; i < length; i += 4 ) {
212+ int8_t ch0 = base64::DecodeTable[(int32_t )b64str[i]];
213+ int8_t ch1 = base64::DecodeTable[(int32_t )b64str[i + 1 ]];
214+ int8_t ch2 = base64::DecodeTable[(int32_t )b64str[i + 2 ]];
215+ int8_t ch3 = base64::DecodeTable[(int32_t )b64str[i + 3 ]];
216+ uint8_t st1 = (ch0 << 2 ) + (ch1 >> 4 );
217+ ret[index++] = st1;
218+ if (b64str[i + 2 ] != ' =' ) {
219+ uint8_t st2 = ((ch1 & 0b1111 ) << 4 ) + (ch2 >> 2 );
220+ ret[index++] = st2;
221+ if (b64str[i + 3 ] != ' =' ) {
222+ uint8_t st3 = ((ch2 & 0b11 ) << 6 ) + ch3;
223+ ret[index++] = st3;
224+ }
225+ }
226+ }
227+ ICHECK (b64strlen (b64str) == index) << " base64 decoding fails" ;
228+ }
229+
230+ /* !
231+ * \brief Export TVM runtime module to base64 stream including its submodules.
232+ * Note that this targets modules that are binary serializable and DSOExportable.
233+ * \param module The runtime module to export
234+ * \return std::string The content of exported file
235+ */
236+ std::string ExportModuleToBase64 (tvm::runtime::Module module ) {
237+ static const tvm::runtime::PackedFunc* f_to_str =
238+ tvm::runtime::Registry::Get (" export_runtime_module" );
239+ ICHECK (f_to_str) << " IndexError: Cannot find the packed function "
240+ " `export_runtime_module` in the global registry" ;
241+ return (*f_to_str)(module );
242+ }
243+
244+ struct Deleter { // deleter
245+ explicit Deleter (std::string file_name) { this ->file_name = file_name; }
246+ void operator ()(FILE* p) const {
247+ fclose (p);
248+ ICHECK (remove (file_name.c_str ()) == 0 )
249+ << " remove temporary file (" << file_name << " ) unsuccessfully" ;
250+ }
251+ std::string file_name;
252+ };
253+
254+ /* !
255+ * \brief Import TVM runtime module from base64 stream
256+ * Note that this targets modules that are binary serializable and DSOExportable.
257+ * \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
258+ * \return runtime::Module runtime module constructed from the given stream
259+ */
260+ tvm::runtime::Module ImportModuleFromBase64 (std::string base64str) {
261+ auto length = b64strlen (base64str);
262+
263+ std::vector<uint8_t > bytes (length); // bytes stream
264+ b64decode (base64str, bytes.data ());
265+
266+ auto now = std::chrono::system_clock::now ();
267+ auto in_time_t = std::chrono::system_clock::to_time_t (now);
268+ std::stringstream datetime;
269+ datetime << std::put_time (std::localtime (&in_time_t ), " %Y-%m-%d-%X" );
270+ const std::string file_name = " tmp-module-" + datetime.str () + " .so" ;
271+ LOG (INFO) << file_name;
272+ std::unique_ptr<FILE, Deleter> pFile (fopen (file_name.c_str (), " wb" ), Deleter (file_name));
273+ fwrite (bytes.data (), sizeof (uint8_t ), length, pFile.get ());
274+ fflush (pFile.get ());
275+
276+ std::string load_f_name = " runtime.module.loadfile_so" ;
277+ const tvm::runtime::PackedFunc* f = tvm::runtime::Registry::Get (load_f_name);
278+ ICHECK (f != nullptr ) << " Loader for `.so` files is not registered,"
279+ << " resolved to (" << load_f_name << " ) in the global registry."
280+ << " Ensure that you have loaded the correct runtime code, and"
281+ << " that you are on the correct hardware architecture." ;
282+ tvm::runtime::Module ret = (*f)(file_name, " " );
283+ return ret;
284+ }
285+
245286char * tvm_contrib_torch_encode (TVMContribTorchRuntimeModule* runtime_module) {
246- std::string std = tvm::contrib::serialize (runtime_module->mod );
287+ std::string std = ExportModuleToBase64 (runtime_module->mod );
247288 char * ret = new char [std.length () + 1 ];
248289 snprintf (ret, std.length () + 1 , " %s" , std.c_str ());
249290 return ret;
250291}
251292
252293TVMContribTorchRuntimeModule* tvm_contrib_torch_decode (const char * state) {
253- tvm::runtime::Module ret = tvm::contrib::deserialize (state);
294+ tvm::runtime::Module ret = ImportModuleFromBase64 (state);
254295 return new TVMContribTorchRuntimeModule (ret);
255296}
256297
0 commit comments