From 1a1784c927200f3c304dc9868740492d15166f3b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 May 2023 08:36:52 -0500 Subject: [PATCH 1/2] [Util] Handle AllocateConst in MergeNest --- src/tir/transforms/ir_utils.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 604dbed325ec..43bf6b983eb5 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); + } else if (const auto* alloc = s.as()) { + auto n = make_object(*alloc); + ICHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); } else if (const auto* decl_buffer = s.as()) { auto n = make_object(*decl_buffer); ICHECK(is_no_op(n->body)); From 1f2192787992d605bbf8cb678cd8da346a59d616 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 May 2023 08:37:48 -0500 Subject: [PATCH 2/2] [TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt Part of changes being split out from https://github.com/apache/tvm/pull/14778 into independent portions. This commit allows TIR `compute_inline`, `compute_at`, and `reverse_compute_at` schedule primitives to preserve `DeclBuffer` nodes. --- src/tir/schedule/transform.cc | 28 ++++-- .../unittest/test_tir_schedule_compute_at.py | 99 +++++++++++++++++++ 2 files changed, 117 insertions(+), 10 deletions(-) diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index baa7f44bbcf2..9c209658c3e6 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -17,6 +17,7 @@ * under the License. */ +#include "../transforms/ir_utils.h" #include "./utils.h" namespace tvm { @@ -261,21 +262,28 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ if (const auto* block = sref->StmtAs()) { auto body = block->body; // Peel off AllocateConst nodes at the beginning of the block body. - std::vector allocs; - while (const auto* alloc = body.as()) { - allocs.push_back(alloc); - body = alloc->body; + std::vector allocs; + while (true) { + if (auto opt = body.as()) { + auto alloc = opt.value(); + body = alloc->body; + alloc.CopyOnWrite()->body = Evaluate(0); + allocs.push_back(alloc); + } else if (auto opt = body.as()) { + auto decl_buffer = opt.value(); + body = decl_buffer->body; + decl_buffer.CopyOnWrite()->body = Evaluate(0); + allocs.push_back(decl_buffer); + } else { + break; + } } + if (const auto* seq = body.as()) { ObjectPtr n = make_object(*block); auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); // Re-attach AllocateConst nodes - auto new_body = new_seq; - for (int i = 0; i < static_cast(allocs.size()); ++i) { - auto alloc = allocs[allocs.size() - 1 - i]; - new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, - new_body, alloc->annotations, alloc->span); - } + auto new_body = MergeNest(allocs, new_seq); n->body = new_body; *src_stmt = GetRef(block); *tgt_stmt = Stmt(std::move(n)); diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 0623fb02f3d6..7efb4cccc0d0 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1672,5 +1672,104 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), verify_trace_roundtrip(sch=sch, mod=before) +@pytest.mark.parametrize("use_decl_buffer", [True, False]) +@pytest.mark.parametrize("use_reverse_compute_at", [True, False]) +def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at): + def apply_decl_buffer(*args, **kwargs): + if use_decl_buffer: + return T.decl_buffer(*args, **kwargs) + else: + return T.Buffer(*args, **kwargs) + + @T.prim_func + def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): + B = T.alloc_buffer([4]) + + offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) + offset = apply_decl_buffer([4], data=offset_ptr) + for i in range(4): + with T.block("compute_B"): + vi = T.axis.remap("S", [i]) + B[vi] = 10.0 * vi + offset[vi] + + for i, j in T.grid(4, 256): + with T.block("compute_C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi] + 100.0 * vj + + @T.prim_func + def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): + B = T.alloc_buffer([4]) + + offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) + offset = apply_decl_buffer([4], data=offset_ptr) + for i in range(4): + with T.block("compute_B"): + vi = T.axis.remap("S", [i]) + B[vi] = 10.0 * vi + offset[vi] + + for j in range(256): + with T.block("compute_C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi] + 100.0 * vj + + sch = tir.Schedule(before, debug_mask="all") + if use_reverse_compute_at: + block = sch.get_block("compute_C") + axis = sch.get_loops("compute_B")[0] + sch.reverse_compute_at(block, axis) + else: + block = sch.get_block("compute_B") + axis = sch.get_loops("compute_C")[0] + sch.compute_at(block, axis) + + after = sch.mod["main"] + + tvm.ir.assert_structural_equal(expected, after) + verify_trace_roundtrip(sch=sch, mod=before) + + +@pytest.mark.parametrize("use_decl_buffer", [True, False]) +def test_compute_inline_allocate_const(use_decl_buffer): + def apply_decl_buffer(*args, **kwargs): + if use_decl_buffer: + return T.decl_buffer(*args, **kwargs) + else: + return T.Buffer(*args, **kwargs) + + @T.prim_func + def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): + B = T.alloc_buffer([4]) + + offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) + offset = apply_decl_buffer([4], data=offset_ptr) + for i in range(4): + with T.block("compute_B"): + vi = T.axis.remap("S", [i]) + B[vi] = 10.0 * vi + offset[vi] + + for i, j in T.grid(4, 256): + with T.block("compute_C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi] + 100.0 * vj + + @T.prim_func + def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): + offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) + offset = apply_decl_buffer([4], data=offset_ptr) + for i, j in T.grid(4, 256): + with T.block("compute_C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("compute_B") + sch.compute_inline(block) + after = sch.mod["main"] + + tvm.ir.assert_structural_equal(expected, after) + verify_trace_roundtrip(sch=sch, mod=before) + + if __name__ == "__main__": tvm.testing.main()