Skip to content

Commit 888fc75

Browse files
committed
stash
1 parent 78818bc commit 888fc75

File tree

4 files changed

+212
-46
lines changed

4 files changed

+212
-46
lines changed

main.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import tvm
2+
from tvm import tir
3+
from tvm.script import tir as T
4+
5+
6+
@T.prim_func
7+
def func(
8+
p0: T.Buffer[(1, 64, 56, 56), "float32"],
9+
p1: T.Buffer[(6, 6, 64, 64), "float32"],
10+
p2: T.Buffer[(1, 64, 1, 1), "float32"],
11+
output: T.Buffer[(1, 64, 56, 56), "float32"],
12+
) -> None:
13+
# function attr dict
14+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
15+
# body
16+
# with T.block("root")
17+
data_pad = T.alloc_buffer([1, 64, 58, 58], dtype="float32")
18+
d = T.alloc_buffer([64, 196, 6, 6], dtype="float32")
19+
B = T.alloc_buffer([6, 6], dtype="float32")
20+
data_pack = T.alloc_buffer([6, 6, 64, 196], dtype="float32")
21+
bgemm = T.alloc_buffer([6, 6, 64, 196], dtype="float32")
22+
A = T.alloc_buffer([6, 4], dtype="float32")
23+
inverse = T.alloc_buffer([64, 196, 4, 4], dtype="float32")
24+
for i0, i1, i2, i3 in T.grid(1, 64, 58, 58):
25+
with T.block("data_pad"):
26+
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
27+
T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1])
28+
T.writes(data_pad[i0_1, i1_1, i2_1, i3_1])
29+
# fmt: off
30+
data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i2_1 and i2_1 < 57 and 1 <= i3_1 and i3_1 < 57, p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32")
31+
# fmt: on
32+
for i0, i1, i2, i3 in T.grid(64, 196, 6, 6):
33+
with T.block("d"):
34+
c, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3])
35+
T.reads(data_pad[p // 196, c, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu])
36+
T.writes(d[c, p, eps, nu])
37+
d[c, p, eps, nu] = data_pad[p // 196, c, p % 196 // 14 * 4 + eps, p % 14 * 4 + nu]
38+
for i0, i1 in T.grid(6, 6):
39+
with T.block("B"):
40+
i, j = T.axis.remap("SS", [i0, i1])
41+
T.reads()
42+
T.writes(B[i, j])
43+
# T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"})
44+
# fmt: off
45+
B[i, j] = T.Select(i % 6 == 5 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 5 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 6 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 6 == 5, T.float32(1.5), T.Select(i % 6 == 4 and j % 6 == 4, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 3, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 2, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 4 and j % 6 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 6 == 5, T.float32(-2), T.Select(i % 6 == 3 and j % 6 == 4, T.float32(-0.5), T.Select(i % 6 == 3 and j % 6 == 3, T.float32(2), T.Select(i % 6 == 3 and j % 6 == 2, T.float32(2.5), T.Select(i % 6 == 3 and j % 6 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 6 == 0, T.float32(1.5), T.Select(i % 6 == 2 and j % 6 == 5, T.float32(-1.5), T.Select(i % 6 == 2 and j % 6 == 4, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 3, T.float32(-1), T.Select(i % 6 == 2 and j % 6 == 2, T.float32(0.5), T.Select(i % 6 == 2 and j % 6 == 1, T.float32(-2.5), T.Select(i % 6 == 2 and j % 6 == 0, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 5, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 4, T.float32(0.5), T.Select(i % 6 == 1 and j % 6 == 3, T.float32(-2), T.Select(i % 6 == 1 and j % 6 == 2, T.float32(-1), T.Select(i % 6 == 1 and j % 6 == 1, T.float32(1), T.Select(i % 6 == 1 and j % 6 == 0, T.float32(-1.5), T.Select(i % 6 == 0 and j % 6 == 5, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 4, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 6 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))
46+
# fmt: on
47+
for i0, i1, i2, i3, i4, i5 in T.grid(6, 6, 64, 196, 6, 6):
48+
with T.block("data_pack"):
49+
eps, nu, ci, p, r_a, r_a_1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
50+
T.reads(
51+
d[ci, p, r_a, r_a_1],
52+
B[T.min(r_a, r_a_1) : T.max(r_a, r_a_1) + 1, T.min(eps, nu) : T.max(eps, nu) + 1],
53+
)
54+
T.writes(data_pack[eps, nu, ci, p])
55+
# T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"})
56+
with T.init():
57+
data_pack[eps, nu, ci, p] = T.float32(0)
58+
data_pack[eps, nu, ci, p] = (
59+
data_pack[eps, nu, ci, p] + d[ci, p, r_a, r_a_1] * B[r_a, eps] * B[r_a_1, nu]
60+
)
61+
for i0, i1, i2, i3, i4 in T.grid(6, 6, 64, 196, 64):
62+
with T.block("bgemm"):
63+
eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4])
64+
T.reads(p1[eps, nu, ci, co], data_pack[eps, nu, ci, p])
65+
T.writes(bgemm[eps, nu, co, p])
66+
with T.init():
67+
bgemm[eps, nu, co, p] = T.float32(0)
68+
bgemm[eps, nu, co, p] = (
69+
bgemm[eps, nu, co, p] + p1[eps, nu, ci, co] * data_pack[eps, nu, ci, p]
70+
)
71+
for i0, i1 in T.grid(6, 4):
72+
with T.block("A"):
73+
i, j = T.axis.remap("SS", [i0, i1])
74+
T.reads()
75+
T.writes(A[i, j])
76+
# T.block_attr({"const_matrix":True, "schedule_rule":"meta_schedule.compute_inline"})
77+
# fmt: off
78+
A[i, j] = T.Select(i % 6 == 5 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 5 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 5 and j % 4 == 0, T.float32(0), T.Select(i % 6 == 4 and j % 4 == 3, T.float32(-8), T.Select(i % 6 == 4 and j % 4 == 2, T.float32(4), T.Select(i % 6 == 4 and j % 4 == 1, T.float32(-2), T.Select(i % 6 == 4 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 3 and j % 4 == 3, T.float32(0.125), T.Select(i % 6 == 3 and j % 4 == 2, T.float32(0.25), T.Select(i % 6 == 3 and j % 4 == 1, T.float32(0.5), T.Select(i % 6 == 3 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 3, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 6 == 2 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 6 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 6 == 1 and j % 4 == 0, T.float32(1), T.Select(i % 6 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 6 == 0 and j % 4 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))
79+
# fmt: on
80+
for i0, i1, i2, i3, i4, i5 in T.grid(64, 196, 4, 4, 6, 6):
81+
with T.block("inverse"):
82+
co, p, vh, vw, r_a_2, r_a_3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
83+
T.reads(
84+
bgemm[r_a_2, r_a_3, co, p],
85+
A[T.min(r_a_2, r_a_3) : T.max(r_a_2, r_a_3) + 1, T.min(vh, vw) : T.max(vh, vw) + 1],
86+
)
87+
T.writes(inverse[co, p, vh, vw])
88+
# T.block_attr({"schedule_rule":"meta_schedule.winograd_inverse.nchw.cuda"})
89+
with T.init():
90+
inverse[co, p, vh, vw] = T.float32(0)
91+
inverse[co, p, vh, vw] = (
92+
inverse[co, p, vh, vw] + bgemm[r_a_2, r_a_3, co, p] * A[r_a_2, vh] * A[r_a_3, vw]
93+
)
94+
for i0, i1, i2, i3 in T.grid(1, 64, 56, 56):
95+
with T.block("output"):
96+
n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3])
97+
T.reads(inverse[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4])
98+
T.writes(output[n, co, h, w])
99+
# T.block_attr({"schedule_rule":"meta_schedule.winograd_output.nchw.cuda", "winograd_tile_size":4, "workload":["conv2d_nchw_winograd_without_weight_transform.cuda", ["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [6, 6, 64, 64], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "float32"]})
100+
output[n, co, h, w] = inverse[co, n * 196 + h // 4 * 14 + w // 4, h % 4, w % 4]
101+
102+
103+
def schedule_data_pack(sch: tir.Schedule, data_pack: tir.schedule.BlockRV):
104+
loops = sch.get_loops(data_pack)
105+
106+
# factors = sch.sample_perfect_tile(loops[2], n=2, max_innermost_factor=64)
107+
# t0 = sch.split(loops[2], factors)
108+
#
109+
# factors = sch.sample_perfect_tile(loops[3], n=2, max_innermost_factor=64)
110+
# t1 = sch.split(loops[3], factors)
111+
112+
# sch.unroll(loops[0])
113+
# sch.unroll(loops[1])
114+
# sch.unroll(loops[4])
115+
# sch.unroll(loops[5])
116+
# sch.reorder(
117+
# t0[0],
118+
# t1[0],
119+
# t0[1],
120+
# t1[1],
121+
# loops[0],
122+
# loops[1],
123+
# loops[4],
124+
# loops[5],
125+
# )
126+
# return t1[1]
127+
128+
# sch.unroll(loops[0])
129+
# sch.unroll(loops[1])
130+
# sch.unroll(loops[4])
131+
# sch.unroll(loops[5])
132+
t0_t1 = sch.fuse(loops[2], loops[3])
133+
t0, t1 = sch.split(t0_t1, factors=[None, 128])
134+
sch.reorder(
135+
t0,
136+
t1,
137+
loops[0],
138+
loops[1],
139+
loops[4],
140+
loops[5],
141+
)
142+
return t1
143+
144+
145+
def main():
146+
sch = tir.Schedule(func)
147+
sch.compute_inline(sch.get_block("A"))
148+
sch.compute_inline(sch.get_block("B"))
149+
# data_pack
150+
data_pack = sch.get_block("data_pack")
151+
(input_tile,) = sch.get_producers(data_pack)
152+
(data_pad,) = sch.get_producers(input_tile)
153+
loop = schedule_data_pack(sch, data_pack)
154+
# sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true);
155+
# sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local");
156+
# sch->ComputeInline(data_pad);
157+
sch.compute_at(input_tile, loop, preserve_unit_loops=True)
158+
sch.set_scope(input_tile, 0, "local")
159+
sch.compute_inline(data_pad)
160+
161+
tvm.lower(sch.mod).show()
162+
163+
164+
if __name__ == "__main__":
165+
main()

python/tvm/topi/cuda/conv2d_winograd.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
174174
return output
175175

176176

177+
FN_INPUTS = None
178+
179+
180+
@tvm._ffi.register_func("debug_store_fn_inputs")
181+
def debug_store_fn_inputs(fn_inputs):
182+
global FN_INPUTS
183+
FN_INPUTS = fn_inputs
184+
185+
177186
def schedule_winograd_cuda(cfg, s, output, pre_computed):
178187
"""Schedule winograd template"""
179188
# get stages
@@ -183,27 +192,47 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
183192
input_tile, B = s[data_pack].op.input_tensors
184193
pad_data = s[input_tile].op.input_tensors[0]
185194

195+
def _print_lower(msg):
196+
if pre_computed:
197+
print(msg)
198+
p0 = pad_data.op.input_tensors[0]
199+
p1 = kernel_pack
200+
pn = FN_INPUTS[2:]
201+
tvm.lower(s, [p0, p1, *pn, output]).show()
202+
203+
def _print_tir():
204+
if pre_computed:
205+
p0 = pad_data.op.input_tensors[0]
206+
p1 = kernel_pack
207+
pn = FN_INPUTS[2:]
208+
tvm.te.create_prim_func([p0, p1, *pn, output]).show()
209+
210+
_print_lower("initial")
211+
_print_tir()
212+
186213
# data transform
187214
s[B].compute_inline()
188215

189216
data_l = s.cache_write(data_pack, "local")
190217
eps, nu, c, p = s[data_l].op.axis
191218
r_a, r_b = s[data_l].op.reduce_axis
192-
for axis in [eps, nu, r_a, r_b]:
193-
s[data_l].unroll(axis)
219+
# for axis in [eps, nu, r_a, r_b]:
220+
# s[data_l].unroll(axis)
194221

195222
eps, nu, c, p = s[data_pack].op.axis
196223
p, pi = s[data_pack].split(p, 1)
197224
fused = s[data_pack].fuse(c, p)
198225
bb, tt = s[data_pack].split(fused, 128)
199226
s[data_pack].reorder(bb, tt, pi, eps, nu)
200-
s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
201-
s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
227+
# s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
228+
# s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
202229

203230
s[data_l].compute_at(s[data_pack], pi)
204231
s[input_tile].compute_at(s[data_pack], pi)
205232
s[pad_data].compute_inline()
206233

234+
_print_lower("after `data_pack`")
235+
207236
# transform kernel
208237
if not pre_computed:
209238
kernel, G = s[kernel_pack].op.input_tensors
@@ -296,8 +325,8 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
296325
s[load].bind(ty, te.thread_axis("threadIdx.y"))
297326
s[load].bind(tx, te.thread_axis("threadIdx.x"))
298327

299-
s[C].pragma(bgemm_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
300-
s[C].pragma(bgemm_scope, "unroll_explicit", cfg["unroll_explicit"].val)
328+
# s[C].pragma(bgemm_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
329+
# s[C].pragma(bgemm_scope, "unroll_explicit", cfg["unroll_explicit"].val)
301330

302331
# schedule inverse, output and fusion
303332
if output.op in s.outputs:
@@ -328,6 +357,8 @@ def schedule_winograd_cuda(cfg, s, output, pre_computed):
328357
s[inverse].unroll(axis)
329358
s[inverse].compute_at(s[output], tt)
330359

360+
if pre_computed:
361+
breakpoint()
331362
return s
332363

333364

src/meta_schedule/schedule_rule/winograd.cc

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,8 @@ inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) {
5555
Array<LoopRV> t1 = sch->Split(loops[3], {factors.begin(), factors.end()});
5656
ICHECK_EQ(t1.size(), 2);
5757

58-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) {
59-
if (*i <= 16) {
60-
sch->Unroll(loops[0]);
61-
}
62-
}
63-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) {
64-
if (*i <= 16) {
65-
sch->Unroll(loops[1]);
66-
}
67-
}
58+
sch->Unroll(loops[0]);
59+
sch->Unroll(loops[1]);
6860
sch->Unroll(loops[4]);
6961
sch->Unroll(loops[5]);
7062
sch->Reorder({
@@ -127,16 +119,8 @@ inline LoopRV ScheduleDataPackNCHW(Schedule sch, BlockRV block) {
127119
Array<LoopRV> loops = sch->GetLoops(block);
128120
ICHECK_EQ(loops.size(), 6);
129121

130-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) {
131-
if (*i <= 16) {
132-
sch->Unroll(loops[0]);
133-
}
134-
}
135-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) {
136-
if (*i <= 16) {
137-
sch->Unroll(loops[1]);
138-
}
139-
}
122+
sch->Unroll(loops[0]);
123+
sch->Unroll(loops[1]);
140124
sch->Unroll(loops[4]);
141125
sch->Unroll(loops[5]);
142126

@@ -185,16 +169,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.nchw.cuda")
185169
sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local");
186170
Array<LoopRV> loops = sch->GetLoops(inverse);
187171
ICHECK_EQ(loops.size(), 6);
188-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[2]))) {
189-
if (*i <= 16) {
190-
sch->Unroll(loops[2]);
191-
}
192-
}
193-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[3]))) {
194-
if (*i <= 16) {
195-
sch->Unroll(loops[3]);
196-
}
197-
}
172+
sch->Unroll(loops[2]);
173+
sch->Unroll(loops[3]);
198174
sch->Unroll(loops[4]);
199175
sch->Unroll(loops[5]);
200176
return {sch};
@@ -204,16 +180,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_kernel_pack.nchw.cuda")
204180
.set_body_typed([](Schedule sch, BlockRV kernel_pack) -> Array<Schedule> {
205181
Array<LoopRV> loops = sch->GetLoops(kernel_pack);
206182
ICHECK_EQ(loops.size(), 6);
207-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) {
208-
if (*i <= 16) {
209-
sch->Unroll(loops[0]);
210-
}
211-
}
212-
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) {
213-
if (*i <= 16) {
214-
sch->Unroll(loops[1]);
215-
}
216-
}
183+
sch->Unroll(loops[0]);
184+
sch->Unroll(loops[1]);
217185
sch->Unroll(loops[4]);
218186
sch->Unroll(loops[5]);
219187

src/relay/backend/te_compiler_cache.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ class ScheduleBuilder : public ExprVisitor {
477477
LowerToTECompute lower_te_compute(target_);
478478
Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func);
479479
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
480+
static const auto* f_store_fn_inputs = runtime::Registry::Get("debug_store_fn_inputs");
481+
(*f_store_fn_inputs)(fn_inputs);
480482
VisitExpr(relay_func->body);
481483

482484
// TODO(mbs): This should be the definitive global by which the PrimFunc is known and

0 commit comments

Comments
 (0)