Skip to content

Commit b87ef32

Browse files
committed
move TilingwithTensorIntrin to auto_tensorize.cc
1 parent 2fc118b commit b87ef32

File tree

3 files changed

+126
-65
lines changed

3 files changed

+126
-65
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "auto_tensorize.h"
21+
22+
#include "../../tir/schedule/analysis.h"
23+
24+
namespace tvm {
25+
namespace meta_schedule {
26+
27+
using tir::LoopRV;
28+
29+
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
30+
const String& intrin_name) {
31+
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
32+
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
33+
if (!opt_tensorize_info) return NullOpt;
34+
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
35+
// Construct a mapping from tir loops back to LoopRVs
36+
Map<tir::StmtSRef, LoopRV> loop2rv;
37+
{
38+
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
39+
for (const LoopRV& loop_rv : loop_rvs) {
40+
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
41+
}
42+
}
43+
// Split the loops
44+
arith::Analyzer analyzer;
45+
std::unordered_set<const tir::StmtSRefNode*> inner_loops;
46+
std::vector<LoopRV> reorder_suffix;
47+
reorder_suffix.resize(info->loop_map.size());
48+
for (const auto& kv : info->loop_map) {
49+
// Extract mapping (block_loop => desc_loop)
50+
const tir::StmtSRef& block_loop_sref = kv.first;
51+
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
52+
const tir::ForNode* desc_loop = kv.second.get();
53+
ICHECK(block_loop != nullptr && desc_loop != nullptr);
54+
// Extract the loop extent
55+
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
56+
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
57+
const auto* int_block_extent = block_extent.as<IntImmNode>();
58+
const auto* int_desc_extent = desc_extent.as<IntImmNode>();
59+
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
60+
// Check divisibility
61+
int64_t total = int_block_extent->value;
62+
int64_t inner = int_desc_extent->value;
63+
ICHECK_EQ(total % inner, 0);
64+
int64_t outer = int_block_extent->value / int_desc_extent->value;
65+
// Do the split
66+
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
67+
ICHECK_EQ(split.size(), 2);
68+
inner_loops.insert(sch->GetSRef(split[1]).operator->());
69+
// The inner split will be reordered to the loop domain that is tensorized
70+
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
71+
reorder_suffix[desc_loop_index] = split[1];
72+
}
73+
// Reorder the loops
74+
std::vector<LoopRV> reorder_list;
75+
bool meet = false;
76+
Array<LoopRV> all_loops = sch->GetLoops(block_rv);
77+
for (const LoopRV& loop : all_loops) {
78+
if (inner_loops.count(sch->GetSRef(loop).operator->())) {
79+
meet = true;
80+
} else if (meet) {
81+
reorder_list.push_back(loop);
82+
}
83+
}
84+
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
85+
sch->Reorder(reorder_list);
86+
ICHECK(!reorder_suffix.empty());
87+
return reorder_suffix[0];
88+
}
89+
90+
} // namespace meta_schedule
91+
} // namespace tvm
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_
20+
#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_
21+
22+
#include <tvm/tir/schedule/schedule.h>
23+
24+
namespace tvm {
25+
namespace meta_schedule {
26+
27+
Optional<tir::LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
28+
const String& intrin_name);
29+
} // namespace meta_schedule
30+
} // namespace tvm
31+
32+
#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_

src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc

Lines changed: 3 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,78 +18,16 @@
1818
*/
1919

2020
#include "../utils.h"
21+
#include "auto_tensorize.h"
2122
#include "multi_level_tiling.h"
22-
#include "../../tir/schedule/analysis.h"
2323

2424
namespace tvm {
2525
namespace meta_schedule {
2626

27-
using tir::LoopRV;
28-
29-
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
30-
const String& intrin_name) {
31-
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
32-
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
33-
if (!opt_tensorize_info) return NullOpt;
34-
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
35-
// Construct a mapping from tir loops back to LoopRVs
36-
Map<tir::StmtSRef, LoopRV> loop2rv;
37-
{
38-
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
39-
for (const LoopRV& loop_rv : loop_rvs) {
40-
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
41-
}
42-
}
43-
// Split the loops
44-
arith::Analyzer analyzer;
45-
std::unordered_set<const tir::StmtSRefNode*> inner_loops;
46-
std::vector<LoopRV> reorder_suffix;
47-
reorder_suffix.resize(info->loop_map.size());
48-
for (const auto& kv : info->loop_map) {
49-
// Extract mapping (block_loop => desc_loop)
50-
const tir::StmtSRef& block_loop_sref = kv.first;
51-
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
52-
const tir::ForNode* desc_loop = kv.second.get();
53-
ICHECK(block_loop != nullptr && desc_loop != nullptr);
54-
// Extract the loop extent
55-
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
56-
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
57-
const auto* int_block_extent = block_extent.as<IntImmNode>();
58-
const auto* int_desc_extent = desc_extent.as<IntImmNode>();
59-
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
60-
// Check divisibility
61-
int64_t total = int_block_extent->value;
62-
int64_t inner = int_desc_extent->value;
63-
ICHECK_EQ(total % inner, 0);
64-
int64_t outer = int_block_extent->value / int_desc_extent->value;
65-
// Do the split
66-
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
67-
ICHECK_EQ(split.size(), 2);
68-
inner_loops.insert(sch->GetSRef(split[1]).operator->());
69-
// The inner split will be reordered to the loop domain that is tensorized
70-
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
71-
reorder_suffix[desc_loop_index] = split[1];
72-
}
73-
// Reorder the loops
74-
std::vector<LoopRV> reorder_list;
75-
bool meet = false;
76-
Array<LoopRV> all_loops = sch->GetLoops(block_rv);
77-
for (const LoopRV& loop : all_loops) {
78-
if (inner_loops.count(sch->GetSRef(loop).operator->())) {
79-
meet = true;
80-
} else if (meet) {
81-
reorder_list.push_back(loop);
82-
}
83-
}
84-
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
85-
sch->Reorder(reorder_list);
86-
ICHECK(!reorder_suffix.empty());
87-
return reorder_suffix[0];
88-
}
89-
9027
std::vector<State> TileForVNNI(State state) {
9128
const std::string intrin_name = "dot_16x4_vnni";
92-
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
29+
Optional<tir::LoopRV> tiled_loop_rv =
30+
TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
9331
ICHECK(tiled_loop_rv.defined());
9432
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
9533
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));

0 commit comments

Comments
 (0)