diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 91a79456e579..14d05a4a340c 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -197,6 +197,8 @@ class PythonAPICall { inline void Input(String arg_name, int arg); /*! \brief Add an integer input */ inline void Input(String arg_name, int64_t arg); + /*! \brief Add a bool input */ + inline void Input(String arg_name, bool arg); /*! \brief Add a double input */ inline void Input(String arg_name, double arg); /*! \brief Add an input random variable */ @@ -462,6 +464,17 @@ void PythonAPICall::Input(String arg_name, int64_t arg) { args_.push_back(std::to_string(arg)); } +void PythonAPICall::Input(String arg_name, bool arg) { + static const char* true_str = "True"; + static const char* false_str = "False"; + arg_names_.emplace_back(std::move(arg_name)); + if (arg) { + args_.push_back(true_str); + } else { + args_.push_back(false_str); + } +} + void PythonAPICall::Input(String arg_name, double arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 7bed18b0f9ea..47f405842c98 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -84,7 +84,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -97,7 +97,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -111,7 +111,7 @@ def test_gpu_softmax_mn(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5])", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5])", @@ -121,7 +121,7 @@ def test_gpu_softmax_mn(): "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l17, l18 = sch.split(loop=l15, factors=[None, v16])", 'sch.bind(loop=l18, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v16])", @@ -161,7 +161,7 @@ def test_gpu_softmax_mn_after_inline(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -175,14 +175,14 @@ def test_gpu_softmax_mn_after_inline(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5])", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5])", 'sch.bind(loop=l12, thread_axis="threadIdx.x")', "b13, b14 = sch.get_consumers(block=b0)", "l15, l16, l17, l18 = sch.get_loops(block=b13)", - "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v5])", @@ -210,7 +210,7 @@ def test_gpu_batch_norm_bmn(): "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l4, l5 = sch.split(loop=l2, factors=[None, v3])", 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l6, l7, l8, l9 = sch.get_loops(block=b0)", "l10 = sch.fuse(l8, l9)", diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index c6a63aae7427..8882ed625bf1 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -55,7 +55,7 @@ def test_cpu_matmul(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -69,7 +69,7 @@ def test_cpu_matmul(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -116,7 +116,7 @@ def test_cpu_matmul_relu(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -130,7 +130,7 @@ def test_cpu_matmul_relu(): "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -186,15 +186,15 @@ def test_cuda_matmul(): 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)", 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", @@ -240,15 +240,15 @@ def test_cuda_matmul_relu(): "l32 = sch.fuse(l11, l21)", 'sch.bind(loop=l32, thread_axis="threadIdx.x")', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True)", 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index 92c7da922c39..18db006c6ca8 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -74,7 +74,7 @@ def test_random_compute_location(): [ 'b0 = sch.get_block(name="move", func_name="main")', "l1 = sch.sample_compute_location(block=b0)", - "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)", ] ] mod = Add diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index f1c97c57b2ff..1923eb23af5b 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -82,6 +82,15 @@ def _make_compute_inline(input): # pylint: disable=redefined-builtin ) +def _make_split(inputs, outputs): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("Split"), + inputs=inputs, + attrs=[], + outputs=outputs, + ) + + def _make_enter_postproc(): return Instruction( kind=InstructionKind.get("EnterPostproc"), @@ -129,6 +138,17 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name return Trace(insts=insts, decisions={}) +def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="B", output=b0), + _make_get_loops(input=b0, outputs=[l1]), + _make_split([l1, None, 32], [l2, l3]), + ], + decisions={}, + ) + + def test_trace_construct_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( @@ -235,6 +255,17 @@ def test_trace_simplified_2(): ) +def test_trace_simplified_3(): + trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "l1, = sch.get_loops(block=b0)", + "l2, l3 = sch.split(loop=l1, factors=[None, 32])", + ) + ) + + def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json()