Skip to content

Commit 0070b6c

Browse files
masahiSiyuan FengspectrometerHBHjinhongyiiMasterJH5574
authored
[TIR] Add TileWithTensorIntrin (apache#11075)
Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 68beae9 commit 0070b6c

File tree

5 files changed

+300
-0
lines changed

5 files changed

+300
-0
lines changed

python/tvm/tir/schedule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from .trace import Trace
2525

2626
from . import analysis
27+
from . import transform
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Transformation on TIR schedule."""
18+
from typing import Optional
19+
20+
from tvm.tir.schedule import Schedule, BlockRV, LoopRV
21+
from . import _ffi_api
22+
23+
24+
def tile_with_tensor_intrin(sch: Schedule, block: BlockRV, intrin_name: str) -> Optional[LoopRV]:
25+
"""Tile a subset of loops in the block according to the given tensor intrinsic.
26+
27+
Parameters
28+
----------
29+
sch : Schedule
30+
The schedule to which tiling is applied
31+
block : BlockRV
32+
The block whose subset of loops will be tiled
33+
intrin_name : str
34+
The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
35+
36+
Returns
37+
-------
38+
tiled_loop_rv : Optional[LoopRV]
39+
LoopRV corresponding to the outermost loop of a block tiled according to the given intrin
40+
NullOpt if no valid loop mapping is found
41+
"""
42+
return _ffi_api.TileWithTensorIntrin(sch, block, intrin_name) # type: ignore

src/tir/schedule/transform.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,68 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
136136
throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
137137
}
138138

139+
Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
140+
const String& intrin_name) {
141+
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
142+
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
143+
if (!opt_tensorize_info) return NullOpt;
144+
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
145+
// Construct a mapping from tir loops back to LoopRVs
146+
Map<tir::StmtSRef, LoopRV> loop2rv;
147+
{
148+
Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
149+
for (const LoopRV& loop_rv : loop_rvs) {
150+
loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
151+
}
152+
}
153+
// Split the loops
154+
arith::Analyzer analyzer;
155+
std::unordered_set<const tir::StmtSRefNode*> inner_loops;
156+
std::vector<LoopRV> reorder_suffix;
157+
reorder_suffix.resize(info->loop_map.size());
158+
for (const auto& kv : info->loop_map) {
159+
// Extract mapping (block_loop => desc_loop)
160+
const tir::StmtSRef& block_loop_sref = kv.first;
161+
const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
162+
const tir::ForNode* desc_loop = kv.second.get();
163+
ICHECK(block_loop != nullptr && desc_loop != nullptr);
164+
// Extract the loop extent
165+
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
166+
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
167+
const auto* int_block_extent = block_extent.as<IntImmNode>();
168+
const auto* int_desc_extent = desc_extent.as<IntImmNode>();
169+
ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
170+
// Check divisibility
171+
int64_t total = int_block_extent->value;
172+
int64_t inner = int_desc_extent->value;
173+
ICHECK_EQ(total % inner, 0);
174+
int64_t outer = int_block_extent->value / int_desc_extent->value;
175+
// Do the split
176+
Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)});
177+
ICHECK_EQ(split.size(), 2);
178+
inner_loops.insert(sch->GetSRef(split[1]).operator->());
179+
// The inner split will be reordered to the loop domain that is tensorized
180+
int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop));
181+
reorder_suffix[desc_loop_index] = split[1];
182+
}
183+
// Reorder the loops
184+
std::vector<LoopRV> reorder_list;
185+
bool meet = false;
186+
Array<LoopRV> all_loops = sch->GetLoops(block_rv);
187+
for (const LoopRV& loop : all_loops) {
188+
if (inner_loops.count(sch->GetSRef(loop).operator->())) {
189+
meet = true;
190+
} else if (meet) {
191+
reorder_list.push_back(loop);
192+
}
193+
}
194+
reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
195+
sch->Reorder(reorder_list);
196+
ICHECK(!reorder_suffix.empty());
197+
return reorder_suffix[0];
198+
}
199+
200+
TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
201+
139202
} // namespace tir
140203
} // namespace tvm

src/tir/schedule/transform.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
2020
#define TVM_TIR_SCHEDULE_TRANSFORM_H_
2121

22+
#include <tvm/tir/schedule/schedule.h>
2223
#include <tvm/tir/schedule/state.h>
2324

2425
namespace tvm {
@@ -104,6 +105,18 @@ Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, c
104105
void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
105106
Stmt* src_stmt, Stmt* tgt_stmt);
106107

108+
/*!
109+
* \brief Tile a subset of loops in the block according to the given tensor intrinsic.
110+
* \param self The schedule to which tiling is applied
111+
* \param block_rv The block whose subset of loops will be tiled
112+
* \param intrin_name The name of a tensor intrinsic, must be registerd via
113+
* TensorIntrin.register(...) beforehand
114+
* \return LoopRV corresponding to the outermost loop of a
115+
* block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
116+
*/
117+
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
118+
const String& intrin_name);
119+
107120
} // namespace tir
108121
} // namespace tvm
109122

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
19+
20+
from tvm.tir import Schedule
21+
from tvm.script import tir as T
22+
from tvm.tir.schedule.transform import tile_with_tensor_intrin
23+
24+
25+
@tvm.script.ir_module
26+
class DenseVNNIModule:
27+
@T.prim_func
28+
def main(
29+
placeholder: T.Buffer[(1024, 1024), "uint8"],
30+
placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
31+
compute: T.Buffer[(1024, 1024), "int32"],
32+
) -> None:
33+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
34+
with T.block("root"):
35+
T.reads()
36+
T.writes()
37+
for i0, i1, i2 in T.grid(1024, 1024, 1024):
38+
with T.block("compute"):
39+
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
40+
T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
41+
T.writes(compute[i, j])
42+
with T.init():
43+
compute[i, j] = 0
44+
compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
45+
placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
46+
)
47+
48+
49+
@tvm.script.ir_module
50+
class DenseVNNIModuleTiled:
51+
@T.prim_func
52+
def main(
53+
placeholder: T.Buffer[(1024, 1024), "uint8"],
54+
placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
55+
compute: T.Buffer[(1024, 1024), "int32"],
56+
) -> None:
57+
# function attr dict
58+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
59+
# body
60+
# with T.block("root")
61+
for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4):
62+
with T.block("compute"):
63+
i = T.axis.spatial(1024, i0)
64+
j = T.axis.spatial(1024, i1_0 * 16 + i1_1)
65+
k = T.axis.reduce(1024, i2_0 * 4 + i2_1)
66+
T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4])
67+
T.writes(compute[i, j])
68+
with T.init():
69+
compute[i, j] = 0
70+
compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast(
71+
placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32"
72+
)
73+
74+
75+
@tvm.script.ir_module
76+
class Conv2dNCHWcVNNIModule:
77+
@T.prim_func
78+
def main(
79+
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
80+
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
81+
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
82+
) -> None:
83+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
84+
for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
85+
with T.block("conv2d_NCHWc_int8"):
86+
(
87+
n,
88+
oc_chunk,
89+
oh,
90+
ow,
91+
oc_block,
92+
kh,
93+
kw,
94+
ic_outer,
95+
ic_f_inner,
96+
ic_s_inner,
97+
) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
98+
T.reads(
99+
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
100+
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
101+
)
102+
T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
103+
with T.init():
104+
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
105+
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
106+
n, oc_chunk, oh, ow, oc_block
107+
] + T.cast(
108+
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
109+
) * T.cast(
110+
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
111+
"int32",
112+
)
113+
114+
115+
@tvm.script.ir_module
116+
class Conv2dNCHWcVNNIModuleTiled:
117+
@T.prim_func
118+
def main(
119+
placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
120+
placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
121+
conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
122+
) -> None:
123+
# function attr dict
124+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
125+
# body
126+
# with T.block("root")
127+
for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid(
128+
1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4
129+
):
130+
with T.block("conv2d_NCHWc_int8"):
131+
n = T.axis.spatial(1, 0)
132+
oc_chunk, oh, ow, oc_block = T.axis.remap("SSSS", [i1, i2, i3, i4_1])
133+
kh = T.axis.reduce(1, 0)
134+
kw = T.axis.reduce(1, 0)
135+
ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("RRR", [i7, i8, i9_1])
136+
T.reads(
137+
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
138+
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
139+
)
140+
T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
141+
with T.init():
142+
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
143+
conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
144+
n, oc_chunk, oh, ow, oc_block
145+
] + T.cast(
146+
placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
147+
) * T.cast(
148+
placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
149+
"int32",
150+
)
151+
152+
153+
def test_tile_with_tensor_intrin_dense_vnni():
154+
s = Schedule(DenseVNNIModule)
155+
block = s.get_block("compute")
156+
157+
tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
158+
159+
_, _, _, i1_1, _ = s.get_loops(block)
160+
161+
assert s.get(tiled_loop) == s.get(i1_1)
162+
tvm.ir.assert_structural_equal(s.mod, DenseVNNIModuleTiled)
163+
164+
165+
def test_tile_with_tensor_intrin_conv2d_nchwc_vnni():
166+
s = Schedule(Conv2dNCHWcVNNIModule)
167+
block = s.get_block("conv2d_NCHWc_int8")
168+
169+
tiled_loop = tile_with_tensor_intrin(s, block, VNNI_DOT_16x4_INTRIN)
170+
171+
tiled_loops = s.get_loops(block)
172+
173+
assert len(tiled_loops) == 12
174+
assert s.get(tiled_loop) == s.get(tiled_loops[-2])
175+
176+
tvm.ir.assert_structural_equal(s.mod, Conv2dNCHWcVNNIModuleTiled)
177+
178+
179+
if __name__ == "__main__":
180+
test_tile_with_tensor_intrin_dense_vnni()
181+
test_tile_with_tensor_intrin_conv2d_nchwc_vnni()

0 commit comments

Comments
 (0)