diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 7adfeb19c377..b9e63271cfa9 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -56,15 +56,39 @@ class StaticLibraryNode final : public runtime::ModuleNode { } } + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(data_); + std::vector func_names; + for (const auto func_name : func_names_) func_names.push_back(func_name); + stream->Write(func_names); + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + auto n = make_object(); + // load data + std::string data; + ICHECK(stream->Read(&data)) << "Loading data failed"; + n->data_ = std::move(data); + + // load func names + std::vector func_names; + ICHECK(stream->Read(&func_names)) << "Loading func names failed"; + for (auto func_name : func_names) n->func_names_.push_back(String(func_name)); + + return Module(n); + } + void SaveToFile(const String& file_name, const String& format) final { VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames() << " to '" << file_name << "'"; SaveBinaryToFile(file_name, data_); } - // TODO(tvm-team): Make this module serializable /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const override { return ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const override { + return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable; + } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -103,6 +127,8 @@ Module LoadStaticLibrary(const std::string& filename, Array func_names) } TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library") + .set_body_typed(StaticLibraryNode::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/target/codegen.cc b/src/target/codegen.cc index a221fa60e63c..67dbf101016c 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -84,8 +84,7 @@ class ModuleSerializer { // we will not produce import_tree_. bool has_import_tree = true; - if (mod_->IsDSOExportable()) { - ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules"; + if (export_dso) { has_import_tree = !mod_->imports().empty(); } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index c75f3008ef6b..90640a6db647 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -119,6 +119,39 @@ class CSourceModuleNode : public runtime::ModuleNode { String GetFormat() override { return fmt_; } + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(code_); + stream->Write(fmt_); + + std::vector func_names; + for (const auto func_name : func_names_) func_names.push_back(func_name); + std::vector const_vars; + for (auto const_var : const_vars_) const_vars.push_back(const_var); + stream->Write(func_names); + stream->Write(const_vars); + } + + static runtime::Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + + std::string code, fmt; + ICHECK(stream->Read(&code)) << "Loading code failed"; + ICHECK(stream->Read(&fmt)) << "Loading format failed"; + + std::vector tmp_func_names, tmp_const_vars; + CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed"; + CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed"; + + Array func_names; + for (auto func_name : tmp_func_names) func_names.push_back(String(func_name)); + + Array const_vars; + for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var)); + + auto n = make_object(code, fmt, func_names, const_vars); + return runtime::Module(n); + } + void SaveToFile(const String& file_name, const String& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); @@ -130,7 +163,10 @@ class CSourceModuleNode : public runtime::ModuleNode { } } - int GetPropertyMask() const override { return runtime::ModulePropertyMask::kDSOExportable; } + int GetPropertyMask() const override { + return runtime::ModulePropertyMask::kBinarySerializable | + runtime::ModulePropertyMask::kDSOExportable; + } bool ImplementsFunction(const String& name, bool query_imports) final { return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end(); @@ -151,6 +187,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, return runtime::Module(n); } +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c") + .set_body_typed(CSourceModuleNode::LoadFromBinary); + /*! * \brief A concrete class to get access to base methods of CodegenSourceBase. * diff --git a/tests/python/unittest/test_roundtrip_runtime_module.py b/tests/python/unittest/test_roundtrip_runtime_module.py index 6a1abeedd914..494143fc0bf2 100644 --- a/tests/python/unittest/test_roundtrip_runtime_module.py +++ b/tests/python/unittest/test_roundtrip_runtime_module.py @@ -25,12 +25,12 @@ def test_csource_module(): - mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None) - # source module that is not binary serializable. - # Thus, it would raise an error. - assert not mod.is_binary_serializable - with pytest.raises(TVMError): - tvm.ir.load_json(tvm.ir.save_json(mod)) + mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], []) + assert mod.type_key == "c" + assert mod.is_binary_serializable + new_mod = tvm.ir.load_json(tvm.ir.save_json(mod)) + assert new_mod.type_key == "c" + assert new_mod.is_binary_serializable def test_aot_module(): diff --git a/tests/python/unittest/test_runtime_module_property.py b/tests/python/unittest/test_runtime_module_property.py index 30af8d086a42..bd71e856d917 100644 --- a/tests/python/unittest/test_runtime_module_property.py +++ b/tests/python/unittest/test_runtime_module_property.py @@ -44,7 +44,7 @@ def create_aot_module(): def test_property(): checker( create_csource_module(), - expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True}, + expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True}, ) checker(