|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=missing-function-docstring,missing-module-docstring |
18 | | -import sys |
19 | | - |
20 | 18 | import pytest |
21 | 19 | import tvm |
22 | 20 | import tvm.testing |
23 | | -from tvm import tir |
| 21 | +from tvm import te, tir, topi |
24 | 22 | from tvm.script import tir as T |
25 | 23 | from tvm.tir.schedule.testing import verify_trace_roundtrip |
26 | 24 |
|
@@ -1133,6 +1131,128 @@ def argmin_split_rfactor( |
1133 | 1131 | argmin_v1[i] = v_argmin_v1 |
1134 | 1132 |
|
1135 | 1133 |
|
| 1134 | +@T.prim_func |
| 1135 | +def argmax_topi_rfactor( |
| 1136 | + placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"] |
| 1137 | +) -> None: |
| 1138 | + T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| 1139 | + placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32") |
| 1140 | + placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32") |
| 1141 | + placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") |
| 1142 | + placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") |
| 1143 | + for i0, i1_0, i1_1 in T.grid(1, 4, 8): |
| 1144 | + with T.block("placeholder_red_temp_rf"): |
| 1145 | + vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) |
| 1146 | + T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) |
| 1147 | + T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) |
| 1148 | + with T.init(): |
| 1149 | + placeholder_red_temp_v0_rf[ax0, vi1_1] = -1 |
| 1150 | + placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648 |
| 1151 | + v_placeholder_red_temp_v0_rf: T.int32 = T.Select( |
| 1152 | + placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1] |
| 1153 | + or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1] |
| 1154 | + and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1, |
| 1155 | + placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1156 | + vi1_0 * 8 + vi1_1, |
| 1157 | + ) |
| 1158 | + v_placeholder_red_temp_v1_rf: T.int32 = T.Select( |
| 1159 | + placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1], |
| 1160 | + placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1161 | + placeholder[ax0, vi1_0 * 8 + vi1_1], |
| 1162 | + ) |
| 1163 | + placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf |
| 1164 | + placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf |
| 1165 | + for i0, i1_1 in T.grid(1, 8): |
| 1166 | + with T.block("placeholder_red_temp"): |
| 1167 | + vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) |
| 1168 | + T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) |
| 1169 | + T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) |
| 1170 | + with T.init(): |
| 1171 | + placeholder_red_temp_v0[ax0] = -1 |
| 1172 | + placeholder_red_temp_v1[ax0] = -2147483648 |
| 1173 | + v_placeholder_red_temp_v0: T.int32 = T.Select( |
| 1174 | + placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1] |
| 1175 | + or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1] |
| 1176 | + and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1177 | + placeholder_red_temp_v0[ax0], |
| 1178 | + placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1179 | + ) |
| 1180 | + v_placeholder_red_temp_v1: T.int32 = T.Select( |
| 1181 | + placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1182 | + placeholder_red_temp_v1[ax0], |
| 1183 | + placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1184 | + ) |
| 1185 | + placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 |
| 1186 | + placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 |
| 1187 | + for i0 in T.serial(1): |
| 1188 | + with T.block("placeholder_red"): |
| 1189 | + ax0 = T.axis.spatial(1, i0) |
| 1190 | + T.reads(placeholder_red_temp_v0[ax0]) |
| 1191 | + T.writes(placeholder_red[ax0]) |
| 1192 | + placeholder_red[ax0] = placeholder_red_temp_v0[ax0] |
| 1193 | + |
| 1194 | + |
| 1195 | +@T.prim_func |
| 1196 | +def argmin_topi_rfactor( |
| 1197 | + placeholder: T.Buffer[(1, 32), "int32"], placeholder_red: T.Buffer[1, "int32"] |
| 1198 | +) -> None: |
| 1199 | + T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| 1200 | + placeholder_red_temp_v0 = T.alloc_buffer([1], dtype="int32") |
| 1201 | + placeholder_red_temp_v1 = T.alloc_buffer([1], dtype="int32") |
| 1202 | + placeholder_red_temp_v0_rf = T.alloc_buffer([1, 8], dtype="int32") |
| 1203 | + placeholder_red_temp_v1_rf = T.alloc_buffer([1, 8], dtype="int32") |
| 1204 | + for i0, i1_0, i1_1 in T.grid(1, 4, 8): |
| 1205 | + with T.block("placeholder_red_temp_rf"): |
| 1206 | + vi1_1, ax0, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) |
| 1207 | + T.reads(placeholder[ax0, vi1_0 * 8 + vi1_1]) |
| 1208 | + T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) |
| 1209 | + with T.init(): |
| 1210 | + placeholder_red_temp_v0_rf[ax0, vi1_1] = -1 |
| 1211 | + placeholder_red_temp_v1_rf[ax0, vi1_1] = 2147483647 |
| 1212 | + v_placeholder_red_temp_v0_rf: T.int32 = T.Select( |
| 1213 | + placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1] |
| 1214 | + or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1] |
| 1215 | + and placeholder_red_temp_v0_rf[ax0, vi1_1] < vi1_0 * 8 + vi1_1, |
| 1216 | + placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1217 | + vi1_0 * 8 + vi1_1, |
| 1218 | + ) |
| 1219 | + v_placeholder_red_temp_v1_rf: T.int32 = T.Select( |
| 1220 | + placeholder_red_temp_v1_rf[ax0, vi1_1] < placeholder[ax0, vi1_0 * 8 + vi1_1], |
| 1221 | + placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1222 | + placeholder[ax0, vi1_0 * 8 + vi1_1], |
| 1223 | + ) |
| 1224 | + placeholder_red_temp_v0_rf[ax0, vi1_1] = v_placeholder_red_temp_v0_rf |
| 1225 | + placeholder_red_temp_v1_rf[ax0, vi1_1] = v_placeholder_red_temp_v1_rf |
| 1226 | + for i0, i1_1 in T.grid(1, 8): |
| 1227 | + with T.block("placeholder_red_temp"): |
| 1228 | + vi1_1, ax0 = T.axis.remap("RS", [i1_1, i0]) |
| 1229 | + T.reads(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) |
| 1230 | + T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) |
| 1231 | + with T.init(): |
| 1232 | + placeholder_red_temp_v0[ax0] = -1 |
| 1233 | + placeholder_red_temp_v1[ax0] = 2147483647 |
| 1234 | + v_placeholder_red_temp_v0: T.int32 = T.Select( |
| 1235 | + placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1] |
| 1236 | + or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1] |
| 1237 | + and placeholder_red_temp_v0[ax0] < placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1238 | + placeholder_red_temp_v0[ax0], |
| 1239 | + placeholder_red_temp_v0_rf[ax0, vi1_1], |
| 1240 | + ) |
| 1241 | + v_placeholder_red_temp_v1: T.int32 = T.Select( |
| 1242 | + placeholder_red_temp_v1[ax0] < placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1243 | + placeholder_red_temp_v1[ax0], |
| 1244 | + placeholder_red_temp_v1_rf[ax0, vi1_1], |
| 1245 | + ) |
| 1246 | + placeholder_red_temp_v0[ax0] = v_placeholder_red_temp_v0 |
| 1247 | + placeholder_red_temp_v1[ax0] = v_placeholder_red_temp_v1 |
| 1248 | + for i0 in T.serial(1): |
| 1249 | + with T.block("placeholder_red"): |
| 1250 | + ax0 = T.axis.spatial(1, i0) |
| 1251 | + T.reads(placeholder_red_temp_v0[ax0]) |
| 1252 | + T.writes(placeholder_red[ax0]) |
| 1253 | + placeholder_red[ax0] = placeholder_red_temp_v0[ax0] |
| 1254 | + |
| 1255 | + |
1136 | 1256 | # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg |
1137 | 1257 |
|
1138 | 1258 |
|
@@ -1490,5 +1610,35 @@ def test_reduction_rfactor_argmax_init_buffer_not_match(): |
1490 | 1610 | s.rfactor(ki, 1) |
1491 | 1611 |
|
1492 | 1612 |
|
| 1613 | +def test_reduction_rfactor_topi_argmax(): |
| 1614 | + A = te.placeholder((1, 32), dtype="int32") |
| 1615 | + B = topi.argmax(A, axis=1) |
| 1616 | + argmax_topi = te.create_prim_func([A, B]) |
| 1617 | + s = tir.Schedule(argmax_topi, debug_mask="all") |
| 1618 | + argmax = s.get_block("placeholder_red_temp") |
| 1619 | + _, k = s.get_loops(argmax) |
| 1620 | + _, ki = s.split(k, [None, 8]) |
| 1621 | + rf_block = s.rfactor(ki, 1) |
| 1622 | + tvm.ir.assert_structural_equal(s.mod["main"], argmax_topi_rfactor) |
| 1623 | + assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) |
| 1624 | + assert s.get(argmax).same_as(s.get(s.get_block("placeholder_red_temp"))) |
| 1625 | + verify_trace_roundtrip(s, mod=argmax_topi) |
| 1626 | + |
| 1627 | + |
| 1628 | +def test_reduction_rfactor_topi_argmin(): |
| 1629 | + A = te.placeholder((1, 32), dtype="int32") |
| 1630 | + B = topi.argmin(A, axis=1) |
| 1631 | + argmin_topi = te.create_prim_func([A, B]) |
| 1632 | + s = tir.Schedule(argmin_topi, debug_mask="all") |
| 1633 | + argmin = s.get_block("placeholder_red_temp") |
| 1634 | + _, k = s.get_loops(argmin) |
| 1635 | + _, ki = s.split(k, [None, 8]) |
| 1636 | + rf_block = s.rfactor(ki, 1) |
| 1637 | + tvm.ir.assert_structural_equal(s.mod["main"], argmin_topi_rfactor) |
| 1638 | + assert s.get(rf_block).same_as(s.get(s.get_block("placeholder_red_temp_rf"))) |
| 1639 | + assert s.get(argmin).same_as(s.get(s.get_block("placeholder_red_temp"))) |
| 1640 | + verify_trace_roundtrip(s, mod=argmin_topi) |
| 1641 | + |
| 1642 | + |
1493 | 1643 | if __name__ == "__main__": |
1494 | 1644 | tvm.testing.main() |
0 commit comments