Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/tir/schedule/instruction_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class PythonAPICall {
inline void Input(String arg_name, int arg);
/*! \brief Add an integer input */
inline void Input(String arg_name, int64_t arg);
/*! \brief Add a bool input */
inline void Input(String arg_name, bool arg);
/*! \brief Add a double input */
inline void Input(String arg_name, double arg);
/*! \brief Add an input random variable */
Expand Down Expand Up @@ -462,6 +464,17 @@ void PythonAPICall::Input(String arg_name, int64_t arg) {
args_.push_back(std::to_string(arg));
}

void PythonAPICall::Input(String arg_name, bool arg) {
static const char* true_str = "True";
static const char* false_str = "False";
arg_names_.emplace_back(std::move(arg_name));
if (arg) {
args_.push_back(true_str);
} else {
args_.push_back(false_str);
}
}

void PythonAPICall::Input(String arg_name, double arg) {
arg_names_.emplace_back(std::move(arg_name));
std::ostringstream os;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_gpu_softmax_mn():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
Expand All @@ -97,7 +97,7 @@ def test_gpu_softmax_mn():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
Expand All @@ -111,7 +111,7 @@ def test_gpu_softmax_mn():
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l6, l7 = sch.split(loop=l4, factors=[None, v5])",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)",
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
"l11, l12 = sch.split(loop=l10, factors=[None, v5])",
Expand All @@ -121,7 +121,7 @@ def test_gpu_softmax_mn():
"v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l17, l18 = sch.split(loop=l15, factors=[None, v16])",
'sch.bind(loop=l18, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
"l22, l23 = sch.split(loop=l21, factors=[None, v16])",
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_gpu_softmax_mn_after_inline():
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9 = sch.get_loops(block=b0)",
"l10, l11 = sch.split(loop=l9, factors=[None, v4])",
Expand All @@ -175,14 +175,14 @@ def test_gpu_softmax_mn_after_inline():
"v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l6, l7 = sch.split(loop=l4, factors=[None, v5])",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)",
"sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l8, l9, l10 = sch.get_loops(block=b1)",
"l11, l12 = sch.split(loop=l10, factors=[None, v5])",
'sch.bind(loop=l12, thread_axis="threadIdx.x")',
"b13, b14 = sch.get_consumers(block=b0)",
"l15, l16, l17, l18 = sch.get_loops(block=b13)",
"sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l19, l20, l21 = sch.get_loops(block=b0)",
"l22, l23 = sch.split(loop=l21, factors=[None, v5])",
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_gpu_batch_norm_bmn():
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l4, l5 = sch.split(loop=l2, factors=[None, v3])",
'sch.bind(loop=l5, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l6, l7, l8, l9 = sch.get_loops(block=b0)",
"l10 = sch.fuse(l8, l9)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_cpu_matmul():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
Expand All @@ -69,7 +69,7 @@ def test_cpu_matmul():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")',
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_cpu_matmul_relu():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
Expand All @@ -130,7 +130,7 @@ def test_cpu_matmul_relu():
"l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
"sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
"b24, = sch.get_consumers(block=b0)",
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)",
],
[
'b0 = sch.get_block(name="C", func_name="main")',
Expand Down Expand Up @@ -186,15 +186,15 @@ def test_cuda_matmul():
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)",
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
"l41 = sch.fuse(l39, l40)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)",
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
"l50 = sch.fuse(l48, l49)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
Expand Down Expand Up @@ -240,15 +240,15 @@ def test_cuda_matmul_relu():
"l32 = sch.fuse(l11, l21)",
'sch.bind(loop=l32, thread_axis="threadIdx.x")',
'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)",
"sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)",
'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)",
"sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)",
"l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
"l41 = sch.fuse(l39, l40)",
"v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)",
"sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)",
"l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
"l50 = sch.fuse(l48, l49)",
"v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_random_compute_location():
[
'b0 = sch.get_block(name="move", func_name="main")',
"l1 = sch.sample_compute_location(block=b0)",
"sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)",
"sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)",
]
]
mod = Add
Expand Down
31 changes: 31 additions & 0 deletions tests/python/unittest/test_tir_schedule_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def _make_compute_inline(input): # pylint: disable=redefined-builtin
)


def _make_split(inputs, outputs): # pylint: disable=redefined-builtin
return Instruction(
kind=InstructionKind.get("Split"),
inputs=inputs,
attrs=[],
outputs=outputs,
)


def _make_enter_postproc():
return Instruction(
kind=InstructionKind.get("EnterPostproc"),
Expand Down Expand Up @@ -129,6 +138,17 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name
return Trace(insts=insts, decisions={})


def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name
return Trace(
insts=[
_make_get_block(name="B", output=b0),
_make_get_loops(input=b0, outputs=[l1]),
_make_split([l1, None, 32], [l2, l3]),
],
decisions={},
)


def test_trace_construct_1():
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
assert str(trace) == "\n".join(
Expand Down Expand Up @@ -235,6 +255,17 @@ def test_trace_simplified_2():
)


def test_trace_simplified_3():
trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False)
assert str(trace) == "\n".join(
(
'b0 = sch.get_block(name="B", func_name="main")',
"l1, = sch.get_loops(block=b0)",
"l2, l3 = sch.split(loop=l1, factors=[None, 32])",
)
)


def test_apply_json_to_schedule_1():
trace = _make_trace_2(BlockRV())
json_obj = trace.as_json()
Expand Down