Skip to content

Commit 6ccde09

Browse files
committed
refactor rewrite_tensorize
1 parent 2ce2066 commit 6ccde09

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

src/meta_schedule/postproc/rewrite_tensorize.cc

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@ using tir::LoopRV;
2929

3030
void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
3131
const tir::PrimFuncNode* func, bool vectorize_init_loop) {
32-
struct RewriteJob {
33-
std::string block_name;
34-
std::function<void(tir::BlockRV)> postproc_fn;
35-
};
36-
37-
std::vector<RewriteJob> jobs;
32+
std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;
3833

3934
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) -> bool {
4035
if (const auto* block = obj.as<tir::BlockNode>()) {
@@ -43,36 +38,34 @@ void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
4338
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
4439
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
4540
if (block_name.find("init") == std::string::npos) {
46-
jobs.push_back({block_name, [sch, intrin_name](tir::BlockRV block) {
47-
sch->Tensorize(block, intrin_name.value());
48-
}});
41+
jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) {
42+
sch->Tensorize(block, intrin_name.value());
43+
});
4944
} else if (vectorize_init_loop) {
50-
jobs.push_back({block_name, [sch](tir::BlockRV block) {
51-
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
52-
ICHECK(child_blocks.size() == 1);
53-
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
54-
ICHECK(init_loops.size() == 1);
55-
sch->Vectorize(init_loops[0]);
56-
}});
45+
jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
46+
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
47+
ICHECK(child_blocks.size() == 1);
48+
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
49+
ICHECK(init_loops.size() == 1);
50+
sch->Vectorize(init_loops[0]);
51+
});
5752
}
5853
}
5954
}
6055
return true;
6156
});
6257

63-
for (auto job : jobs) {
64-
tir::BlockRV block = sch->GetBlock(job.block_name, func_name);
58+
for (auto kv : jobs) {
59+
tir::BlockRV block = sch->GetBlock(kv.first, func_name);
6560
sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
66-
job.postproc_fn(block);
61+
kv.second(block);
6762
}
6863
}
6964

7065
class RewriteTensorizeNode : public PostprocNode {
7166
public:
72-
// Inherited from PostprocNode
7367
void InitializeWithTuneContext(const TuneContext& context) final {}
7468

75-
// Inherited from PostprocNode
7669
bool Apply(const tir::Schedule& sch) final;
7770

7871
void VisitAttrs(tvm::AttrVisitor* v) {}

0 commit comments

Comments
 (0)