Skip to content

Commit 2cae905

Browse files
authored
[TIR] Support pattern matching argmax/argmin generated by TOPI (#12827)
This PR introduces two reducers to TIR reduction part, so that rfactor and cross-thread reduction can be applied to those functions who contains argmax/argmin computation generated by TOPI.
1 parent 41b65a3 commit 2cae905

File tree

2 files changed

+233
-57
lines changed

2 files changed

+233
-57
lines changed

src/tir/schedule/primitive/reduction.cc

Lines changed: 80 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -297,60 +297,86 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
297297
*/
298298
struct ReducerRegistry {
299299
ReducerRegistry()
300-
: reducer_getters{CreateReducerGetter(
301-
/*n_buffers=*/1,
302-
[](const Array<Var>& x, const Array<Var>& y) {
303-
return Array<PrimExpr>{x[0] + y[0]};
304-
},
305-
[](const Array<PrimExpr>& values) {
306-
return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
307-
}),
308-
CreateReducerGetter(
309-
/*n_buffers=*/1,
310-
[](const Array<Var>& x, const Array<Var>& y) {
311-
return Array<PrimExpr>{x[0] * y[0]};
312-
},
313-
[](const Array<PrimExpr>& values) {
314-
return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
315-
}),
316-
CreateReducerGetter(
317-
/*n_buffers=*/1,
318-
[](const Array<Var>& x, const Array<Var>& y) {
319-
return Array<PrimExpr>{min(x[0], y[0])};
320-
},
321-
[](const Array<PrimExpr>& values) {
322-
return Array<PrimExpr>{max_value(values[0]->dtype)};
323-
}),
324-
CreateReducerGetter(
325-
/*n_buffers=*/1,
326-
[](const Array<Var>& x, const Array<Var>& y) {
327-
return Array<PrimExpr>{max(x[0], y[0])};
328-
},
329-
[](const Array<PrimExpr>& values) {
330-
return Array<PrimExpr>{min_value(values[0]->dtype)};
331-
}),
332-
CreateReducerGetter(
333-
/*n_buffers=*/2,
334-
[](const Array<Var>& x, const Array<Var>& y) {
335-
PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
336-
PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
337-
return Array<PrimExpr>{idx, val};
338-
},
339-
[](const Array<PrimExpr>& values) {
340-
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
341-
min_value(values[1]->dtype)};
342-
}),
343-
CreateReducerGetter(
344-
/*n_buffers=*/2,
345-
[](const Array<Var>& x, const Array<Var>& y) {
346-
PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
347-
PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
348-
return Array<PrimExpr>{idx, val};
349-
},
350-
[](const Array<PrimExpr>& values) {
351-
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
352-
max_value(values[1]->dtype)};
353-
})} {}
300+
: reducer_getters{
301+
CreateReducerGetter(
302+
/*n_buffers=*/1,
303+
[](const Array<Var>& x, const Array<Var>& y) {
304+
return Array<PrimExpr>{x[0] + y[0]};
305+
},
306+
[](const Array<PrimExpr>& values) {
307+
return Array<PrimExpr>{make_const(values[0]->dtype, 0)};
308+
}),
309+
CreateReducerGetter(
310+
/*n_buffers=*/1,
311+
[](const Array<Var>& x, const Array<Var>& y) {
312+
return Array<PrimExpr>{x[0] * y[0]};
313+
},
314+
[](const Array<PrimExpr>& values) {
315+
return Array<PrimExpr>{make_const(values[0]->dtype, 1)};
316+
}),
317+
CreateReducerGetter(
318+
/*n_buffers=*/1,
319+
[](const Array<Var>& x, const Array<Var>& y) {
320+
return Array<PrimExpr>{min(x[0], y[0])};
321+
},
322+
[](const Array<PrimExpr>& values) {
323+
return Array<PrimExpr>{max_value(values[0]->dtype)};
324+
}),
325+
CreateReducerGetter(
326+
/*n_buffers=*/1,
327+
[](const Array<Var>& x, const Array<Var>& y) {
328+
return Array<PrimExpr>{max(x[0], y[0])};
329+
},
330+
[](const Array<PrimExpr>& values) {
331+
return Array<PrimExpr>{min_value(values[0]->dtype)};
332+
}),
333+
CreateReducerGetter(
334+
/*n_buffers=*/2,
335+
[](const Array<Var>& x, const Array<Var>& y) {
336+
PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]);
337+
PrimExpr val = Select(x[1] >= y[1], x[1], y[1]);
338+
return Array<PrimExpr>{idx, val};
339+
},
340+
[](const Array<PrimExpr>& values) {
341+
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
342+
min_value(values[1]->dtype)};
343+
}),
344+
CreateReducerGetter(
345+
/*n_buffers=*/2,
346+
[](const Array<Var>& x, const Array<Var>& y) {
347+
PrimExpr idx =
348+
Select(Or(greater(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))),
349+
x[0], y[0]);
350+
PrimExpr val = Select(greater(x[1], y[1]), x[1], y[1]);
351+
return Array<PrimExpr>{idx, val};
352+
},
353+
[](const Array<PrimExpr>& values) {
354+
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
355+
min_value(values[1]->dtype)};
356+
}),
357+
CreateReducerGetter(
358+
/*n_buffers=*/2,
359+
[](const Array<Var>& x, const Array<Var>& y) {
360+
PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]);
361+
PrimExpr val = Select(x[1] <= y[1], x[1], y[1]);
362+
return Array<PrimExpr>{idx, val};
363+
},
364+
[](const Array<PrimExpr>& values) {
365+
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
366+
max_value(values[1]->dtype)};
367+
}),
368+
CreateReducerGetter(
369+
/*n_buffers=*/2,
370+
[](const Array<Var>& x, const Array<Var>& y) {
371+
PrimExpr idx = Select(
372+
Or(less(x[1], y[1]), And(equal(x[1], y[1]), less(x[0], y[0]))), x[0], y[0]);
373+
PrimExpr val = Select(less(x[1], y[1]), x[1], y[1]);
374+
return Array<PrimExpr>{idx, val};
375+
},
376+
[](const Array<PrimExpr>& values) {
377+
return Array<PrimExpr>{make_const(values[0]->dtype, -1),
378+
max_value(values[1]->dtype)};
379+
})} {}
354380

