Skip to content

Commit 668ea13

Browse files
committed
[TIR] Handle DeclBuffer in CacheReadWrite schedule primitive
Part of changes being split out from apache#14778 into independent portions. This commit allows TIR `cache_read` and `cache_write` schedule primitives to preserve `DeclBuffer` nodes.
1 parent 1a1784c commit 668ea13

File tree

2 files changed

+113
-80
lines changed

2 files changed

+113
-80
lines changed

src/tir/schedule/primitive/cache_read_write.cc

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <unordered_set>
2121

2222
#include "../../analysis/var_use_def_analysis.h"
23+
#include "../../transforms/ir_utils.h"
2324
#include "../utils.h"
2425

2526
namespace tvm {
@@ -425,21 +426,43 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref)
425426
* \return A SeqStmt, the result after insertion
426427
*/
427428
Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
428-
if (const auto* alloc = stmt.as<AllocateConstNode>()) {
429-
auto seq_stmt = InsertCacheStage(alloc->body, pos, stage);
430-
return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt,
431-
alloc->annotations, alloc->span);
432-
}
433-
if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
434-
ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
435-
result->seq.insert(result->seq.begin() + pos, stage);
436-
return SeqStmt(result);
429+
std::vector<Stmt> nest;
430+
Stmt body = stmt;
431+
while (true) {
432+
if (auto opt = body.as<AllocateConst>()) {
433+
auto alloc = opt.value();
434+
body = alloc->body;
435+
alloc.CopyOnWrite()->body = Evaluate(0);
436+
nest.push_back(alloc);
437+
} else if (auto opt = body.as<DeclBuffer>()) {
438+
auto decl_buffer = opt.value();
439+
body = decl_buffer->body;
440+
decl_buffer.CopyOnWrite()->body = Evaluate(0);
441+
nest.push_back(decl_buffer);
442+
} else {
443+
break;
444+
}
437445
}
438-
if (pos == 0) {
439-
return SeqStmt({stage, stmt});
446+
447+
if (const auto* seq_stmt = body.as<SeqStmtNode>()) {
448+
Array<Stmt> seq = seq_stmt->seq;
449+
ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " into sequence of length "
450+
<< seq.size();
451+
seq.insert(seq.begin() + pos, stage);
452+
body = SeqStmt(seq);
453+
} else if (pos == 0) {
454+
body = SeqStmt({stage, stmt});
455+
} else if (pos == 1) {
456+
body = SeqStmt({stmt, stage});
457+
} else {
458+
LOG(FATAL) << "Cannot insert at position " << pos
459+
<< ". When inserting adjacent to non-SeqStmt, "
460+
<< "only positions 0 and 1 are valid.";
440461
}
441-
ICHECK_EQ(pos, 1);
442-
return SeqStmt({stmt, stage});
462+
463+
body = MergeNest(nest, body);
464+
465+
return body;
443466
}
444467

