@@ -29,12 +29,7 @@ using tir::LoopRV;
2929
3030void ApplyTensorization (const tir::Schedule& sch, const String& func_name,
3131 const tir::PrimFuncNode* func, bool vectorize_init_loop) {
32- struct RewriteJob {
33- std::string block_name;
34- std::function<void (tir::BlockRV)> postproc_fn;
35- };
36-
37- std::vector<RewriteJob> jobs;
32+ std::vector<std::pair<std::string, std::function<void (tir::BlockRV)>>> jobs;
3833
3934 tir::PostOrderVisit (func->body , [=, &jobs](const ObjectRef& obj) -> bool {
4035 if (const auto * block = obj.as <tir::BlockNode>()) {
@@ -43,36 +38,34 @@ void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
4338 tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
4439 std::string block_name = block_sref->StmtAs <tir::BlockNode>()->name_hint ;
4540 if (block_name.find (" init" ) == std::string::npos) {
46- jobs.push_back ({ block_name, [sch, intrin_name](tir::BlockRV block) {
47- sch->Tensorize (block, intrin_name.value ());
48- } });
41+ jobs.emplace_back ( block_name, [sch, intrin_name](tir::BlockRV block) {
42+ sch->Tensorize (block, intrin_name.value ());
43+ });
4944 } else if (vectorize_init_loop) {
50- jobs.push_back ({ block_name, [sch](tir::BlockRV block) {
51- Array<BlockRV> child_blocks = sch->GetChildBlocks (block);
52- ICHECK (child_blocks.size () == 1 );
53- Array<LoopRV> init_loops = sch->GetLoops (child_blocks[0 ]);
54- ICHECK (init_loops.size () == 1 );
55- sch->Vectorize (init_loops[0 ]);
56- } });
45+ jobs.emplace_back ( block_name, [sch](tir::BlockRV block) {
46+ Array<BlockRV> child_blocks = sch->GetChildBlocks (block);
47+ ICHECK (child_blocks.size () == 1 );
48+ Array<LoopRV> init_loops = sch->GetLoops (child_blocks[0 ]);
49+ ICHECK (init_loops.size () == 1 );
50+ sch->Vectorize (init_loops[0 ]);
51+ });
5752 }
5853 }
5954 }
6055 return true ;
6156 });
6257
63- for (auto job : jobs) {
64- tir::BlockRV block = sch->GetBlock (job. block_name , func_name);
58+ for (auto kv : jobs) {
59+ tir::BlockRV block = sch->GetBlock (kv. first , func_name);
6560 sch->Unannotate (block, tir::attr::meta_schedule_auto_tensorize);
66- job. postproc_fn (block);
61+ kv. second (block);
6762 }
6863}
6964
7065class RewriteTensorizeNode : public PostprocNode {
7166 public:
72- // Inherited from PostprocNode
7367 void InitializeWithTuneContext (const TuneContext& context) final {}
7468
75- // Inherited from PostprocNode
7669 bool Apply (const tir::Schedule& sch) final ;
7770
7871 void VisitAttrs (tvm::AttrVisitor* v) {}
0 commit comments