Skip to content

Commit 18ff9ff

Browse files
YXY-0922yuxiyue
andauthored
[MetaSchedule]Add a testcase for padded conv2d in meta_schedule (#17171)
### Bug Fix In the `TileWithTensorIntrin` function, when the `allow_padding` parameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined. --------- Co-authored-by: yuxiyue <[email protected]>
1 parent e5bf56d commit 18ff9ff

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

src/tir/schedule/transform.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
340340
}
341341
auto consumers = sch->GetConsumers(block_rv);
342342
for (const auto& consumer : consumers) {
343-
sch->ComputeInline(consumer);
343+
auto sref = sch->GetSRef(consumer);
344+
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
345+
sch->ComputeInline(consumer);
344346
}
345347
}
346348
// Construct a mapping from tir loops back to LoopRVs

tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
10551055
)
10561056

10571057

1058+
def test_padded_conv():
1059+
# fmt: off
1060+
@T.prim_func
1061+
def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")):
1062+
T.func_attr({"tir.noalias": T.bool(True)})
1063+
# with T.block("root"):
1064+
conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared")
1065+
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator")
1066+
PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared")
1067+
weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared")
1068+
PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a")
1069+
weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b")
1070+
for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"):
1071+
for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
1072+
for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"):
1073+
for ax2_0_0 in range(10):
1074+
for ax0_ax1_fused in range(28672):
1075+
with T.block("PadInput_reindex_pad_shared"):
1076+
v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16)
1077+
v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16)
1078+
T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3])
1079+
T.writes(PadInput_reindex_pad_shared[v0, v1])
1080+
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4})
1081+
PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0))
1082+
for ax0_ax1_fused in range(512):
1083+
with T.block("weight_reindex_pad_shared"):
1084+
v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32)
1085+
v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32)
1086+
T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1])
1087+
T.writes(weight_reindex_pad_shared[v0, v1])
1088+
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2})
1089+
weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0))
1090+
for ax2_0_1 in range(1):
1091+
for ax0_0, ax1_0 in T.grid(14, 1):
1092+
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"):
1093+
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0)
1094+
v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0)
1095+
T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
1096+
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
1097+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"})
1098+
for ax0_1, ax1_1 in T.grid(16, 16):
1099+
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"):
1100+
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
1101+
T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
1102+
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
1103+
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
1104+
for ax0_0, ax1_0 in T.grid(1, 2):
1105+
with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"):
1106+
v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0)
1107+
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0)
1108+
T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
1109+
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
1110+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"})
1111+
for ax0_1, ax1_1 in T.grid(16, 16):
1112+
with T.block("weight_reindex_pad_shared_wmma.matrix_b"):
1113+
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
1114+
T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
1115+
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
1116+
weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
1117+
for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1):
1118+
with T.block("conv2d_nhwc_o"):
1119+
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4)
1120+
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4)
1121+
v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2)
1122+
T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
1123+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16])
1124+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1})
1125+
with T.init():
1126+
for ax0_1, ax1_1 in T.grid(16, 16):
1127+
with T.block("conv2d_nhwc_init"):
1128+
v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
1129+
T.reads()
1130+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init])
1131+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0)
1132+
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
1133+
with T.block("conv2d_nhwc"):
1134+
v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
1135+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
1136+
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i])
1137+
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
1138+
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
1139+
for ax2 in range(14):
1140+
for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"):
1141+
for ax2_1, ax3 in T.grid(1, 2):
1142+
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
1143+
v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused)
1144+
v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
1145+
v2_o = T.axis.spatial(14, ax2 + ax2_1)
1146+
v3_o = T.axis.spatial(2, ax3)
1147+
v4_o = T.axis.spatial(1, 0)
1148+
v5_o = T.axis.spatial(1, 0)
1149+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
1150+
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
1151+
T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"})
1152+
for ax4, ax5 in T.grid(16, 16):
1153+
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
1154+
v4_i, v5_i = T.axis.remap("SS", [ax4, ax5])
1155+
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
1156+
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
1157+
conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]
1158+
for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
1159+
with T.block("conv2d_nhwc_reindex_shared"):
1160+
v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512)
1161+
v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
1162+
v2 = T.axis.spatial(14, ax2)
1163+
v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
1164+
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
1165+
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
1166+
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5])
1167+
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32])
1168+
T.block_attr({"meta_schedule.cooperative_fetch": 3})
1169+
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
1170+
# fmt: on
1171+
1172+
decision_0 = [
1173+
("SamplePerfectTile", [7, 1, 8, 7, 2]),
1174+
("SamplePerfectTile", [2, 1, 1, 2, 1]),
1175+
("SamplePerfectTile", [10, 1, 1]),
1176+
("SampleCategorical", 2),
1177+
("SampleCategorical", 2),
1178+
("SampleCategorical", 1),
1179+
]
1180+
mod = te.create_prim_func(
1181+
te_workload.conv2d_nhwc(
1182+
1,
1183+
224,
1184+
224,
1185+
3,
1186+
64,
1187+
7,
1188+
2,
1189+
3,
1190+
in_dtype="float16",
1191+
out_dtype="float32",
1192+
)
1193+
)
1194+
actual = generate_design_space(
1195+
kind="cuda",
1196+
mod=mod,
1197+
target=tvm.target.Target("cuda --arch=sm_70"),
1198+
types=None,
1199+
sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
1200+
+ get_rules("cuda", ms.schedule_rule.AutoInline),
1201+
)
1202+
check_sketches(
1203+
mod,
1204+
sketches=actual,
1205+
expected_mods=[padded_conv2d_0],
1206+
expected_decisions=[decision_0],
1207+
)
1208+
1209+
10581210
if __name__ == "__main__":
10591211
tvm.testing.main()

0 commit comments

Comments
 (0)