|
| 1 | +import tvm |
| 2 | +from tvm import tir |
| 3 | +from tvm.script import tir as T |
| 4 | + |
| 5 | + |
| 6 | +@T.prim_func |
| 7 | +def func( |
| 8 | + p0: T.Buffer[(1, 64, 56, 56), "float32"], |
| 9 | + p1: T.Buffer[(6, 6, 64, 64), "float32"], |
| 10 | + p2: T.Buffer[(1, 64, 1, 1), "float32"], |
| 11 | + output: T.Buffer[(1, 64, 56, 56), "float32"], |
| 12 | +) -> None: |
| 13 | + # function attr dict |
| 14 | + T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| 15 | + # body |
| 16 | + # with T.block("root") |
| 17 | + data_pad = T.alloc_buffer([1, 64, 58, 58], dtype="float32") |
| 18 | + d = T.alloc_buffer([64, 196, 6, 6], dtype="float32") |
| 19 | + B = T.alloc_buffer([6, 6], dtype="float32") |
| 20 | + data_pack = T.alloc_buffer([6, 6, 64, 196], dtype="float32") |
| 21 | + bgemm = T.alloc_buffer([6, 6, 64, 196], dtype="float32") |
| 22 | + A = T.alloc_buffer([6, 4], dtype="float32") |
| 23 | + inverse = T.alloc_buffer([64, 196, 4, 4], dtype="float32") |
| 24 | + for i0, i1, i2, i3 in T.grid(1, 64, 58, 58): |
| 25 | + with T.block("data_pad"): |
| 26 | + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| 27 | + T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1]) |
| 28 | + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) |
| 29 | + # fmt: off |
| 30 | + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i2_1 and i2_1 < 57 and 1 <= i3_1 and i3_1 < 57, p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") |
| 31 | + # fmt: on |
| 32 | + for i0, i1, i2, i3 in T.grid(64, 196, 6, 6): |
| 33 | + with T.block("d"): |
| 34 | + c, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| 35 | + T.reads(data_pad[p // 196, c, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu]) |
| 36 | + T.writes(d[c, p, eps, nu]) |
| 37 | + d[c, p, eps, nu] = data_pad[p // 196, c, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu] |
| 38 | + for i0, i1 in T.grid(6, 6): |
| 39 | + with T.block("B"): |
| 40 | + i, j = T.axis.remap("SS", [i0, i1]) |
| 41 | + T.reads() |
| 42 | + T.writes(B[i, j]) |
| 43 | + # T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) |
| 44 | + # fmt: off |
| 45 | + B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) |
| 46 | + # fmt: on |
| 47 | + for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 64, 196, 6, 6): |
| 48 | + with T.block("data_pack"): |
| 49 | + eps, nu, ci, p, r_a, r_a_1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) |
| 50 | + T.reads( |
| 51 | + d[ci, p, r_a, r_a_1], |
| 52 | + B[T.min(r_a, r_a_1) : T.max(r_a, r_a_1) + 1, T.min(eps, nu) : T.max(eps, nu) + 1], |
| 53 | + ) |
| 54 | + T.writes(data_pack[eps, nu, ci, p]) |
| 55 | + # T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"}) |
| 56 | + with T.init(): |
| 57 | + data_pack[eps, nu, ci, p] = T.float32(0) |
| 58 | + data_pack[eps, nu, ci, p] = ( |
| 59 | + data_pack[eps, nu, ci, p] + d[ci, p, r_a, r_a_1] * B[r_a, eps] * B[r_a_1, nu] |
| 60 | + ) |
| 61 | + for i0, i1, i2, i3, i4 in T.grid(6, 6, 64, 196, 64): |
| 62 | + with T.block("bgemm"): |
| 63 | + eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) |
| 64 | + T.reads(p1[eps, nu, ci, co], data_pack[eps, nu, ci, p]) |
| 65 | + T.writes(bgemm[eps, nu, co, p]) |
| 66 | + with T.init(): |
| 67 | + bgemm[eps, nu, co, p] = T.float32(0) |
| 68 | + bgemm[eps, nu, co, p] = ( |
| 69 | + bgemm[eps, nu, co, p] + p1[eps, nu, ci, co] * data_pack[eps, nu, ci, p] |
| 70 | + ) |
| 71 | + for i0, i1 in T.grid(6, 4): |
| 72 | + with T.block("A"): |
| 73 | + i, j = T.axis.remap("SS", [i0, i1]) |
| 74 | + T.reads() |
| 75 | + T.writes(A[i, j]) |
| 76 | + # T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"}) |
| 77 | + # fmt: off |
| 78 | + A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))) |
| 79 | + # fmt: on |
| 80 | + for i0, i1, i2, i3, i4, i5 in T.grid(64, 196, 4, 4, 6, 6): |
| 81 | + with T.block("inverse"): |
| 82 | + co, p, vh, vw, r_a_2, r_a_3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) |
| 83 | + T.reads( |
| 84 | + bgemm[r_a_2, r_a_3, co, p], |
| 85 | + A[T.min(r_a_2, r_a_3) : T.max(r_a_2, r_a_3) + 1, T.min(vh, vw) : T.max(vh, vw) + 1], |
| 86 | + ) |
| 87 | + T.writes(inverse[co, p, vh, vw]) |
| 88 | + # T.block_attr({"schedule_rule":"meta_schedule.winograd_inverse.nchw.cuda"}) |
| 89 | + with T.init(): |
| 90 | + inverse[co, p, vh, vw] = T.float32(0) |
| 91 | + inverse[co, p, vh, vw] = ( |
| 92 | + inverse[co, p, vh, vw] + bgemm[r_a_2, r_a_3, co, p] * A[r_a_2, vh] * A[r_a_3, vw] |
| 93 | + ) |
| 94 | + for i0, i1, i2, i3 in T.grid(1, 64, 56, 56): |
| 95 | + with T.block("output"): |
| 96 | + n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| 97 | + T.reads(inverse[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4]) |
| 98 | + T.writes(output[n, co, h, w]) |
| 99 | + # T.block_attr({"schedule_rule":"meta_schedule.winograd_output.nchw.cuda", "winograd_tile_size":4, "workload":["conv2d_nchw_winograd_without_weight_transform.cuda", ["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [6, 6, 64, 64], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "float32"]}) |
| 100 | + output[n, co, h, w] = inverse[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4] |
| 101 | + |
| 102 | + |
| 103 | +def schedule_data_pack(sch: tir.Schedule, data_pack: tir.schedule.BlockRV): |
| 104 | + loops = sch.get_loops(data_pack) |
| 105 | + |
| 106 | + # factors = sch.sample_perfect_tile(loops[2], n=2, max_innermost_factor=64) |
| 107 | + # t0 = sch.split(loops[2], factors) |
| 108 | + # |
| 109 | + # factors = sch.sample_perfect_tile(loops[3], n=2, max_innermost_factor=64) |
| 110 | + # t1 = sch.split(loops[3], factors) |
| 111 | + |
| 112 | + # sch.unroll(loops[0]) |
| 113 | + # sch.unroll(loops[1]) |
| 114 | + # sch.unroll(loops[4]) |
| 115 | + # sch.unroll(loops[5]) |
| 116 | + # sch.reorder( |
| 117 | + # t0[0], |
| 118 | + # t1[0], |
| 119 | + # t0[1], |
| 120 | + # t1[1], |
| 121 | + # loops[0], |
| 122 | + # loops[1], |
| 123 | + # loops[4], |
| 124 | + # loops[5], |
| 125 | + # ) |
| 126 | + # return t1[1] |
| 127 | + |
| 128 | + # sch.unroll(loops[0]) |
| 129 | + # sch.unroll(loops[1]) |
| 130 | + # sch.unroll(loops[4]) |
| 131 | + # sch.unroll(loops[5]) |
| 132 | + t0_t1 = sch.fuse(loops[2], loops[3]) |
| 133 | + t0, t1 = sch.split(t0_t1, factors=[None, 128]) |
| 134 | + sch.reorder( |
| 135 | + t0, |
| 136 | + t1, |
| 137 | + loops[0], |
| 138 | + loops[1], |
| 139 | + loops[4], |
| 140 | + loops[5], |
| 141 | + ) |
| 142 | + return t1 |
| 143 | + |
| 144 | + |
| 145 | +def main(): |
| 146 | + sch = tir.Schedule(func) |
| 147 | + sch.compute_inline(sch.get_block("A")) |
| 148 | + sch.compute_inline(sch.get_block("B")) |
| 149 | + # data_pack |
| 150 | + data_pack = sch.get_block("data_pack") |
| 151 | + (input_tile,) = sch.get_producers(data_pack) |
| 152 | + (data_pad,) = sch.get_producers(input_tile) |
| 153 | + loop = schedule_data_pack(sch, data_pack) |
| 154 | + # sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); |
| 155 | + # sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); |
| 156 | + # sch->ComputeInline(data_pad); |
| 157 | + sch.compute_at(input_tile, loop, preserve_unit_loops=True) |
| 158 | + sch.set_scope(input_tile, 0, "local") |
| 159 | + sch.compute_inline(data_pad) |
| 160 | + |
| 161 | + tvm.lower(sch.mod).show() |
| 162 | + |
| 163 | + |
| 164 | +if __name__ == "__main__": |
| 165 | + main() |
0 commit comments