Skip to content

Commit 0e99e8a

Browse files
author
ZihengJiang
committed
Add tile operation
1 parent 0c72ca9 commit 0e99e8a

File tree

5 files changed

+36
-0
lines changed

5 files changed

+36
-0
lines changed

include/tvm/schedule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class Schedule : public NodeRef {
100100
* \return reference to self.
101101
*/
102102
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
103+
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
104+
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
105+
Expr x_factor, Expr y_factor); // NOLINT(*)
103106
};
104107

105108
/*!

python/tvm/schedule.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,8 @@ def reorder(self, *args):
107107
The order to be ordered
108108
"""
109109
_function_internal._ScheduleReorder(self, args)
110+
111+
def tile(self, x_parent, y_parent, x_factor, y_factor):
112+
x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile(
113+
self, x_parent, y_parent, x_factor, y_factor)
114+
return x_outer, y_outer, x_inner, y_inner

src/c_api/c_api_lang.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,13 @@ TVM_REGISTER_API(_ScheduleReorder)
151151
.reorder(args.at(1));
152152
});
153153

154+
TVM_REGISTER_API(_ScheduleTile)
155+
.set_body([](const ArgStack& args, RetValue *ret) {
156+
IterVar x_outer, y_outer, x_inner, y_inner;
157+
args.at(0).operator Schedule()
158+
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
159+
&x_inner, &y_inner, args.at(3), args.at(4));
160+
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
161+
});
154162

155163
} // namespace tvm

src/lang/schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
148148
return *this;
149149
}
150150

151+
Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
152+
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
153+
Expr x_factor, Expr y_factor) { // NOLINT(*)
154+
155+
split(x_parent, p_x_outer, p_x_inner, x_factor);
156+
split(y_parent, p_y_outer, p_y_inner, y_factor);
157+
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
158+
return *this;
159+
}
160+
151161
IterVarRelation SplitNode::make(
152162
IterVar parent, IterVar outer,
153163
IterVar inner, Expr factor) {

tests/python/test_schedule.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,18 @@ def test_reorder():
3434
sch_T.reorder(*order)
3535
assert tuple(sch_T.leaf_iter_vars) == order
3636

37+
def test_tile():
38+
m = tvm.Var('m')
39+
n = tvm.Var('n')
40+
A = tvm.placeholder((m, n), name='A')
41+
T = tvm.compute((m, n), lambda i, j: A[i, j])
42+
43+
sch_T = tvm.Schedule(T.op, scope="shared")
44+
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
45+
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)
3746

3847
if __name__ == "__main__":
3948
test_schedule_create()
4049
test_reorder()
50+
test_tile()
4151

0 commit comments

Comments
 (0)