Skip to content

Commit 2c47e06

Browse files
mehrdadhpfk-beta
authored andcommitted
Add order to functions in C Codegen (apache#10590)
* Add function ordering to C Codegen * trigger * fix comment * address comments * add test * add unorder check * fix test * address comments
1 parent 9bdda07 commit 2c47e06

File tree

9 files changed

+161
-7
lines changed

9 files changed

+161
-7
lines changed

include/tvm/runtime/module.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,14 @@ inline const ModuleNode* Module::operator->() const {
251251
return static_cast<const ModuleNode*>(get());
252252
}
253253

254+
inline std::ostream& operator<<(std::ostream& out, const Module& module) {
255+
out << "Module(type_key= ";
256+
out << module->type_key();
257+
out << ")";
258+
259+
return out;
260+
}
261+
254262
} // namespace runtime
255263
} // namespace tvm
256264

src/runtime/hexagon/hexagon_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ HexagonHostModuleNode::HexagonHostModuleNode(std::string data, std::string fmt,
5353
PackedFunc HexagonHostModuleNode::GetFunction(const std::string& name,
5454
const ObjectPtr<Object>& sptr_to_self) {
5555
LOG(FATAL) << "HexagonHostModuleNode::GetFunction is not implemented.";
56-
return nullptr;
56+
return PackedFunc();
5757
}
5858

5959
std::string HexagonHostModuleNode::GetSource(const std::string& format) {

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
8686
LOG(FATAL) << "Unknown packed function: " << name;
8787
return PackedFunc();
8888
}
89-
return nullptr;
9089
}
9190

9291
/*!

src/runtime/vm/executable.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
105105
});
106106
} else {
107107
LOG(FATAL) << "Unknown packed function: " << name;
108-
return PackedFunc(nullptr);
108+
return PackedFunc();
109109
}
110110
}
111111

src/target/source/codegen_c_host.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
#include <tvm/runtime/module.h>
2929
#include <tvm/target/codegen.h>
3030

31+
#include <algorithm>
3132
#include <string>
33+
#include <utility>
3234
#include <vector>
3335

3436
#include "../../support/str_escape.h"
@@ -362,22 +364,37 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
362364
Map<String, LinkedParam> linked_params;
363365
PrimFunc aot_executor_fn;
364366

367+
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
365368
for (auto kv : mod->functions) {
366369
// Make sure that the executor function is the last one to be code generated so that all the
367-
// symbols are available to tvm_run_func
370+
// symbols are available to __tvm_main__
368371
auto fun_name = std::string(kv.first->name_hint);
369372
bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
370373

371374
if (is_aot_executor_fn) {
372375
aot_executor_fn = Downcast<PrimFunc>(kv.second);
373376
continue;
374377
}
378+
funcs.push_back(kv);
379+
}
375380

381+
// Sort functions
382+
std::sort(funcs.begin(), funcs.end(),
383+
[](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
384+
std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
385+
std::string name_hint_a = kv_a.first->name_hint;
386+
std::string name_hint_b = kv_b.first->name_hint;
387+
return name_hint_a < name_hint_b;
388+
});
389+
390+
// Add all functions except __tvm_main__
391+
for (auto& kv : funcs) {
376392
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
377393
auto f = Downcast<PrimFunc>(kv.second);
378394
cg.AddFunction(f);
379395
}
380396

397+
// Add __tvm_main__
381398
if (aot_executor_fn.defined()) {
382399
cg.AddFunction(aot_executor_fn);
383400
}

src/target/source/codegen_c_host.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <string>
2828
#include <unordered_map>
29+
#include <utility>
2930
#include <vector>
3031

3132
#include "codegen_c.h"
@@ -42,7 +43,13 @@ class CodeGenCHost : public CodeGenC {
4243

4344
void InitGlobalContext();
4445
void AddFunction(const PrimFunc& f);
45-
46+
/*!
47+
* \brief Add functions from the (unordered) range to the current module in a deterministic
48+
* order. This helps with debugging.
49+
*
50+
* \param functions A vector of unordered range of current module.
51+
*/
52+
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
4653
void DefineModuleName();
4754

4855
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)

src/target/source/interface_c.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class InterfaceCNode : public runtime::ModuleNode {
9090
}
9191

9292
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
93-
return PackedFunc(nullptr);
93+
return PackedFunc();
9494
}
9595

9696
private:

