Skip to content

Commit a1229f6

Browse files
authored
[TIR] Handle nullptr returned by FindEntryFunc (#13852)
The FindEntryFunc function can return a null pointer. In development I got this situation, which appears as a segfault.
1 parent 7db77ad commit a1229f6

File tree

4 files changed

+58
-27
lines changed

4 files changed

+58
-27
lines changed

src/tir/analysis/stmt_finding.cc

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,31 @@ const BlockNode* FindAnchorBlock(const IRModule& mod) {
111111
std::vector<const BlockNode*> blocks;
112112
};
113113

114-
auto prim_func = FindEntryFunc(mod, nullptr);
114+
if (auto prim_func = FindEntryFunc(mod, nullptr)) {
115+
ReductionBlockCollector collector;
116+
collector(prim_func->body);
115117

116-
ReductionBlockCollector collector;
117-
collector(prim_func->body);
118+
const auto& candidates = collector.blocks;
118119

119-
const auto& candidates = collector.blocks;
120-
121-
if (candidates.empty()) {
122-
return nullptr;
123-
} else if (candidates.size() == 1) {
124-
return candidates[0];
125-
}
120+
if (candidates.empty()) {
121+
return nullptr;
122+
} else if (candidates.size() == 1) {
123+
return candidates[0];
124+
}
126125

127-
double best_flops = -1;
128-
int best_idx = 0;
129-
for (size_t i = 0; i < candidates.size(); ++i) {
130-
auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
131-
auto flops = EstimateTIRFlops(loop);
132-
if (flops > best_flops) {
133-
best_flops = flops;
134-
best_idx = i;
126+
double best_flops = -1;
127+
int best_idx = 0;
128+
for (size_t i = 0; i < candidates.size(); ++i) {
129+
auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
130+
auto flops = EstimateTIRFlops(loop);
131+
if (flops > best_flops) {
132+
best_flops = flops;
133+
best_idx = i;
134+
}
135135
}
136+
return candidates[best_idx];
136137
}
137-
return candidates[best_idx];
138+
return nullptr;
138139
}
139140

140141
TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) {

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ConcreteScheduleNode : public ScheduleNode {
5656
// `error_render_level_` is not visited
5757
// `symbol_table_` is not visited
5858
// `analyzer_` is not visited
59-
// `rgnd_state_` is not visited
59+
// `rand_state_` is not visited
6060
}
6161

6262
virtual ~ConcreteScheduleNode() = default;

src/tir/schedule/utils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,12 @@ inline std::unordered_set<std::string> GetBlockNames(const IRModule& mod) {
456456
std::unordered_set<std::string> block_names;
457457
};
458458

459-
auto prim_func = tir::FindEntryFunc(mod, nullptr);
460-
BlockNameCollector collector;
461-
collector(prim_func->body);
462-
return collector.block_names;
459+
if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) {
460+
BlockNameCollector collector;
461+
collector(prim_func->body);
462+
return collector.block_names;
463+
}
464+
return {};
463465
}
464466

465467
/*! \brief Query if the given block name exists in the module associated with the schedule */

tests/python/unittest/test_meta_schedule_database.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tvm.ir.module import IRModule
3131
from tvm.script import tir as T
3232
from tvm.tir import Schedule
33-
33+
from tvm import relay
3434

3535
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
3636
# fmt: off
@@ -93,10 +93,10 @@ def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Sched
9393
return sch
9494

9595

96-
def _create_tmp_database(tmpdir: str) -> ms.database.JSONDatabase:
96+
def _create_tmp_database(tmpdir: str, mod_eq: str = "structural") -> ms.database.JSONDatabase:
9797
path_workload = osp.join(tmpdir, "workloads.json")
9898
path_tuning_record = osp.join(tmpdir, "tuning_records.json")
99-
return ms.database.JSONDatabase(path_workload, path_tuning_record)
99+
return ms.database.JSONDatabase(path_workload, path_tuning_record, module_equality=mod_eq)
100100

101101

102102
def _equal_record(a: ms.database.TuningRecord, b: ms.database.TuningRecord):
@@ -583,5 +583,33 @@ def test_json_database_get_top_k(k, expected):
583583
assert result == expected
584584

585585

586+
def MatmulFunc() -> IRModule:
587+
a = relay.var("a", relay.TensorType((1024, 1024), "float32"))
588+
b = relay.var("b", relay.TensorType((1024, 1024), "float32"))
589+
func = relay.Function([a, b], relay.nn.matmul(a, b))
590+
return tvm.IRModule.from_expr(func)
591+
592+
593+
def MatmulPrimFunc() -> IRModule:
594+
return Matmul
595+
596+
597+
@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc])
598+
@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"])
599+
def test_json_database_commit_workload(f_mod, mod_eq):
600+
mod: IRModule = f_mod()
601+
with tempfile.TemporaryDirectory() as tmpdir:
602+
database = _create_tmp_database(tmpdir, mod_eq)
603+
database.commit_workload(mod)
604+
605+
606+
@pytest.mark.parametrize("f_mod", [MatmulPrimFunc, MatmulFunc])
607+
@pytest.mark.parametrize("mod_eq", ["structural", "ignore-ndarray", "anchor-block"])
608+
def test_memory_database_commit_workload(f_mod, mod_eq):
609+
mod: IRModule = f_mod()
610+
database = ms.database.MemoryDatabase(module_equality=mod_eq)
611+
database.commit_workload(mod)
612+
613+
586614
if __name__ == "__main__":
587615
tvm.testing.main()

0 commit comments

Comments
 (0)