445468
/*!
@@ -550,8 +573,14 @@ class CacheLocDetector : public StmtVisitor {
550573

551574
auto block_body = scope_sref->StmtAs<BlockNode>()->body;
552575
// Find the SeqStmtNode within (potentially nested) AllocateConstNodes
553-
while (block_body->IsInstance<AllocateConstNode>()) {
554-
block_body = block_body.as<AllocateConstNode>()->body;
576+
while (true) {
577+
if (auto* ptr = block_body.as<AllocateConstNode>()) {
578+
block_body = ptr->body;
579+
} else if (auto* ptr = block_body.as<DeclBufferNode>()) {
580+
block_body = ptr->body;
581+
} else {
582+
break;
583+
}
555584
}
556585
const auto* body = block_body.as<SeqStmtNode>();
557586
info->loc_pos = body == nullptr ? 1 : body->size();

tests/python/unittest/test_tir_schedule_cache_read_write.py

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,67 +1172,6 @@ def block_predicate_cache_write_output_buf() -> None:
11721172
use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
11731173

11741174

1175-
@T.prim_func
1176-
def cache_write_allocate_const(
1177-
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
1178-
):
1179-
B = T.alloc_buffer([128, 128], dtype="float32")
1180-
const = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1181-
const_1 = T.Buffer([8], dtype="float32", data=const)
1182-
const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1183-
const_2 = T.Buffer([8], dtype="float32", data=const)
1184-
for i, j in T.grid(128, 128):
1185-
for x in range(8):
1186-
with T.block("B"):
1187-
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
1188-
T.reads(A[vi, vj], const_1[vx], const_2[vx])
1189-
T.writes(B[vi, vj])
1190-
B[vi, vj] = A[vi, vj] * const_1[vx] + const_2[vx]
1191-
for i, j in T.grid(128, 128):
1192-
with T.block("C"):
1193-
vi, vj = T.axis.remap("SS", [i, j])
1194-
T.reads(B[vi, vj])
1195-
T.writes(C[vi, vj])
1196-
C[vi, vj] = B[vi, vj] + 1.0
1197-
1198-
1199-
@T.prim_func
1200-
def cache_write_allocate_const_output(
1201-
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")
1202-
):
1203-
B = T.alloc_buffer([128, 128], dtype="float32")
1204-
A_global = T.alloc_buffer([128, 128], dtype="float32")
1205-
C_global = T.alloc_buffer([128, 128], dtype="float16")
1206-
const_2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1207-
const_1 = T.Buffer([8], dtype="float32", data=const_2)
1208-
const_2_1 = T.Buffer([8], dtype="float32", data=const_2)
1209-
const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1210-
for ax0, ax1 in T.grid(128, 128):
1211-
with T.block("A_global"):
1212-
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1213-
T.reads(A[v0, v1])
1214-
T.writes(A_global[v0, v1])
1215-
A_global[v0, v1] = A[v0, v1]
1216-
for i, j, x in T.grid(128, 128, 8):
1217-
with T.block("B"):
1218-
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
1219-
T.reads(A_global[vi, vj], const_1[vx], const_2_1[vx])
1220-
T.writes(B[vi, vj])
1221-
B[vi, vj] = A_global[vi, vj] * const_1[vx] + const_2_1[vx]
1222-
for i, j in T.grid(128, 128):
1223-
with T.block("C"):
1224-
vi, vj = T.axis.remap("SS", [i, j])
1225-
T.reads(B[vi, vj])
1226-
T.writes(C_global[vi, vj])
1227-
C_global[vi, vj] = B[vi, vj] + T.float32(1)
1228-
for ax0, ax1 in T.grid(128, 128):
1229-
with T.block("C_global"):
1230-
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1231-
T.reads(C_global[v0, v1])
1232-
T.writes(C[v0, v1])
1233-
C[v0, v1] = C_global[v0, v1]
1234-
1235-
12361175
def test_cache_read_elementwise(use_block_name):
12371176
sch = tir.Schedule(elementwise, debug_mask="all")
12381177
block_b = sch.get_block("B")
@@ -1493,14 +1432,79 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name):
14931432
sch.cache_write(block_b, 0, "test_scope")
14941433

14951434

1496-
def test_cache_write_allocate_const():
1497-
sch = tir.Schedule(cache_write_allocate_const)
1435+
@pytest.mark.parametrize("use_decl_buffer", [True, False])
1436+
def test_cache_write_allocate_const(use_decl_buffer):
1437+
def apply_decl_buffer(*args, **kwargs):
1438+
if use_decl_buffer:
1439+
return T.decl_buffer(*args, **kwargs)
1440+
else:
1441+
return T.Buffer(*args, **kwargs)
1442+
1443+
@T.prim_func
1444+
def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")):
1445+
B = T.alloc_buffer([128, 128], dtype="float32")
1446+
const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1447+
const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
1448+
const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1449+
const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
1450+
for i, j in T.grid(128, 128):
1451+
for x in range(8):
1452+
with T.block("B"):
1453+
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
1454+
T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx])
1455+
T.writes(B[vi, vj])
1456+
B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx]
1457+
for i, j in T.grid(128, 128):
1458+
with T.block("C"):
1459+
vi, vj = T.axis.remap("SS", [i, j])
1460+
T.reads(B[vi, vj])
1461+
T.writes(C[vi, vj])
1462+
C[vi, vj] = B[vi, vj] + 1.0
1463+
1464+
@T.prim_func
1465+
def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")):
1466+
B = T.alloc_buffer([128, 128], dtype="float32")
1467+
A_global = T.alloc_buffer([128, 128], dtype="float32")
1468+
C_global = T.alloc_buffer([128, 128], dtype="float16")
1469+
const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1470+
const1_buf = apply_decl_buffer([8], dtype="float32", data=const1)
1471+
const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8])
1472+
const2_buf = apply_decl_buffer([8], dtype="float32", data=const2)
1473+
for ax0, ax1 in T.grid(128, 128):
1474+
with T.block("A_global"):
1475+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1476+
T.reads(A[v0, v1])
1477+
T.writes(A_global[v0, v1])
1478+
A_global[v0, v1] = A[v0, v1]
1479+
for i, j, x in T.grid(128, 128, 8):
1480+
with T.block("B"):
1481+
vi, vj, vx = T.axis.remap("SSS", [i, j, x])
1482+
T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx])
1483+
T.writes(B[vi, vj])
1484+
B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx]
1485+
for i, j in T.grid(128, 128):
1486+
with T.block("C"):
1487+
vi, vj = T.axis.remap("SS", [i, j])
1488+
T.reads(B[vi, vj])
1489+
T.writes(C_global[vi, vj])
1490+
C_global[vi, vj] = B[vi, vj] + T.float32(1)
1491+
for ax0, ax1 in T.grid(128, 128):
1492+
with T.block("C_global"):
1493+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1494+
T.reads(C_global[v0, v1])
1495+
T.writes(C[v0, v1])
1496+
C[v0, v1] = C_global[v0, v1]
1497+
1498+
sch = tir.Schedule(before)
14981499
block_b = sch.get_block("B")
14991500
block_c = sch.get_block("C")
15001501
sch.cache_read(block_b, 0, "global")
15011502
sch.cache_write(block_c, 0, "global")
1502-
tvm.ir.assert_structural_equal(cache_write_allocate_const_output, sch.mod["main"])
1503-
verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const)
1503+
1504+
after = sch.mod["main"]
1505+
1506+
tvm.ir.assert_structural_equal(expected, after)
1507+
verify_trace_roundtrip(sch=sch, mod=before)
15041508

15051509

15061510
def test_reindex_cache_read():

0 commit comments

Comments
 (0)