Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ inline const ModuleNode* Module::operator->() const {
return static_cast<const ModuleNode*>(get());
}

inline std::ostream& operator<<(std::ostream& out, const Module& module) {
out << "Module(type_key= ";
out << module->type_key();
out << ")";

return out;
}

} // namespace runtime
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ HexagonHostModuleNode::HexagonHostModuleNode(std::string data, std::string fmt,
PackedFunc HexagonHostModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
LOG(FATAL) << "HexagonHostModuleNode::GetFunction is not implemented.";
return nullptr;
return PackedFunc();
}

std::string HexagonHostModuleNode::GetSource(const std::string& format) {
Expand Down
1 change: 0 additions & 1 deletion src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc();
}
return nullptr;
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
return PackedFunc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this change is equivalent

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, PackedFunc() just matches with the rest of the codebase.

}
}

Expand Down
19 changes: 18 additions & 1 deletion src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
#include <tvm/runtime/module.h>
#include <tvm/target/codegen.h>

#include <algorithm>
#include <string>
#include <utility>
#include <vector>

#include "../../support/str_escape.h"
Expand Down Expand Up @@ -355,22 +357,37 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
Map<String, LinkedParam> linked_params;
PrimFunc aot_executor_fn;

std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
for (auto kv : mod->functions) {
// Make sure that the executor function is the last one to be code generated so that all the
// symbols are available to tvm_run_func
// symbols are available to __tvm_main__
auto fun_name = std::string(kv.first->name_hint);
bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();

if (is_aot_executor_fn) {
aot_executor_fn = Downcast<PrimFunc>(kv.second);
continue;
}
funcs.push_back(kv);
}

// Sort functions
std::sort(funcs.begin(), funcs.end(),
[](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
std::string name_hint_a = kv_a.first->name_hint;
std::string name_hint_b = kv_b.first->name_hint;
return name_hint_a < name_hint_b;
});

// Add all functions except __tvm_main__
for (auto& kv : funcs) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
cg.AddFunction(f);
}

// Add __tvm_main__
if (aot_executor_fn.defined()) {
cg.AddFunction(aot_executor_fn);
}
Expand Down
9 changes: 8 additions & 1 deletion src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "codegen_c.h"
Expand All @@ -42,7 +43,13 @@ class CodeGenCHost : public CodeGenC {

void InitGlobalContext();
void AddFunction(const PrimFunc& f);

/*!
* \brief Add functions from the (unordered) range to the current module in a deterministic
* order. This helps with debugging.
*
* \param functions A vector of unordered range of current module.
*/
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();

void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class InterfaceCNode : public runtime::ModuleNode {
}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
return PackedFunc(nullptr);
return PackedFunc();
}

private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {

std::string GetFormat() { return fmt_; }
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
return PackedFunc(nullptr);
return PackedFunc();
}

void SaveToFile(const std::string& file_name, const std::string& format) final {
Expand Down
108 changes: 108 additions & 0 deletions tests/cpp/c_codegen_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type.h>
#include <tvm/node/reflection.h>
#include <tvm/runtime/metadata.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>
#include <tvm/te/operation.h>

TEST(CCodegen, MainFunctionOrder) {
using namespace tvm;
using namespace tvm::te;

std::string tvm_module_main = std::string(runtime::symbol::tvm_module_main);

tvm::Target target_c = tvm::Target("c -keys=cpu -link-params=0");

const int n = 4;
Array<PrimExpr> shape{n};

auto A = placeholder(shape, DataType::Float(32), "A");
auto B = placeholder(shape, DataType::Float(32), "B");

auto elemwise_add = compute(
A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add");

auto fcreate = [=]() {
With<Target> llvm_scope(target_c);
return create_schedule({elemwise_add->op});
};

auto args = Array<Tensor>({A, B, elemwise_add});

std::unordered_map<Tensor, Buffer> binds;
auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds);
Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
runtime::Module module = build(inputs, Target());
Array<String> functions = module->GetFunction("get_func_names", false)();

ICHECK(functions.back().compare(tvm_module_main) == 0);
}

TEST(CCodegen, FunctionOrder) {
using testing::_;
using testing::ElementsAre;
using testing::StrEq;
using namespace tvm;
using namespace tvm::te;

auto target = Target("c -keys=cpu -link-params=0");

// The shape of input tensors.
const int n = 4;
Array<PrimExpr> shape{n};

auto A = placeholder(shape, DataType::Float(32), "A");
auto B = placeholder(shape, DataType::Float(32), "B");

auto op_1 = compute(
A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "op_1");

auto op_2 = compute(
A->shape, [&A, &B](PrimExpr i) { return A[i] - B[i]; }, "op_2");

auto fcreate_s1 = [=]() {
With<Target> llvm_scope(target);
return create_schedule({op_1->op});
};

auto fcreate_s2 = [=]() {
With<Target> llvm_scope(target);
return create_schedule({op_2->op});
};

auto args1 = Array<Tensor>({A, B, op_1});
auto args2 = Array<Tensor>({A, B, op_2});

std::unordered_map<Tensor, Buffer> binds;
auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "op_1", binds);
auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "op_2", binds);

// add schedules in reverse order
Map<tvm::Target, IRModule> inputs = {{target, lowered_s2}, {target, lowered_s1}};
auto module = build(inputs, Target());
Array<String> func_array = module->GetFunction("get_func_names", false)();
std::vector<std::string> functions{func_array.begin(), func_array.end()};
EXPECT_THAT(functions, ElementsAre(StrEq("op_1"), _, StrEq("op_2"), _));
}