Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
sch->ComputeInline(consumer);
auto sref = sch->GetSRef(consumer);
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
sch->ComputeInline(consumer);
}
}
// Construct a mapping from tir loops back to LoopRVs
Expand Down
152 changes: 152 additions & 0 deletions tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
)


def test_padded_conv():
# fmt: off
@T.prim_func
def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared")
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator")
PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared")
weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared")
PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a")
weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"):
for ax2_0_0 in range(10):
for ax0_ax1_fused in range(28672):
with T.block("PadInput_reindex_pad_shared"):
v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16)
v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16)
T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3])
T.writes(PadInput_reindex_pad_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4})
PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0))
for ax0_ax1_fused in range(512):
with T.block("weight_reindex_pad_shared"):
v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32)
T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1])
T.writes(weight_reindex_pad_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2})
weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0))
for ax2_0_1 in range(1):
for ax0_0, ax1_0 in T.grid(14, 1):
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0)
v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0)
T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0_0, ax1_0 in T.grid(1, 2):
with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0)
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0)
T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("weight_reindex_pad_shared_wmma.matrix_b"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1):
with T.block("conv2d_nhwc_o"):
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4)
v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2)
T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init])
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
with T.block("conv2d_nhwc"):
v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
for ax2 in range(14):
for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"):
for ax2_1, ax3 in T.grid(1, 2):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused)
v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
v2_o = T.axis.spatial(14, ax2 + ax2_1)
v3_o = T.axis.spatial(2, ax3)
v4_o = T.axis.spatial(1, 0)
v5_o = T.axis.spatial(1, 0)
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"})
for ax4, ax5 in T.grid(16, 16):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
v4_i, v5_i = T.axis.remap("SS", [ax4, ax5])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]
for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512)
v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
v2 = T.axis.spatial(14, ax2)
v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32])
T.block_attr({"meta_schedule.cooperative_fetch": 3})
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
# fmt: on

decision_0 = [
("SamplePerfectTile", [7, 1, 8, 7, 2]),
("SamplePerfectTile", [2, 1, 1, 2, 1]),
("SamplePerfectTile", [10, 1, 1]),
("SampleCategorical", 2),
("SampleCategorical", 2),
("SampleCategorical", 1),
]
mod = te.create_prim_func(
te_workload.conv2d_nhwc(
1,
224,
224,
3,
64,
7,
2,
3,
in_dtype="float16",
out_dtype="float32",
)
)
actual = generate_design_space(
kind="cuda",
mod=mod,
target=tvm.target.Target("cuda --arch=sm_70"),
types=None,
sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
+ get_rules("cuda", ms.schedule_rule.AutoInline),
)
check_sketches(
mod,
sketches=actual,
expected_mods=[padded_conv2d_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
tvm.testing.main()