Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/runtime/static_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,39 @@ class StaticLibraryNode final : public runtime::ModuleNode {
}
}

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(data_);
std::vector<std::string> func_names;
for (const auto func_name : func_names_) func_names.push_back(func_name);
stream->Write(func_names);
}

static Module LoadFromBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
auto n = make_object<StaticLibraryNode>();
// load data
std::string data;
ICHECK(stream->Read(&data)) << "Loading data failed";
n->data_ = std::move(data);

// load func names
std::vector<std::string> func_names;
ICHECK(stream->Read(&func_names)) << "Loading func names failed";
for (auto func_name : func_names) n->func_names_.push_back(String(func_name));

return Module(n);
}

void SaveToFile(const String& file_name, const String& format) final {
VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames()
<< " to '" << file_name << "'";
SaveBinaryToFile(file_name, data_);
}

// TODO(tvm-team): Make this module serializable
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const override { return ModulePropertyMask::kDSOExportable; }
int GetPropertyMask() const override {
return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable;
}

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand Down Expand Up @@ -103,6 +127,8 @@ Module LoadStaticLibrary(const std::string& filename, Array<String> func_names)
}

TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library")
.set_body_typed(StaticLibraryNode::LoadFromBinary);

} // namespace runtime
} // namespace tvm
3 changes: 1 addition & 2 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class ModuleSerializer {
// we will not produce import_tree_.
bool has_import_tree = true;

if (mod_->IsDSOExportable()) {
ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
if (export_dso) {
has_import_tree = !mod_->imports().empty();
}

Expand Down
41 changes: 40 additions & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,39 @@ class CSourceModuleNode : public runtime::ModuleNode {

String GetFormat() override { return fmt_; }

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(code_);
stream->Write(fmt_);

std::vector<std::string> func_names;
for (const auto func_name : func_names_) func_names.push_back(func_name);
std::vector<std::string> const_vars;
for (auto const_var : const_vars_) const_vars.push_back(const_var);
stream->Write(func_names);
stream->Write(const_vars);
}

static runtime::Module LoadFromBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);

std::string code, fmt;
ICHECK(stream->Read(&code)) << "Loading code failed";
ICHECK(stream->Read(&fmt)) << "Loading format failed";

std::vector<std::string> tmp_func_names, tmp_const_vars;
CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed";
CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed";

Array<String> func_names;
for (auto func_name : tmp_func_names) func_names.push_back(String(func_name));

Array<String> const_vars;
for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var));

auto n = make_object<CSourceModuleNode>(code, fmt, func_names, const_vars);
return runtime::Module(n);
}

void SaveToFile(const String& file_name, const String& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
Expand All @@ -130,7 +163,10 @@ class CSourceModuleNode : public runtime::ModuleNode {
}
}

int GetPropertyMask() const override { return runtime::ModulePropertyMask::kDSOExportable; }
int GetPropertyMask() const override {
return runtime::ModulePropertyMask::kBinarySerializable |
runtime::ModulePropertyMask::kDSOExportable;
}

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand All @@ -151,6 +187,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c")
.set_body_typed(CSourceModuleNode::LoadFromBinary);

/*!
* \brief A concrete class to get access to base methods of CodegenSourceBase.
*
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_roundtrip_runtime_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@


def test_csource_module():
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None)
# source module that is not binary serializable.
# Thus, it would raise an error.
assert not mod.is_binary_serializable
with pytest.raises(TVMError):
tvm.ir.load_json(tvm.ir.save_json(mod))
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], [])
assert mod.type_key == "c"
assert mod.is_binary_serializable
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.type_key == "c"
assert new_mod.is_binary_serializable


def test_aot_module():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_runtime_module_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_aot_module():
def test_property():
checker(
create_csource_module(),
expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True},
expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True},
)

checker(
Expand Down