diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index cf139c7df789..0bdf4aff7940 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -20,6 +20,7 @@ #include #include "../../analysis/var_use_def_analysis.h" +#include "../../transforms/ir_utils.h" #include "../utils.h" namespace tvm { @@ -425,21 +426,43 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) * \return A SeqStmt, the result after insertion */ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { - if (const auto* alloc = stmt.as()) { - auto seq_stmt = InsertCacheStage(alloc->body, pos, stage); - return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt, - alloc->annotations, alloc->span); - } - if (const auto* seq_stmt = stmt.as()) { - ObjectPtr result = make_object(*seq_stmt); - result->seq.insert(result->seq.begin() + pos, stage); - return SeqStmt(result); + std::vector nest; + Stmt body = stmt; + while (true) { + if (auto opt = body.as()) { + auto alloc = opt.value(); + body = alloc->body; + alloc.CopyOnWrite()->body = Evaluate(0); + nest.push_back(alloc); + } else if (auto opt = body.as()) { + auto decl_buffer = opt.value(); + body = decl_buffer->body; + decl_buffer.CopyOnWrite()->body = Evaluate(0); + nest.push_back(decl_buffer); + } else { + break; + } } - if (pos == 0) { - return SeqStmt({stage, stmt}); + + if (const auto* seq_stmt = body.as()) { + Array seq = seq_stmt->seq; + ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " into sequence of length " + << seq.size(); + seq.insert(seq.begin() + pos, stage); + body = SeqStmt(seq); + } else if (pos == 0) { + body = SeqStmt({stage, stmt}); + } else if (pos == 1) { + body = SeqStmt({stmt, stage}); + } else { + LOG(FATAL) << "Cannot insert at position " << pos + << ". When inserting adjacent to non-SeqStmt, " + << "only positions 0 and 1 are valid."; } - ICHECK_EQ(pos, 1); - return SeqStmt({stmt, stage}); + + body = MergeNest(nest, body); + + return body; } /*! @@ -550,8 +573,14 @@ class CacheLocDetector : public StmtVisitor { auto block_body = scope_sref->StmtAs()->body; // Find the SeqStmtNode within (potentially nested) AllocateConstNodes - while (block_body->IsInstance()) { - block_body = block_body.as()->body; + while (true) { + if (auto* ptr = block_body.as()) { + block_body = ptr->body; + } else if (auto* ptr = block_body.as()) { + block_body = ptr->body; + } else { + break; + } } const auto* body = block_body.as(); info->loc_pos = body == nullptr ? 1 : body->size(); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 604dbed325ec..43bf6b983eb5 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); + } else if (const auto* alloc = s.as()) { + auto n = make_object(*alloc); + ICHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); } else if (const auto* decl_buffer = s.as()) { auto n = make_object(*decl_buffer); ICHECK(is_no_op(n->body)); diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 454557a2bde7..95955646c64b 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -1172,67 +1172,6 @@ def block_predicate_cache_write_output_buf() -> None: use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) -@T.prim_func -def cache_write_allocate_const( - A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16") -): - B = T.alloc_buffer([128, 128], dtype="float32") - const = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const_1 = T.Buffer([8], dtype="float32", data=const) - const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const_2 = T.Buffer([8], dtype="float32", data=const) - for i, j in T.grid(128, 128): - for x in range(8): - with T.block("B"): - vi, vj, vx = T.axis.remap("SSS", [i, j, x]) - T.reads(A[vi, vj], const_1[vx], const_2[vx]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * const_1[vx] + const_2[vx] - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + 1.0 - - -@T.prim_func -def cache_write_allocate_const_output( - A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16") -): - B = T.alloc_buffer([128, 128], dtype="float32") - A_global = T.alloc_buffer([128, 128], dtype="float32") - C_global = T.alloc_buffer([128, 128], dtype="float16") - const_2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const_1 = T.Buffer([8], dtype="float32", data=const_2) - const_2_1 = T.Buffer([8], dtype="float32", data=const_2) - const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - for ax0, ax1 in T.grid(128, 128): - with T.block("A_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v0, v1]) - T.writes(A_global[v0, v1]) - A_global[v0, v1] = A[v0, v1] - for i, j, x in T.grid(128, 128, 8): - with T.block("B"): - vi, vj, vx = T.axis.remap("SSS", [i, j, x]) - T.reads(A_global[vi, vj], const_1[vx], const_2_1[vx]) - T.writes(B[vi, vj]) - B[vi, vj] = A_global[vi, vj] * const_1[vx] + const_2_1[vx] - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(C_global[vi, vj]) - C_global[vi, vj] = B[vi, vj] + T.float32(1) - for ax0, ax1 in T.grid(128, 128): - with T.block("C_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(C_global[v0, v1]) - T.writes(C[v0, v1]) - C[v0, v1] = C_global[v0, v1] - - def test_cache_read_elementwise(use_block_name): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") @@ -1493,14 +1432,79 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name): sch.cache_write(block_b, 0, "test_scope") -def test_cache_write_allocate_const(): - sch = tir.Schedule(cache_write_allocate_const) +@pytest.mark.parametrize("use_decl_buffer", [True, False]) +def test_cache_write_allocate_const(use_decl_buffer): + def apply_decl_buffer(*args, **kwargs): + if use_decl_buffer: + return T.decl_buffer(*args, **kwargs) + else: + return T.Buffer(*args, **kwargs) + + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")): + B = T.alloc_buffer([128, 128], dtype="float32") + const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) + const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) + for i, j in T.grid(128, 128): + for x in range(8): + with T.block("B"): + vi, vj, vx = T.axis.remap("SSS", [i, j, x]) + T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")): + B = T.alloc_buffer([128, 128], dtype="float32") + A_global = T.alloc_buffer([128, 128], dtype="float32") + C_global = T.alloc_buffer([128, 128], dtype="float16") + const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) + const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) + const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) + for ax0, ax1 in T.grid(128, 128): + with T.block("A_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v0, v1]) + T.writes(A_global[v0, v1]) + A_global[v0, v1] = A[v0, v1] + for i, j, x in T.grid(128, 128, 8): + with T.block("B"): + vi, vj, vx = T.axis.remap("SSS", [i, j, x]) + T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx]) + T.writes(B[vi, vj]) + B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C_global[vi, vj]) + C_global[vi, vj] = B[vi, vj] + T.float32(1) + for ax0, ax1 in T.grid(128, 128): + with T.block("C_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(C_global[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_global[v0, v1] + + sch = tir.Schedule(before) block_b = sch.get_block("B") block_c = sch.get_block("C") sch.cache_read(block_b, 0, "global") sch.cache_write(block_c, 0, "global") - tvm.ir.assert_structural_equal(cache_write_allocate_const_output, sch.mod["main"]) - verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const) + + after = sch.mod["main"] + + tvm.ir.assert_structural_equal(expected, after) + verify_trace_roundtrip(sch=sch, mod=before) def test_reindex_cache_read():