@@ -49,8 +49,10 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
4949 tir::StmtSRef block_sref = sch->GetSRef (block);
5050 if (Optional<String> intrin_name =
5151 tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
52- tasks.push_back (std::make_tuple (block_sref->StmtAs <tir::BlockNode>()->name_hint ,
53- func_name, intrin_name.value ()));
52+ std::string block_name = block_sref->StmtAs <tir::BlockNode>()->name_hint ;
53+ if (block_name.find (" init" ) == std::string::npos) {
54+ tasks.push_back (std::make_tuple (block_name, func_name, intrin_name.value ()));
55+ }
5456 }
5557 }
5658 return true ;
@@ -59,6 +61,7 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
5961}
6062
6163bool RewriteVNNINode::Apply (const tir::Schedule& sch) {
64+ LOG (INFO) << " Apply RewriteVNNI " << sch->mod ();
6265 std::vector<BlockPosition> tasks;
6366 for (const auto & kv : sch->mod ()->functions ) {
6467 GlobalVar g_var = kv.first ;
@@ -73,11 +76,13 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
7376 String intrin_name = std::get<2 >(task);
7477 sch->Unannotate (block_rv, tir::attr::meta_schedule_auto_tensorize);
7578 sch->Tensorize (block_rv, intrin_name);
79+ LOG (INFO) << " After tensorize: " << sch->mod ();
7680 }
7781 return true ;
7882}
7983
8084Postproc RewriteVNNI () {
85+ LOG (INFO) << " RewriteVNNI is called" ;
8186 ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
8287 return Postproc (n);
8388}
0 commit comments