Skip to content

Commit aae484f

Browse files
committed
Update testcase test_meta_schedule_schedule_rule_mlt_tc.py::test_conv_1x1
1 parent 7c4c620 commit aae484f

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -903,39 +903,39 @@ def test_conv_1x1():
903903
def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")):
904904
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
905905
# with T.block("root"):
906-
conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="shared")
907-
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 2, 8, 2, 16, 16), scope="wmma.accumulator")
906+
conv2d_nhwc_reindex_shared = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared")
907+
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator")
908908
PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared")
909909
weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared")
910910
PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a")
911911
weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b")
912-
for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T.thread_binding(4, thread="blockIdx.y"):
913-
for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
914-
for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T.thread_binding(1, thread="threadIdx.y"):
915-
for ax4_0_0 in range(1):
912+
for ax0_ax1_ax2_0_0_ax3_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
913+
for ax2_0_1_ax3_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
914+
for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"):
915+
for ax4_0_0 in range(2):
916916
for ax0_ax1_fused in range(8192):
917917
with T.block("PadInput_reindex_shared"):
918-
v0 = T.axis.spatial(256, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64)
919-
v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
918+
v0 = T.axis.spatial(256, ax0_ax1_fused // 32)
919+
v1 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_fused % 32)
920920
T.reads(inputs[0, v0 // 16, v0 % 16, v1])
921921
T.writes(PadInput_reindex_shared[v0, v1])
922-
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2})
922+
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8})
923923
PadInput_reindex_shared[v0, v1] = inputs[0, v0 // 16, v0 % 16, v1]
924924
for ax0_ax1_ax2_ax3_fused in range(2048):
925925
with T.block("weight_reindex_shared"):
926926
v0 = T.axis.spatial(1, 0)
927927
v1 = T.axis.spatial(1, 0)
928-
v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused // 32)
929-
v3 = T.axis.spatial(64, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32)
928+
v2 = T.axis.spatial(64, ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64)
929+
v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
930930
T.reads(weight[v0, v1, v2, v3])
931931
T.writes(weight_reindex_shared[v0, v1, v2, v3])
932-
T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 8})
932+
T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4})
933933
weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
934934
for ax4_0_1 in range(1):
935-
for ax0_0, ax1_0 in T.grid(8, 4):
935+
for ax0_0, ax1_0 in T.grid(8, 2):
936936
with T.block("PadInput_reindex_shared_wmma.matrix_a_o"):
937-
v0_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0)
938-
v1_o = T.axis.spatial(4, ax1_0)
937+
v0_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax0_0)
938+
v1_o = T.axis.spatial(4, ax4_0_0 * 2 + ax1_0)
939939
T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
940940
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
941941
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"})
@@ -945,10 +945,11 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
945945
T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
946946
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
947947
PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
948-
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 4, 2):
948+
for ax0, ax1, ax2_0, ax3_0 in T.grid(1, 1, 2, 4):
949949
with T.block("weight_reindex_shared_wmma.matrix_b_o"):
950-
v0_o, v1_o, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0])
951-
v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0)
950+
v0_o, v1_o = T.axis.remap("SS", [ax0, ax1])
951+
v2_o = T.axis.spatial(4, ax4_0_0 * 2 + ax2_0)
952+
v3_o = T.axis.spatial(4, ax3_0)
952953
T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
953954
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
954955
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"})
@@ -958,38 +959,38 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
958959
T.reads(weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
959960
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i])
960961
weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i] = weight_reindex_shared[v0_o, v1_o, v2_o * 16 + v2_i, v3_o * 16 + v3_i]
961-
for ax0_3, ax1_3, ax2_0_3, ax3_0_3, ax4_0_2, ax0_4, ax1_4, ax2_0_4, ax3_0_4 in T.grid(1, 1, 8, 2, 4, 1, 1, 1, 1):
962+
for ax2_0_3, ax3_0_3, ax4_0_2, ax2_0_4, ax3_0_4 in T.grid(8, 1, 2, 1, 4):
962963
with T.block("conv2d_nhwc_o"):
963-
v0_o = T.axis.spatial(1, ax0_3 + ax0_4)
964-
v1_o = T.axis.spatial(1, ax1_3 + ax1_4)
965-
v2_o = T.axis.spatial(16, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4)
966-
v3_o = T.axis.spatial(4, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4)
967-
v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2)
964+
v0_o = T.axis.spatial(1, 0)
965+
v1_o = T.axis.spatial(1, 0)
966+
v2_o = T.axis.spatial(16, ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4)
967+
v3_o = T.axis.spatial(4, ax3_0_3 * 4 + ax3_0_4)
968+
v4_o = T.axis.reduce(4, ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2)
968969
T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
969-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16])
970+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, 0:16, 0:16])
970971
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1})
971972
with T.init():
972973
for ax2_1, ax3_1 in T.grid(16, 16):
973974
with T.block("conv2d_nhwc_init"):
974975
v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1])
975976
T.reads()
976-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init])
977-
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i_init, v3_i_init] = T.float32(0)
977+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init])
978+
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i_init, v3_i_init] = T.float32(0)
978979
for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16):
979980
with T.block("conv2d_nhwc"):
980981
v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1])
981-
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
982-
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i])
982+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
983+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i])
983984
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
984-
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
985+
conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o // 8, 0, v2_o % 8, v3_o, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0_o, v1_o, v4_o * 16 + v4_i, v3_o * 16 + v3_i])
985986
for ax2 in range(8):
986-
for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"):
987-
for ax2_1, ax3 in T.grid(1, 2):
987+
for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"):
988+
for ax2_1, ax3 in T.grid(1, 4):
988989
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
989-
v0_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
990-
v1_o = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
990+
v0_o = T.axis.spatial(2, ax0_ax1_fused)
991+
v1_o = T.axis.spatial(1, 0)
991992
v2_o = T.axis.spatial(8, ax2 + ax2_1)
992-
v3_o = T.axis.spatial(2, ax3)
993+
v3_o = T.axis.spatial(4, ax3)
993994
v4_o = T.axis.spatial(1, 0)
994995
v5_o = T.axis.spatial(1, 0)
995996
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
@@ -1001,29 +1002,29 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
10011002
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
10021003
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
10031004
conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]
1004-
for ax0_ax1_ax3_ax4_ax5_fused in range(512):
1005+
for ax0_ax1_ax3_ax4_ax5_fused in range(2048):
10051006
with T.block("conv2d_nhwc_reindex_shared"):
1006-
v0 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2)
1007-
v1 = T.axis.spatial(2, ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2)
1007+
v0 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024)
1008+
v1 = T.axis.spatial(1, 0)
10081009
v2 = T.axis.spatial(8, ax2)
1009-
v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256)
1010+
v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256)
10101011
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
10111012
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
10121013
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5])
1013-
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32])
1014+
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16])
10141015
T.block_attr({"meta_schedule.cooperative_fetch": 1})
1015-
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
1016+
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 128) // 16, (v4 + v2 * 16 + v0 * 128) % 16, v5 + v3 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
10161017
# fmt: on
10171018

10181019
decision_0 = [
1019-
("SamplePerfectTile", [1, 1, 1, 1, 1]),
1020-
("SamplePerfectTile", [1, 1, 1, 1, 1]),
1021-
("SamplePerfectTile", [2, 1, 1, 8, 1]),
1022-
("SamplePerfectTile", [2, 1, 1, 2, 1]),
1023-
("SamplePerfectTile", [1, 1, 4]),
1020+
# ("SamplePerfectTile", [1, 1, 1, 1, 1]),
1021+
# ("SamplePerfectTile", [1, 1, 1, 1, 1]),
1022+
("SamplePerfectTile", [1, 1, 2, 8, 1]),
1023+
("SamplePerfectTile", [1, 1, 1, 1, 4]),
1024+
("SamplePerfectTile", [2, 1, 2]),
10241025
("SampleCategorical", 0),
1025-
("SampleCategorical", 1),
10261026
("SampleCategorical", 3),
1027+
("SampleCategorical", 2),
10271028
]
10281029

10291030
mod = te.create_prim_func(

0 commit comments

Comments
 (0)