Skip to content

Commit d3d8715

Browse files
authored
[TIR] TIR Schedule Misc Update (#10341)
* tir schedule misc update * Trigger Build
1 parent 8947729 commit d3d8715

File tree

5 files changed

+63
-19
lines changed

5 files changed

+63
-19
lines changed

src/tir/schedule/instruction_traits.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class PythonAPICall {
197197
inline void Input(String arg_name, int arg);
198198
/*! \brief Add an integer input */
199199
inline void Input(String arg_name, int64_t arg);
200+
/*! \brief Add a bool input */
201+
inline void Input(String arg_name, bool arg);
200202
/*! \brief Add a double input */
201203
inline void Input(String arg_name, double arg);
202204
/*! \brief Add an input random variable */
@@ -462,6 +464,17 @@ void PythonAPICall::Input(String arg_name, int64_t arg) {
462464
args_.push_back(std::to_string(arg));
463465
}
464466

467+
void PythonAPICall::Input(String arg_name, bool arg) {
468+
static const char* true_str = "True";
469+
static const char* false_str = "False";
470+
arg_names_.emplace_back(std::move(arg_name));
471+
if (arg) {
472+
args_.push_back(true_str);
473+
} else {
474+
args_.push_back(false_str);
475+
}
476+
}
477+
465478
void PythonAPICall::Input(String arg_name, double arg) {
466479
arg_names_.emplace_back(std::move(arg_name));
467480
std::ostringstream os;

tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_gpu_softmax_mn():
8484
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
8585
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
8686
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
87-
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
87+
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
8888
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
8989
"l7, l8, l9 = sch.get_loops(block=b0)",
9090
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
@@ -97,7 +97,7 @@ def test_gpu_softmax_mn():
9797
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
9898
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
9999
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
100-
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
100+
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
101101
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
102102
"l7, l8, l9 = sch.get_loops(block=b0)",
103103
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
@@ -111,7 +111,7 @@ def test_gpu_softmax_mn():
111111
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
112112
"l6, l7 = sch.split(loop=l4, factors=[None, v5])",
113113
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
114-
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)",
114+
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
115115
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
116116
"l8, l9, l10 = sch.get_loops(block=b1)",
117117
"l11, l12 = sch.split(loop=l10, factors=[None, v5])",
@@ -121,7 +121,7 @@ def test_gpu_softmax_mn():
121121
"v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
122122
"l17, l18 = sch.split(loop=l15, factors=[None, v16])",
123123
'sch.bind(loop=l18, thread_axis="threadIdx.x")',
124-
"sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1)",
124+
"sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
125125
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
126126
"l19, l20, l21 = sch.get_loops(block=b0)",
127127
"l22, l23 = sch.split(loop=l21, factors=[None, v16])",
@@ -161,7 +161,7 @@ def test_gpu_softmax_mn_after_inline():
161161
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
162162
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
163163
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
164-
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
164+
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
165165
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
166166
"l7, l8, l9 = sch.get_loops(block=b0)",
167167
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
@@ -175,14 +175,14 @@ def test_gpu_softmax_mn_after_inline():
175175
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
176176
"l6, l7 = sch.split(loop=l4, factors=[None, v5])",
177177
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
178-
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)",
178+
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
179179
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
180180
"l8, l9, l10 = sch.get_loops(block=b1)",
181181
"l11, l12 = sch.split(loop=l10, factors=[None, v5])",
182182
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
183183
"b13, b14 = sch.get_consumers(block=b0)",
184184
"l15, l16, l17, l18 = sch.get_loops(block=b13)",
185-
"sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1)",
185+
"sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
186186
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
187187
"l19, l20, l21 = sch.get_loops(block=b0)",
188188
"l22, l23 = sch.split(loop=l21, factors=[None, v5])",
@@ -210,7 +210,7 @@ def test_gpu_batch_norm_bmn():
210210
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
211211
"l4, l5 = sch.split(loop=l2, factors=[None, v3])",
212212
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
213-
"sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1)",
213+
"sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
214214
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
215215
"l6, l7, l8, l9 = sch.get_loops(block=b0)",
216216
"l10 = sch.fuse(l8, l9)",

tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_cpu_matmul():
5555
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
5656
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
5757
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
58-
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)",
58+
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
5959
],
6060
[
6161
'b0 = sch.get_block(name="C", func_name="main")',
@@ -69,7 +69,7 @@ def test_cpu_matmul():
6969
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
7070
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
7171
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
72-
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)",
72+
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
7373
],
7474
[
7575
'b0 = sch.get_block(name="C", func_name="main")',
@@ -116,7 +116,7 @@ def test_cpu_matmul_relu():
116116
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
117117
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
118118
"b24, = sch.get_consumers(block=b0)",
119-
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)",
119+
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
120120
],
121121
[
122122
'b0 = sch.get_block(name="C", func_name="main")',
@@ -130,7 +130,7 @@ def test_cpu_matmul_relu():
130130
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
131131
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
132132
"b24, = sch.get_consumers(block=b0)",
133-
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)",
133+
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
134134
],
135135
[
136136
'b0 = sch.get_block(name="C", func_name="main")',
@@ -186,15 +186,15 @@ def test_cuda_matmul():
186186
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
187187
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
188188
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
189-
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)",
189+
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
190190
'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
191-
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)",
191+
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
192192
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
193193
"l41 = sch.fuse(l39, l40)",
194194
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
195195
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
196196
'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
197-
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)",
197+
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
198198
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
199199
"l50 = sch.fuse(l48, l49)",
200200
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
@@ -240,15 +240,15 @@ def test_cuda_matmul_relu():
240240
"l32 = sch.fuse(l11, l21)",
241241
'sch.bind(loop=l32, thread_axis="threadIdx.x")',
242242
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
243-
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)",
243+
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
244244
'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
245-
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)",
245+
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
246246
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
247247
"l41 = sch.fuse(l39, l40)",
248248
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
249249
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
250250
'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
251-
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)",
251+
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
252252
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
253253
"l50 = sch.fuse(l48, l49)",
254254
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",

tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_random_compute_location():
7474
[
7575
'b0 = sch.get_block(name="move", func_name="main")',
7676
"l1 = sch.sample_compute_location(block=b0)",
77-
"sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)",
77+
"sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)",
7878
]
7979
]
8080
mod = Add

tests/python/unittest/test_tir_schedule_trace.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def _make_compute_inline(input): # pylint: disable=redefined-builtin
8282
)
8383

8484

85+
def _make_split(inputs, outputs): # pylint: disable=redefined-builtin
86+
return Instruction(
87+
kind=InstructionKind.get("Split"),
88+
inputs=inputs,
89+
attrs=[],
90+
outputs=outputs,
91+
)
92+
93+
8594
def _make_enter_postproc():
8695
return Instruction(
8796
kind=InstructionKind.get("EnterPostproc"),
@@ -129,6 +138,17 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name
129138
return Trace(insts=insts, decisions={})
130139

131140

141+
def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name
142+
return Trace(
143+
insts=[
144+
_make_get_block(name="B", output=b0),
145+
_make_get_loops(input=b0, outputs=[l1]),
146+
_make_split([l1, None, 32], [l2, l3]),
147+
],
148+
decisions={},
149+
)
150+
151+
132152
def test_trace_construct_1():
133153
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
134154
assert str(trace) == "\n".join(
@@ -235,6 +255,17 @@ def test_trace_simplified_2():
235255
)
236256

237257

258+
def test_trace_simplified_3():
259+
trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False)
260+
assert str(trace) == "\n".join(
261+
(
262+
'b0 = sch.get_block(name="B", func_name="main")',
263+
"l1, = sch.get_loops(block=b0)",
264+
"l2, l3 = sch.split(loop=l1, factors=[None, 32])",
265+
)
266+
)
267+
268+
238269
def test_apply_json_to_schedule_1():
239270
trace = _make_trace_2(BlockRV())
240271
json_obj = trace.as_json()

0 commit comments

Comments
 (0)