@@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
10551055 )
10561056
10571057
1058+ def test_padded_conv ():
1059+ # fmt: off
1060+ @T .prim_func
1061+ def padded_conv2d_0 (inputs : T .Buffer ((1 , 224 , 224 , 3 ), "float16" ), weight : T .Buffer ((7 , 7 , 3 , 64 ), "float16" ), conv2d_nhwc : T .Buffer ((1 , 112 , 112 , 64 ), "float32" )):
1062+ T .func_attr ({"tir.noalias" : T .bool (True )})
1063+ # with T.block("root"):
1064+ conv2d_nhwc_reindex_shared = T .alloc_buffer ((56 , 2 , 14 , 2 , 16 , 16 ), scope = "shared" )
1065+ conv2d_nhwc_reindex_shared_wmma_accumulator = T .alloc_buffer ((56 , 2 , 14 , 2 , 16 , 16 ), scope = "wmma.accumulator" )
1066+ PadInput_reindex_pad_shared = T .alloc_buffer ((12544 , 160 ), "float16" , scope = "shared" )
1067+ weight_reindex_pad_shared = T .alloc_buffer ((160 , 64 ), "float16" , scope = "shared" )
1068+ PadInput_reindex_pad_shared_wmma_matrix_a = T .alloc_buffer ((12544 , 160 ), "float16" , scope = "wmma.matrix_a" )
1069+ weight_reindex_pad_shared_wmma_matrix_b = T .alloc_buffer ((160 , 64 ), "float16" , scope = "wmma.matrix_b" )
1070+ for ax0_0_0_ax1_0_0_fused in T .thread_binding (14 , thread = "blockIdx.y" ):
1071+ for ax0_0_1_ax1_0_1_fused in T .thread_binding (1 , thread = "blockIdx.x" ):
1072+ for ax0_0_2_ax1_0_2_fused in T .thread_binding (8 , thread = "threadIdx.y" ):
1073+ for ax2_0_0 in range (10 ):
1074+ for ax0_ax1_fused in range (28672 ):
1075+ with T .block ("PadInput_reindex_pad_shared" ):
1076+ v0 = T .axis .spatial (12544 , ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16 )
1077+ v1 = T .axis .spatial (160 , ax2_0_0 * 16 + ax0_ax1_fused % 16 )
1078+ T .reads (inputs [0 , v0 // 112 * 2 + v1 // 21 - 3 , v0 % 112 * 2 + v1 % 21 // 3 - 3 , v1 % 3 ])
1079+ T .writes (PadInput_reindex_pad_shared [v0 , v1 ])
1080+ T .block_attr ({"buffer_dim_align" : [[0 , 0 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 4 })
1081+ PadInput_reindex_pad_shared [v0 , v1 ] = T .if_then_else (v1 < 147 , T .if_then_else (3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227 , inputs [0 , v0 // 112 * 2 + v1 // 21 - 3 , v0 % 112 * 2 + v1 % 21 // 3 - 3 , v1 % 3 ], T .float16 (0 )), T .float16 (0 ))
1082+ for ax0_ax1_fused in range (512 ):
1083+ with T .block ("weight_reindex_pad_shared" ):
1084+ v0 = T .axis .spatial (160 , ax2_0_0 * 16 + ax0_ax1_fused // 32 )
1085+ v1 = T .axis .spatial (64 , ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32 )
1086+ T .reads (weight [v0 // 21 , v0 % 21 // 3 , v0 % 3 , v1 ])
1087+ T .writes (weight_reindex_pad_shared [v0 , v1 ])
1088+ T .block_attr ({"buffer_dim_align" : [[0 , 0 , 32 , 8 ]], "meta_schedule.cooperative_fetch" : 2 })
1089+ weight_reindex_pad_shared [v0 , v1 ] = T .if_then_else (v0 < 147 , weight [v0 // 21 , v0 % 21 // 3 , v0 % 3 , v1 ], T .float16 (0 ))
1090+ for ax2_0_1 in range (1 ):
1091+ for ax0_0 , ax1_0 in T .grid (14 , 1 ):
1092+ with T .block ("PadInput_reindex_pad_shared_wmma.matrix_a_o" ):
1093+ v0_o = T .axis .spatial (784 , ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0 )
1094+ v1_o = T .axis .spatial (10 , ax2_0_0 + ax1_0 )
1095+ T .reads (PadInput_reindex_pad_shared [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
1096+ T .writes (PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
1097+ T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_load_16x16x16_f16_a_shared" })
1098+ for ax0_1 , ax1_1 in T .grid (16 , 16 ):
1099+ with T .block ("PadInput_reindex_pad_shared_wmma.matrix_a" ):
1100+ v0_i , v1_i = T .axis .remap ("SS" , [ax0_1 , ax1_1 ])
1101+ T .reads (PadInput_reindex_pad_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
1102+ T .writes (PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
1103+ PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 + v0_i , v1_o * 16 + v1_i ] = PadInput_reindex_pad_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ]
1104+ for ax0_0 , ax1_0 in T .grid (1 , 2 ):
1105+ with T .block ("weight_reindex_pad_shared_wmma.matrix_b_o" ):
1106+ v0_o = T .axis .spatial (10 , ax2_0_0 + ax0_0 )
1107+ v1_o = T .axis .spatial (4 , ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0 )
1108+ T .reads (weight_reindex_pad_shared [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
1109+ T .writes (weight_reindex_pad_shared_wmma_matrix_b [v0_o * 16 :v0_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
1110+ T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_load_16x16x16_f16_b_shared" })
1111+ for ax0_1 , ax1_1 in T .grid (16 , 16 ):
1112+ with T .block ("weight_reindex_pad_shared_wmma.matrix_b" ):
1113+ v0_i , v1_i = T .axis .remap ("SS" , [ax0_1 , ax1_1 ])
1114+ T .reads (weight_reindex_pad_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
1115+ T .writes (weight_reindex_pad_shared_wmma_matrix_b [v0_o * 16 + v0_i , v1_o * 16 + v1_i ])
1116+ weight_reindex_pad_shared_wmma_matrix_b [v0_o * 16 + v0_i , v1_o * 16 + v1_i ] = weight_reindex_pad_shared [v0_o * 16 + v0_i , v1_o * 16 + v1_i ]
1117+ for ax0_0_3 , ax1_0_3 , ax2_0_2 , ax0_0_4 , ax1_0_4 in T .grid (7 , 2 , 1 , 2 , 1 ):
1118+ with T .block ("conv2d_nhwc_o" ):
1119+ v0_o = T .axis .spatial (784 , ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4 )
1120+ v1_o = T .axis .spatial (4 , ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4 )
1121+ v2_o = T .axis .reduce (10 , ax2_0_0 + ax2_0_1 + ax2_0_2 )
1122+ T .reads (PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 :v0_o * 16 + 16 , v2_o * 16 :v2_o * 16 + 16 ], weight_reindex_pad_shared_wmma_matrix_b [v2_o * 16 :v2_o * 16 + 16 , v1_o * 16 :v1_o * 16 + 16 ])
1123+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , 0 :16 , 0 :16 ])
1124+ T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_sync_16x16x16_f16f16f32" , "meta_schedule.auto_tensorize_init" : "wmma_fill_16x16x16_f32" , "warp_execution" : 1 })
1125+ with T .init ():
1126+ for ax0_1 , ax1_1 in T .grid (16 , 16 ):
1127+ with T .block ("conv2d_nhwc_init" ):
1128+ v0_i_init , v1_i_init = T .axis .remap ("SS" , [ax0_1 , ax1_1 ])
1129+ T .reads ()
1130+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i_init , v1_i_init ])
1131+ conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i_init , v1_i_init ] = T .float32 (0 )
1132+ for ax0_1 , ax1_1 , ax2_1 in T .grid (16 , 16 , 16 ):
1133+ with T .block ("conv2d_nhwc" ):
1134+ v0_i , v1_i , v2_i = T .axis .remap ("SSR" , [ax0_1 , ax1_1 , ax2_1 ])
1135+ T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i , v1_i ], PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 + v0_i , v2_o * 16 + v2_i ], weight_reindex_pad_shared_wmma_matrix_b [v2_o * 16 + v2_i , v1_o * 16 + v1_i ])
1136+ T .writes (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i , v1_i ])
1137+ T .block_attr ({"meta_schedule.tiling_structure" : "SSSRRSRS" })
1138+ conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i , v1_i ] = conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o // 14 , v1_o // 2 , v0_o % 14 , v1_o % 2 , v0_i , v1_i ] + T .Cast ("float32" , PadInput_reindex_pad_shared_wmma_matrix_a [v0_o * 16 + v0_i , v2_o * 16 + v2_i ]) * T .Cast ("float32" , weight_reindex_pad_shared_wmma_matrix_b [v2_o * 16 + v2_i , v1_o * 16 + v1_i ])
1139+ for ax2 in range (14 ):
1140+ for ax0_ax1_fused in T .thread_binding (8 , thread = "threadIdx.y" ):
1141+ for ax2_1 , ax3 in T .grid (1 , 2 ):
1142+ with T .block ("conv2d_nhwc_reindex_shared_wmma.accumulator_o" ):
1143+ v0_o = T .axis .spatial (56 , ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused )
1144+ v1_o = T .axis .spatial (2 , ax0_0_0_ax1_0_0_fused % 2 )
1145+ v2_o = T .axis .spatial (14 , ax2 + ax2_1 )
1146+ v3_o = T .axis .spatial (2 , ax3 )
1147+ v4_o = T .axis .spatial (1 , 0 )
1148+ v5_o = T .axis .spatial (1 , 0 )
1149+ T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , 0 :16 , 0 :16 ])
1150+ T .writes (conv2d_nhwc_reindex_shared [v0_o , v1_o , v2_o , v3_o , 0 :16 , 0 :16 ])
1151+ T .block_attr ({"meta_schedule.auto_tensorize" : "wmma_store_16x16x16_f32_shared" })
1152+ for ax4 , ax5 in T .grid (16 , 16 ):
1153+ with T .block ("conv2d_nhwc_reindex_shared_wmma.accumulator" ):
1154+ v4_i , v5_i = T .axis .remap ("SS" , [ax4 , ax5 ])
1155+ T .reads (conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ])
1156+ T .writes (conv2d_nhwc_reindex_shared [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ])
1157+ conv2d_nhwc_reindex_shared [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ] = conv2d_nhwc_reindex_shared_wmma_accumulator [v0_o , v1_o , v2_o , v3_o , v4_i , v5_i ]
1158+ for ax0_ax1_ax3_ax4_ax5_fused in range (4096 ):
1159+ with T .block ("conv2d_nhwc_reindex_shared" ):
1160+ v0 = T .axis .spatial (56 , ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512 )
1161+ v1 = T .axis .spatial (2 , ax0_0_0_ax1_0_0_fused % 2 )
1162+ v2 = T .axis .spatial (14 , ax2 )
1163+ v3 = T .axis .spatial (2 , ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 )
1164+ v4 = T .axis .spatial (16 , ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 )
1165+ v5 = T .axis .spatial (16 , ax0_ax1_ax3_ax4_ax5_fused % 16 )
1166+ T .reads (conv2d_nhwc_reindex_shared [v0 , v1 , v2 , v3 , v4 , v5 ])
1167+ T .writes (conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 224 ) // 112 , (v4 + v2 * 16 + v0 * 224 ) % 112 , v5 + v3 * 16 + v1 * 32 ])
1168+ T .block_attr ({"meta_schedule.cooperative_fetch" : 3 })
1169+ conv2d_nhwc [0 , (v4 + v2 * 16 + v0 * 224 ) // 112 , (v4 + v2 * 16 + v0 * 224 ) % 112 , v5 + v3 * 16 + v1 * 32 ] = conv2d_nhwc_reindex_shared [v0 , v1 , v2 , v3 , v4 , v5 ]
1170+ # fmt: on
1171+
1172+ decision_0 = [
1173+ ("SamplePerfectTile" , [7 , 1 , 8 , 7 , 2 ]),
1174+ ("SamplePerfectTile" , [2 , 1 , 1 , 2 , 1 ]),
1175+ ("SamplePerfectTile" , [10 , 1 , 1 ]),
1176+ ("SampleCategorical" , 2 ),
1177+ ("SampleCategorical" , 2 ),
1178+ ("SampleCategorical" , 1 ),
1179+ ]
1180+ mod = te .create_prim_func (
1181+ te_workload .conv2d_nhwc (
1182+ 1 ,
1183+ 224 ,
1184+ 224 ,
1185+ 3 ,
1186+ 64 ,
1187+ 7 ,
1188+ 2 ,
1189+ 3 ,
1190+ in_dtype = "float16" ,
1191+ out_dtype = "float32" ,
1192+ )
1193+ )
1194+ actual = generate_design_space (
1195+ kind = "cuda" ,
1196+ mod = mod ,
1197+ target = tvm .target .Target ("cuda --arch=sm_70" ),
1198+ types = None ,
1199+ sch_rules = [multi_level_tiling_tensor_core (write_reuse_scope = "shared" )]
1200+ + get_rules ("cuda" , ms .schedule_rule .AutoInline ),
1201+ )
1202+ check_sketches (
1203+ mod ,
1204+ sketches = actual ,
1205+ expected_mods = [padded_conv2d_0 ],
1206+ expected_decisions = [decision_0 ],
1207+ )
1208+
1209+
10581210if __name__ == "__main__" :
10591211 tvm .testing .main ()
0 commit comments