Skip to content

Commit fcc31ee

Browse files
committed
tensorize worked
1 parent 2b53437 commit fcc31ee

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/meta_schedule/postproc/rewrite_vnni.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
4949
tir::StmtSRef block_sref = sch->GetSRef(block);
5050
if (Optional<String> intrin_name =
5151
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
52-
tasks.push_back(std::make_tuple(block_sref->StmtAs<tir::BlockNode>()->name_hint,
53-
func_name, intrin_name.value()));
52+
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
53+
if (block_name.find("init") == std::string::npos) {
54+
tasks.push_back(std::make_tuple(block_name, func_name, intrin_name.value()));
55+
}
5456
}
5557
}
5658
return true;
@@ -59,6 +61,7 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
5961
}
6062

6163
bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
64+
LOG(INFO) << "Apply RewriteVNNI " << sch->mod();
6265
std::vector<BlockPosition> tasks;
6366
for (const auto& kv : sch->mod()->functions) {
6467
GlobalVar g_var = kv.first;
@@ -73,11 +76,13 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
7376
String intrin_name = std::get<2>(task);
7477
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize);
7578
sch->Tensorize(block_rv, intrin_name);
79+
LOG(INFO) << "After tensorize: " << sch->mod();
7680
}
7781
return true;
7882
}
7983

8084
Postproc RewriteVNNI() {
85+
LOG(INFO) << "RewriteVNNI is called";
8186
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
8287
return Postproc(n);
8388
}

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,8 @@ inline std::vector<State> MultiLevelTilingNode::TileForVNNI(State state) const {
453453
const std::string intrin_name = "dot_16x1x16_uint8_int8_int32_cascadelake";
454454
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
455455
ICHECK(tiled_loop_rv.defined());
456-
LOG(INFO) << "After TilingwithTensorIntrin" << state.sch->mod();
457456
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
458-
state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
457+
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
459458
result.push_back(state);
460459
return result;
461460
}

0 commit comments

Comments
 (0)