355381
static void RegisterReducer(
356382
int n_buffers, TypedPackedFunc<Array<PrimExpr>(Array<Var>, Array<Var>)> combiner_getter,

tests/python/unittest/test_tir_schedule_rfactor.py

Lines changed: 153 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-function-docstring,missing-module-docstring
18-
import sys
19-
2018
import pytest
2119
import tvm
2220
import tvm.testing
23-
from tvm import tir
21+
from tvm import te, tir, topi
2422
from tvm.script import tir as T
2523
from tvm.tir.schedule.testing import verify_trace_roundtrip
2624

@@ -1133,6 +1131,128 @@ def argmin_split_rfactor(
11331131
argmin_v1[i] = v_argmin_v1
11341132

11351133

1134+
@T.prim_func
1135+
def argmax_topi_rfactor(
1136+
placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"]
1137+
) -> None:
1138+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
1139+
placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
1140+
placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
1141+
placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
1142+
placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
1143+
for i0, i1_0, i1_1 in T.grid(1, 4, 8):
1144+
with T.block("placeholder_red_temp_rf"):
1145+
vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
1146+
T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
1147+
T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
1148+
with T.init():
1149+
placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
1150+
placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648
1151+
v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
1152+
placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1]
1153+
or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1]
1154+
and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
1155+
placeholder_red_temp_v0_rf[ax0, vi1_1],
1156+
vi1_0 * 8 + vi1_1,
1157+
)
1158+
v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
1159+
placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1],
1160+
placeholder_red_temp_v1_rf[ax0, vi1_1],
1161+
placeholder[ax0, vi1_0 * 8 + vi1_1],
1162+
)
1163+
placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf
1164+
placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf
1165+
for i0, i1_1 in T.grid(1, 8):
1166+
with T.block("placeholder_red_temp"):
1167+
vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
1168+
T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
1169+
T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0])
1170+
with T.init():
1171+
placeholder_red_temp_v0[ax0] = -1
1172+
placeholder_red_temp_v1[ax0] = -2147483648
1173+
v_placeholder_red_temp_v0: T.int32 = T.Select(
1174+
placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1]
1175+
or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1]
1176+
and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1],
1177+
placeholder_red_temp_v0[ax0],
1178+
placeholder_red_temp_v0_rf[ax0, vi1_1],
1179+
)
1180+
v_placeholder_red_temp_v1: T.int32 = T.Select(
1181+
placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1],
1182+
placeholder_red_temp_v1[ax0],
1183+
placeholder_red_temp_v1_rf[ax0, vi1_1],
1184+
)
1185+
placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
1186+
placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
1187+
for i0 in T.serial(1):
1188+
with T.block("placeholder_red"):
1189+
ax0 = T.axis.spatial(1, i0)
1190+
T.reads(placeholder_red_temp_v0[ax0])
1191+
T.writes(placeholder_red[ax0])
1192+
placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
1193+
1194+
1195+
@T.prim_func
1196+
def argmin_topi_rfactor(
1197+
placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"]
1198+
) -> None:
1199+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
1200+
placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32")
1201+
placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32")
1202+
placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32")
1203+
placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32")
1204+
for i0, i1_0, i1_1 in T.grid(1, 4, 8):
1205+
with T.block("placeholder_red_temp_rf"):
1206+
vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0])
1207+
T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1])
1208+
T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
1209+
with T.init():
1210+
placeholder_red_temp_v0_rf[ax0, vi1_1] = -1
1211+
placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647
1212+
v_placeholder_red_temp_v0_rf: T.int32 = T.Select(
1213+
placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1]
1214+
or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1]
1215+
and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1,
1216+
placeholder_red_temp_v0_rf[ax0, vi1_1],
1217+
vi1_0 * 8 + vi1_1,
1218+
)
1219+
v_placeholder_red_temp_v1_rf: T.int32 = T.Select(
1220+
placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1],
1221+
placeholder_red_temp_v1_rf[ax0, vi1_1],
1222+
placeholder[ax0, vi1_0 * 8 + vi1_1],
1223+
)
1224+
placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf
1225+
placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf
1226+
for i0, i1_1 in T.grid(1, 8):
1227+
with T.block("placeholder_red_temp"):
1228+
vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0])
1229+
T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1])
1230+
T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0])
1231+
with T.init():
1232+
placeholder_red_temp_v0[ax0] = -1
1233+
placeholder_red_temp_v1[ax0] = 2147483647
1234+
v_placeholder_red_temp_v0: T.int32 = T.Select(
1235+
placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1]
1236+
or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1]
1237+
and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1],
1238+
placeholder_red_temp_v0[ax0],
1239+
placeholder_red_temp_v0_rf[ax0, vi1_1],
1240+
)
1241+
v_placeholder_red_temp_v1: T.int32 = T.Select(
1242+
placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1],
1243+
placeholder_red_temp_v1[ax0],
1244+
placeholder_red_temp_v1_rf[ax0, vi1_1],
1245+
)
1246+
placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0
1247+
placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1
1248+
for i0 in T.serial(1):
1249+
with T.block("placeholder_red"):
1250+
ax0 = T.axis.spatial(1, i0)
1251+
T.reads(placeholder_red_temp_v0[ax0])
1252+
T.writes(placeholder_red[ax0])
1253+
placeholder_red[ax0] = placeholder_red_temp_v0[ax0]
1254+
1255+
11361256
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
11371257