src/target/source/source_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
171171

172172
std::string GetFormat() { return fmt_; }
173173
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
174-
return PackedFunc(nullptr);
174+
return PackedFunc();
175175
}
176176

177177
void SaveToFile(const std::string& file_name, const std::string& format) final {

tests/cpp/c_codegen_test.cc

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gmock/gmock.h>
21+
#include <gtest/gtest.h>
22+
#include <tvm/driver/driver_api.h>
23+
#include <tvm/ir/type.h>
24+
#include <tvm/node/reflection.h>
25+
#include <tvm/runtime/metadata.h>
26+
#include <tvm/runtime/module.h>
27+
#include <tvm/target/target.h>
28+
#include <tvm/te/operation.h>
29+
30+
TEST(CCodegen, MainFunctionOrder) {
31+
using namespace tvm;
32+
using namespace tvm::te;
33+
34+
std::string tvm_module_main = std::string(runtime::symbol::tvm_module_main);
35+
36+
tvm::Target target_c = tvm::Target("c -keys=cpu -link-params=0");
37+
38+
const int n = 4;
39+
Array<PrimExpr> shape{n};
40+
41+
auto A = placeholder(shape, DataType::Float(32), "A");
42+
auto B = placeholder(shape, DataType::Float(32), "B");
43+
44+
auto elemwise_add = compute(
45+
A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add");
46+
47+
auto fcreate = [=]() {
48+
With<Target> llvm_scope(target_c);
49+
return create_schedule({elemwise_add->op});
50+
};
51+
52+
auto args = Array<Tensor>({A, B, elemwise_add});
53+
54+
std::unordered_map<Tensor, Buffer> binds;
55+
auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds);
56+
Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
57+
runtime::Module module = build(inputs, Target());
58+
Array<String> functions = module->GetFunction("get_func_names", false)();
59+
60+
ICHECK(functions.back().compare(tvm_module_main) == 0);
61+
}
62+
63+
auto BuildLowered(std::string op_name, tvm::Target target) {
64+
using namespace tvm;
65+
using namespace tvm::te;
66+
67+
// The shape of input tensors.
68+
const int n = 4;
69+
Array<PrimExpr> shape{n};
70+
71+
auto A = placeholder(shape, DataType::Float(32), "A");
72+
auto B = placeholder(shape, DataType::Float(32), "B");
73+
74+
auto op = compute(
75+
A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, op_name);
76+
77+
auto fcreate_s = [=]() {
78+
With<Target> llvm_scope(target);
79+
return create_schedule({op->op});
80+
};
81+
82+
auto args = Array<Tensor>({A, B, op});
83+
std::unordered_map<Tensor, Buffer> binds;
84+
auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds);
85+
return lowered_s;
86+
}
87+
88+
bool IsSorted(tvm::Map<tvm::Target, tvm::IRModule> inputs) {
89+
std::vector<std::string> schedule_names;
90+
for (auto const& module : inputs) {
91+
for (auto const& func : module.second->functions) {
92+
schedule_names.push_back(func.first->name_hint);
93+
}
94+
}
95+
return std::is_sorted(schedule_names.begin(), schedule_names.end());
96+
}
97+
98+
TEST(CCodegen, FunctionOrder) {
99+
using testing::_;
100+
using testing::ElementsAre;
101+
using testing::StrEq;
102+
using namespace tvm;
103+
using namespace tvm::te;
104+
105+
Target target = Target("c -keys=cpu -link-params=0");
106+
107+
// add schedules in reverse order
108+
Map<tvm::Target, IRModule> inputs;
109+
inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered("op_2", target));
110+
inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered("op_1", target));
111+
112+
for (uint32_t counter = 99; IsSorted(inputs) && counter > 0; counter--) {
113+
std::string op_name = "op_" + std::to_string(counter);
114+
inputs.Set(Target("c -keys=cpu -link-params=0"), BuildLowered(op_name, target));
115+
}
116+
117+
EXPECT_FALSE(IsSorted(inputs));
118+
119+
auto module = build(inputs, Target());
120+
Array<String> func_array = module->GetFunction("get_func_names", false)();
121+
std::vector<std::string> functions{func_array.begin(), func_array.end()};
122+
EXPECT_THAT(functions, ElementsAre(StrEq("op_1"), _, StrEq("op_2"), _));
123+
}

0 commit comments

Comments
 (0)