@@ -903,39 +903,39 @@ def test_conv_1x1():
903903 def conv2d_1x1_0 (inputs : T .Buffer ((1 , 16 , 16 , 64 ), "float16" ), weight : T .Buffer ((1 , 1 , 64 , 64 ), "float16" ), conv2d_nhwc : T .Buffer ((1 , 16 , 16 , 64 ), "float32" )):
904904 T .func_attr ({"global_symbol" : "main" , "tir.noalias" : T .bool (True )})
905905 # with T.block("root"):
906- conv2d_nhwc_reindex_shared = T .alloc_buffer ((2 , 2 , 8 , 2 , 16 , 16 ), scope = "shared" )
907- conv2d_nhwc_reindex_shared_wmma_accumulator = T .alloc_buffer ((2 , 2 , 8 , 2 , 16 , 16 ), scope = "wmma.accumulator" )
906+ conv2d_nhwc_reindex_shared = T .alloc_buffer ((2 , 1 , 8 , 4 , 16 , 16 ), scope = "shared" )
907+ conv2d_nhwc_reindex_shared_wmma_accumulator = T .alloc_buffer ((2 , 1 , 8 , 4 , 16 , 16 ), scope = "wmma.accumulator" )
908908 PadInput_reindex_shared = T .alloc_buffer ((256 , 64 ), "float16" , scope = "shared" )
909909 weight_reindex_shared = T .alloc_buffer ((1 , 1 , 64 , 64 ), "float16" , scope = "shared" )
910910 PadInput_reindex_shared_wmma_matrix_a = T .alloc_buffer ((256 , 64 ), "float16" , scope = "wmma.matrix_a" )
911911 weight_reindex_shared_wmma_matrix_b = T .alloc_buffer ((1 , 1 , 64 , 64 ), "float16" , scope = "wmma.matrix_b" )
912- for ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused in T .thread_binding (4 , thread = "blockIdx.y" ):
913- for ax0_1_ax1_1_ax2_0_1_ax3_0_1_fused in T .thread_binding (1 , thread = "blockIdx.x" ):
914- for ax0_2_ax1_2_ax2_0_2_ax3_0_2_fused in T .thread_binding (1 , thread = "threadIdx.y" ):
915- for ax4_0_0 in range (1 ):
912+ for ax0_ax1_ax2_0_0_ax3_0_0_fused in T .thread_binding (1 , thread = "blockIdx.y" ):
913+ for ax2_0_1_ax3_0_1_fused in T .thread_binding (1 , thread = "blockIdx.x" ):
914+ for ax2_0_2_ax3_0_2_fused in T .thread_binding (2 , thread = "threadIdx.y" ):
915+ for ax4_0_0 in range (2 ):
916916 for ax0_ax1_fused in range (8192 ):
917917 with T .block ("PadInput_reindex_shared" ):
918- v0 = T .axis .spatial (256 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 128 + ax0_ax1_fused // 64 )
919- v1 = T .axis .spatial (64 , ax0_ax1_fused % 64 )
918+ v0 = T .axis .spatial (256 , ax0_ax1_fused // 32 )
919+ v1 = T .axis .spatial (64 , ax4_0_0 * 32 + ax0_ax1_fused % 32 )
920920 T .reads (inputs [0 , v0 // 16 , v0 % 16 , v1 ])
921921 T .writes (PadInput_reindex_shared [v0 , v1 ])
922- T .block_attr ({"buffer_dim_align" : [[0 , 0 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 2 })
922+ T .block_attr ({"buffer_dim_align" : [[0 , 0 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 8 })
923923 PadInput_reindex_shared [v0 , v1 ] = inputs [0 , v0 // 16 , v0 % 16 , v1 ]
924924 for ax0_ax1_ax2_ax3_fused in range (2048 ):
925925 with T .block ("weight_reindex_shared" ):
926926 v0 = T .axis .spatial (1 , 0 )
927927 v1 = T .axis .spatial (1 , 0 )
928- v2 = T .axis .spatial (64 , ax0_ax1_ax2_ax3_fused // 32 )
929- v3 = T .axis .spatial (64 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32 )
928+ v2 = T .axis .spatial (64 , ax4_0_0 * 32 + ax0_ax1_ax2_ax3_fused // 64 )
929+ v3 = T .axis .spatial (64 , ax0_ax1_ax2_ax3_fused % 64 )
930930 T .reads (weight [v0 , v1 , v2 , v3 ])
931931 T .writes (weight_reindex_shared [v0 , v1 , v2 , v3 ])
932- T .block_attr ({"buffer_dim_align" : [[0 , 2 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 8 })
932+ T .block_attr ({"buffer_dim_align" : [[0 , 2 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 4 })
933933 weight_reindex_shared [v0 , v1 , v2 , v3 ] = weight [v0 , v1 , v2 , v3 ]
934934 for ax4_0_1 in range (1 ):
935- for ax0_0 , ax1_0 in T .grid (8 , 4 ):
935+ for ax0_0 , ax1_0 in T .grid (8 , 2 ):
936936 with T .block ("PadInput_reindex_shared_wmma.matrix_a_o" ):
937- v0_o = T .axis .spatial (16 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax0_0 )
938- v1_o = T .axis .spatial (4 , ax1_0 )
937+ v0_o = T .axis .spatial (16 , ax2_0_2_ax3_0_2_fused * 8 + ax0_0 )
938+ v1_o = T .axis .spatial (4 , ax4_0_0 * 2 + ax1_0 )
939939 T .reads (PadInput_reindex_shared [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
940940 T .writes (PadInput_reindex_shared_wmma_matrix_a [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
941941 T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_load_16x16x16_f16_a_shared" })
@@ -945,10 +945,11 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
945945 T .reads (PadInput_reindex_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
946946 T .writes (PadInput_reindex_shared_wmma_matrix_a [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
947947 PadInput_reindex_shared_wmma_matrix_a [v0_o * 16 + v0_i , v1_o * 16 + v1_i ] = PadInput_reindex_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ]
948- for ax0 , ax1 , ax2_0 , ax3_0 in T .grid (1 , 1 , 4 , 2 ):
948+ for ax0 , ax1 , ax2_0 , ax3_0 in T .grid (1 , 1 , 2 , 4 ):
949949 with T .block ("weight_reindex_shared_wmma.matrix_b_o" ):
950- v0_o , v1_o , v2_o = T .axis .remap ("SSS" , [ax0 , ax1 , ax2_0 ])
951- v3_o = T .axis .spatial (4 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0 )
950+ v0_o , v1_o = T .axis .remap ("SS" , [ax0 , ax1 ])
951+ v2_o = T .axis .spatial (4 , ax4_0_0 * 2 + ax2_0 )
952+ v3_o = T .axis .spatial (4 , ax3_0 )
952953 T .reads (weight_reindex_shared [v0_o , v1_o , v2_o * 16 :v2_o * 16 + 16 , v3_o * 16 :v3_o * 16 + 16 ])
953954 T .writes (weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v2_o * 16 :v2_o * 16 + 16 , v3_o * 16 :v3_o * 16 + 16 ])
954955 T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_load_16x16x16_f16_b_shared" })
@@ -958,38 +959,38 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
958959 T .reads (weight_reindex_shared [v0_o , v1_o , v2_o * 16 + v2_i , v3_o * 16 + v3_i ])
959960 T .writes (weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v2_o * 16 + v2_i , v3_o * 16 + v3_i ])
960961 weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v2_o * 16 + v2_i , v3_o * 16 + v3_i ] = weight_reindex_shared [v0_o , v1_o , v2_o * 16 + v2_i , v3_o * 16 + v3_i ]
961- for ax0_3 , ax1_3 , ax2_0_3 , ax3_0_3 , ax4_0_2 , ax0_4 , ax1_4 , ax2_0_4 , ax3_0_4 in T .grid (1 , 1 , 8 , 2 , 4 , 1 , 1 , 1 , 1 ):
962+ for ax2_0_3 , ax3_0_3 , ax4_0_2 , ax2_0_4 , ax3_0_4 in T .grid (8 , 1 , 2 , 1 , 4 ):
962963 with T .block ("conv2d_nhwc_o" ):
963- v0_o = T .axis .spatial (1 , ax0_3 + ax0_4 )
964- v1_o = T .axis .spatial (1 , ax1_3 + ax1_4 )
965- v2_o = T .axis .spatial (16 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 * 8 + ax2_0_3 + ax2_0_4 )
966- v3_o = T .axis .spatial (4 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 * 2 + ax3_0_3 + ax3_0_4 )
967- v4_o = T .axis .reduce (4 , ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2 )
964+ v0_o = T .axis .spatial (1 , 0 )
965+ v1_o = T .axis .spatial (1 , 0 )
966+ v2_o = T .axis .spatial (16 , ax2_0_2_ax3_0_2_fused * 8 + ax2_0_3 + ax2_0_4 )
967+ v3_o = T .axis .spatial (4 , ax3_0_3 * 4 + ax3_0_4 )
968+ v4_o = T .axis .reduce (4 , ax4_0_0 * 2 + ax4_0_1 * 2 + ax4_0_2 )
968969 T .reads (PadInput_reindex_shared_wmma_matrix_a [v2_o * 16 :v2_o * 16 + 16 , v4_o * 16 :v4_o * 16 + 16 ], weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v4_o * 16 :v4_o * 16 + 16 , v3_o * 16 :v3_o * 16 + 16 ])
969- T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , 0 :16 , 0 :16 ])
970+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , 0 :16 , 0 :16 ])
970971 T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_sync_16x16x16_f16f16f32" , "meta_schedule.auto_tensorize_init" : "wmma_fill_16x16x16_f32" , "warp_execution" : 1 })
971972 with T .init ():
972973 for ax2_1 , ax3_1 in T .grid (16 , 16 ):
973974 with T .block ("conv2d_nhwc_init" ):
974975 v2_i_init , v3_i_init = T .axis .remap ("SS" , [ax2_1 , ax3_1 ])
975976 T .reads ()
976- T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i_init , v3_i_init ])
977- conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i_init , v3_i_init ] = T .float32 (0 )
977+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i_init , v3_i_init ])
978+ conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i_init , v3_i_init ] = T .float32 (0 )
978979 for ax2_1 , ax3_1 , ax4_1 in T .grid (16 , 16 , 16 ):
979980 with T .block ("conv2d_nhwc" ):
980981 v2_i , v3_i , v4_i = T .axis .remap ("SSR" , [ax2_1 , ax3_1 , ax4_1 ])
981- T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i , v3_i ], PadInput_reindex_shared_wmma_matrix_a [v2_o * 16 + v2_i , v4_o * 16 + v4_i ], weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v4_o * 16 + v4_i , v3_o * 16 + v3_i ])
982- T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i , v3_i ])
982+ T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i , v3_i ], PadInput_reindex_shared_wmma_matrix_a [v2_o * 16 + v2_i , v4_o * 16 + v4_i ], weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v4_o * 16 + v4_i , v3_o * 16 + v3_i ])
983+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i , v3_i ])
983984 T .block_attr ({"meta_schedule.tiling_structure" : "SSSRRSRS" })
984- conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i , v3_i ] = conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , v3_o // 2 , v2_o % 8 , v3_o % 2 , v2_i , v3_i ] + T .Cast ("float32" , PadInput_reindex_shared_wmma_matrix_a [v2_o * 16 + v2_i , v4_o * 16 + v4_i ]) * T .Cast ("float32" , weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v4_o * 16 + v4_i , v3_o * 16 + v3_i ])
985+ conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i , v3_i ] = conv2d_nhwc_reindex_shared_wmma_accumulator [v2_o // 8 , 0 , v2_o % 8 , v3_o , v2_i , v3_i ] + T .Cast ("float32" , PadInput_reindex_shared_wmma_matrix_a [v2_o * 16 + v2_i , v4_o * 16 + v4_i ]) * T .Cast ("float32" , weight_reindex_shared_wmma_matrix_b [v0_o , v1_o , v4_o * 16 + v4_i , v3_o * 16 + v3_i ])
985986 for ax2 in range (8 ):
986- for ax0_ax1_fused in T .thread_binding (1 , thread = "threadIdx.y" ):
987- for ax2_1 , ax3 in T .grid (1 , 2 ):
987+ for ax0_ax1_fused in T .thread_binding (2 , thread = "threadIdx.y" ):
988+ for ax2_1 , ax3 in T .grid (1 , 4 ):
988989 with T .block ("conv2d_nhwc_reindex_shared_wmma.accumulator_o" ):
989- v0_o = T .axis .spatial (2 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 )
990- v1_o = T .axis .spatial (2 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 )
990+ v0_o = T .axis .spatial (2 , ax0_ax1_fused )
991+ v1_o = T .axis .spatial (1 , 0 )
991992 v2_o = T .axis .spatial (8 , ax2 + ax2_1 )
992- v3_o = T .axis .spatial (2 , ax3 )
993+ v3_o = T .axis .spatial (4 , ax3 )
993994 v4_o = T .axis .spatial (1 , 0 )
994995 v5_o = T .axis .spatial (1 , 0 )
995996 T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , 0 :16 , 0 :16 ])
@@ -1001,29 +1002,29 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
10011002 T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ])
10021003 T .writes (conv2d_nhwc_reindex_shared [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ])
10031004 conv2d_nhwc_reindex_shared [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ] = conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ]
1004- for ax0_ax1_ax3_ax4_ax5_fused in range (512 ):
1005+ for ax0_ax1_ax3_ax4_ax5_fused in range (2048 ):
10051006 with T .block ("conv2d_nhwc_reindex_shared" ):
1006- v0 = T .axis .spatial (2 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused // 2 )
1007- v1 = T .axis .spatial (2 , ax0_0_ax1_0_ax2_0_0_ax3_0_0_fused % 2 )
1007+ v0 = T .axis .spatial (2 , ax0_ax1_ax3_ax4_ax5_fused // 1024 )
1008+ v1 = T .axis .spatial (1 , 0 )
10081009 v2 = T .axis .spatial (8 , ax2 )
1009- v3 = T .axis .spatial (2 , ax0_ax1_ax3_ax4_ax5_fused // 256 )
1010+ v3 = T .axis .spatial (4 , ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256 )
10101011 v4 = T .axis .spatial (16 , ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 )
10111012 v5 = T .axis .spatial (16 , ax0_ax1_ax3_ax4_ax5_fused % 16 )
10121013 T .reads (conv2d_nhwc_reindex_shared [v0 , v1 , v2 , v3 , v4 , v5 ])
1013- T .writes (conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 128 ) // 16 , (v4 + v2 * 16 + v0 * 128 ) % 16 , v5 + v3 * 16 + v1 * 32 ])
1014+ T .writes (conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 128 ) // 16 , (v4 + v2 * 16 + v0 * 128 ) % 16 , v5 + v3 * 16 ])
10141015 T .block_attr ({"meta_schedule.cooperative_fetch" : 1 })
1015- conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 128 ) // 16 , (v4 + v2 * 16 + v0 * 128 ) % 16 , v5 + v3 * 16 + v1 * 32 ] = conv2d_nhwc_reindex_shared [v0 , v1 , v2 , v3 , v4 , v5 ]
1016+ conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 128 ) // 16 , (v4 + v2 * 16 + v0 * 128 ) % 16 , v5 + v3 * 16 ] = conv2d_nhwc_reindex_shared [v0 , v1 , v2 , v3 , v4 , v5 ]
10161017 # fmt: on
10171018
10181019 decision_0 = [
1019- ("SamplePerfectTile" , [1 , 1 , 1 , 1 , 1 ]),
1020- ("SamplePerfectTile" , [1 , 1 , 1 , 1 , 1 ]),
1021- ("SamplePerfectTile" , [2 , 1 , 1 , 8 , 1 ]),
1022- ("SamplePerfectTile" , [2 , 1 , 1 , 2 , 1 ]),
1023- ("SamplePerfectTile" , [1 , 1 , 4 ]),
1020+ # ("SamplePerfectTile", [1, 1, 1, 1, 1]),
1021+ # ("SamplePerfectTile", [1, 1, 1, 1, 1]),
1022+ ("SamplePerfectTile" , [1 , 1 , 2 , 8 , 1 ]),
1023+ ("SamplePerfectTile" , [1 , 1 , 1 , 1 , 4 ]),
1024+ ("SamplePerfectTile" , [2 , 1 , 2 ]),
10241025 ("SampleCategorical" , 0 ),
1025- ("SampleCategorical" , 1 ),
10261026 ("SampleCategorical" , 3 ),
1027+ ("SampleCategorical" , 2 ),
10271028 ]
10281029
10291030 mod = te .create_prim_func (
0 commit comments