11381258

@@ -1490,5 +1610,35 @@ def test_reduction_rfactor_argmax_init_buffer_not_match():
14901610
s.rfactor(ki, 1)
14911611

14921612

1613+
def test_reduction_rfactor_topi_argmax():
1614+
A = te.placeholder((1, 32), dtype="int32")
1615+
B = topi.argmax(A, axis=1)
1616+
argmax_topi = te.create_prim_func([A, B])
1617+
s = tir.Schedule(argmax_topi, debug_mask="all")
1618+
argmax = s.get_block("placeholder_red_temp")
1619+
_, k = s.get_loops(argmax)
1620+
_, ki = s.split(k, [None, 8])
1621+
rf_block = s.rfactor(ki, 1)
1622+
tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor)
1623+
assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
1624+
assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp")))
1625+
verify_trace_roundtrip(s, mod=argmax_topi)
1626+
1627+
1628+
def test_reduction_rfactor_topi_argmin():
1629+
A = te.placeholder((1, 32), dtype="int32")
1630+
B = topi.argmin(A, axis=1)
1631+
argmin_topi = te.create_prim_func([A, B])
1632+
s = tir.Schedule(argmin_topi, debug_mask="all")
1633+
argmin = s.get_block("placeholder_red_temp")
1634+
_, k = s.get_loops(argmin)
1635+
_, ki = s.split(k, [None, 8])
1636+
rf_block = s.rfactor(ki, 1)
1637+
tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor)
1638+
assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf")))
1639+
assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp")))
1640+
verify_trace_roundtrip(s, mod=argmin_topi)
1641+
1642+
14931643
if __name__ == "__main__":
14941644
tvm.testing.main()

0 commit comments

Comments
 (0)