Skip to content

Commit 2c28efd

Browse files
quic-sanirudhyongwww
authored andcommitted
[TIR] [Bugfix] Pass the correct block_sref_reuse to Replace (apache#14023)
* [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace A mismatch between the blocks present in the `result` vs the blocks passed in `block_sref_to_reuse` caused the bug mentioned in apache#13974. This patch tries to fix that bug by collecting only the blocks that are part of result and also present in the block replacement map `new_block_to_old_`. Since the scope block is `result`, only that block and its child blocks would be replaced, and any replaced block would be present in `rewriter.new_block_to_old_`. Thus, collecting the replaced blocks from among child blocks of `result` guarantees that the `block_sref_reuse` would contain all the replaced blocks and that they'll point to the correct block in `result` thus avoiding the missing SRef error.
1 parent befd956 commit 2c28efd

File tree

2 files changed

+89
-11
lines changed

2 files changed

+89
-11
lines changed

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
704704
Buffer old_buffer_;
705705
};
706706

707+
/*!
708+
* \brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef
709+
* reuse
710+
*/
711+
class ReuseBlocksCollector : public tir::StmtVisitor {
712+
public:
713+
static Map<Block, Block> Collect(Block result, Map<Block, Block> new_block_to_old) {
714+
return ReuseBlocksCollector(new_block_to_old).Run(result);
715+
}
716+
717+
private:
718+
/*! \brief Entry point */
719+
Map<Block, Block> Run(const Block result) {
720+
VisitStmt(result);
721+
return block_sref_reuse_;
722+
}
723+
/*! \brief Constructor */
724+
explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
725+
: new_block_to_old_(new_block_to_old) {}
726+
727+
/*! \brief Override the Stmt visiting behaviour */
728+
void VisitStmt_(const tir::BlockNode* block) override {
729+
Block block_ref = GetRef<Block>(block);
730+
auto it = new_block_to_old_.find(block_ref);
731+
if (it != new_block_to_old_.end()) {
732+
block_sref_reuse_.Set((*it).second, (*it).first);
733+
}
734+
StmtVisitor::VisitStmt_(block);
735+
}
736+
737+
/*! \brief New map to be filled with just blocks from scope block */
738+
Map<Block, Block> block_sref_reuse_;
739+
/*! \brief All block replacements collected so far */
740+
Map<Block, Block> new_block_to_old_;
741+
};
742+
707743
class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
708744
public:
709745
/*!
@@ -730,17 +766,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
730766
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
731767
}
732768

733-
Map<Block, Block> block_sref_reuse;
734-
for (auto [after, before] : rewriter.new_block_to_old_) {
735-
while (auto opt = rewriter.new_block_to_old_.Get(before)) {
736-
before = opt.value();
737-
}
738-
while (auto opt = block_sref_reuse.Get(after)) {
739-
after = opt.value();
740-
}
741-
742-
block_sref_reuse.Set(before, after);
743-
}
769+
Map<Block, Block> block_sref_reuse =
770+
ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);
744771

745772
return {result, block_sref_reuse};
746773
}

tests/python/unittest/test_tir_schedule_transform_layout.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128), "float32"), C: T.Buffer((1, 1
173173
vi, vj = T.axis.remap("SS", [i, j])
174174
C[vi, vj] = B[vi, vj] + 1.0
175175

176+
class TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter):
177+
"""
178+
transform_layout with axis_separator on a buffer from cache_write should work as expected
179+
"""
180+
181+
@pytest.fixture
182+
def transform(self):
183+
def transform(mod):
184+
185+
def transform_fn(x, y):
186+
return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32]
187+
188+
sch = tvm.tir.Schedule(mod, debug_mask="all")
189+
block_rv = sch.get_block("T_add")
190+
sch.cache_write(block_rv, 0, "global")
191+
sch.transform_layout(block_rv, ("write", 0), transform_fn, pad_value=0.0)
192+
return sch.mod
193+
194+
return transform
195+
196+
def before(
197+
p0: T.Buffer((T.int64(33), T.int64(128)), "float32"),
198+
p1: T.Buffer((T.int64(33), T.int64(128)), "float32"),
199+
T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"),
200+
):
201+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
202+
# with T.block("root"):
203+
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
204+
with T.block("T_add"):
205+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
206+
T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1])
207+
T.writes(T_add[v_ax0, v_ax1])
208+
T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1]
209+
210+
def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")):
211+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
212+
# with T.block("root"):
213+
T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2])
214+
for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)):
215+
with T.block("T_add"):
216+
v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, axis2])
217+
T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
218+
T.writes(T_add_global[v_axis0, v_axis1, v_axis2])
219+
T_add_global[v_axis0, v_axis1, v_axis2] = T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
220+
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
221+
with T.block("T_add_global"):
222+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
223+
T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)])
224+
T.writes(T_add[v0, v1])
225+
T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)]
226+
176227
# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
177228
# fmt: on
178229

0 commit comments

Comments
 (0)