Skip to content

Commit a61c1ad

Browse files
[TIR] Fix plan buffer allocation location for loop carried dependencies (#12757)
* Fix plan buffer allocation location for loop carried dependencies * fix testcase region annotation issue * fix typo in ut
1 parent 71f25b3 commit a61c1ad

File tree

2 files changed

+200
-15
lines changed

2 files changed

+200
-15
lines changed

src/tir/analysis/buffer_access_lca_detector.cc

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,30 +99,113 @@ class LCADetector : public StmtExprVisitor {
9999
}
100100

101101
ancestor_scopes_.push_back(current_scope);
102+
loop_scope_map_.insert({op->loop_var.get(), current_scope});
102103
StmtExprVisitor::VisitStmt_(op);
103104
ancestor_scopes_.pop_back();
105+
loop_scope_map_.erase(op->loop_var.get());
104106
}
105107

106-
void VisitStmt_(const BlockNode* op) final {
108+
void VisitStmt_(const BlockRealizeNode* op) final {
109+
const BlockNode* block = op->block.get();
107110
int n = ancestor_scopes_.size();
108-
for (const Buffer& buf : op->alloc_buffers) {
111+
for (const Buffer& buf : block->alloc_buffers) {
109112
buffer_var_map_.emplace(buf->data.get(), buf.get());
110113
}
111114

112115
const ScopeInfo* parent_scope = ancestor_scopes_.back();
113-
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
116+
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, block, n);
114117

115118
ancestor_scopes_.push_back(current_scope);
119+
120+
// For each accessed buffer of the block, update the buffer's lca to
121+
// the lowest inclusive stmt position, which should dominate all loops
122+
// related to the accessed opaque block iter vars in buffer indices.
123+
UpdateDominateScopeOfOpaqueIter(op);
124+
116125
// Update match_buffers
117-
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
118-
UpdateBufferLCA(match_buffer->source->buffer.get());
126+
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
127+
UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back());
119128
match_buffers_.insert(match_buffer->buffer.get());
120129
}
121130

122131
StmtExprVisitor::VisitStmt_(op);
123132
ancestor_scopes_.pop_back();
124133
}
125134

135+
void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
136+
// map opaque iter var to the scope which dominate all loop carried dependencies.
137+
std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
138+
139+
// function to collect `itervar_to_dom_scope`, the result scope for each block
140+
// iter var should be above all loop scopes the opaque iter var binding relates to.
141+
auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar,
142+
const PrimExpr& binding) {
143+
PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
144+
if (const VarNode* loop_var = obj.as<VarNode>()) {
145+
auto it = loop_scope_map_.find(loop_var);
146+
if (it == loop_scope_map_.end()) {
147+
return;
148+
}
149+
const ScopeInfo* scope = it->second->parent_scope_info;
150+
// find the highest loop scope the iter var binding has related to.
151+
auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
152+
if (dom_scope_it == itervar_to_dom_scope.end()) {
153+
itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope});
154+
} else if (scope->depth < dom_scope_it->second->depth) {
155+
dom_scope_it->second = scope;
156+
}
157+
}
158+
});
159+
};
160+
161+
// function to update lca scope of the buffer with loop carried dependent buffer accesses.
162+
// the result scope should be above all loop scopes the accessed opaque block iter vars
163+
// relate to, which is record in `itervar_to_dom_scope`.
164+
auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) {
165+
const Buffer& buffer = region->buffer;
166+
const ScopeInfo* scope = ancestor_scopes_.back();
167+
168+
auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
169+
if (const VarNode* iter_var = obj.as<VarNode>()) {
170+
auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
171+
if (dom_scope_it == itervar_to_dom_scope.end()) {
172+
return;
173+
}
174+
// find the highest loop scope the accessed buffer index has
175+
// loop carried dependencies to (via opaque iter var binding).
176+
if (dom_scope_it->second->depth < scope->depth) {
177+
scope = dom_scope_it->second;
178+
}
179+
}
180+
};
181+
182+
// visit region min and max to find the lowest legal lca scope
183+
for (const Range& range : region->region) {
184+
PostOrderVisit(range->min, handle_itervar);
185+
PostOrderVisit(range->min + range->extent - 1, handle_itervar);
186+
}
187+
UpdateBufferLCA(buffer.get(), scope);
188+
};
189+
190+
// do collect and update
191+
const Block& block = block_realize->block;
192+
for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
193+
const IterVar& iter_var = block->iter_vars[i];
194+
if (iter_var->iter_type != IterVarType::kDataPar &&
195+
iter_var->iter_type != IterVarType::kCommReduce) {
196+
do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
197+
}
198+
}
199+
if (!itervar_to_dom_scope.empty()) {
200+
for (const auto& read : block->reads) {
201+
do_update(read);
202+
}
203+
for (const auto& write : block->writes) {
204+
do_update(write);
205+
}
206+
}
207+
}
208+
126209
void VisitStmt_(const AttrStmtNode* op) final {
127210
if (op->attr_key == attr::thread_extent) {
128211
const auto* iter = op->node.as<IterVarNode>();
@@ -136,17 +219,18 @@ class LCADetector : public StmtExprVisitor {
136219
}
137220

138221
void VisitExpr_(const BufferLoadNode* op) final {
139-
UpdateBufferLCA(op->buffer.get());
222+
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
140223
StmtExprVisitor::VisitExpr_(op);
141224
}
142225

143226
void VisitStmt_(const BufferStoreNode* op) final {
144-
UpdateBufferLCA(op->buffer.get());
227+
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
145228
StmtExprVisitor::VisitStmt_(op);
146229
}
147230

148231
void VisitStmt_(const BufferRealizeNode* op) final {
149232
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
233+
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
150234
StmtExprVisitor::VisitStmt_(op);
151235
}
152236

@@ -165,16 +249,16 @@ class LCADetector : public StmtExprVisitor {
165249
void VisitBufferVar(const VarNode* op) {
166250
auto it = buffer_var_map_.find(op);
167251
if (it != buffer_var_map_.end()) {
168-
UpdateBufferLCA(it->second);
252+
UpdateBufferLCA(it->second, ancestor_scopes_.back());
169253
}
170254
}
171255

172-
void UpdateBufferLCA(const BufferNode* buffer) {
256+
void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) {
173257
buffer_var_map_.emplace(buffer->data.get(), buffer);
174258
if (match_buffers_.find(buffer) == match_buffers_.end()) {
175259
// Ingore buffer created by block match_buffer
176260
const ScopeInfo*& lca = buffer_lca_[buffer];
177-
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
261+
lca = LowestCommonAncestor(lca, scope);
178262
}
179263
}
180264

@@ -229,6 +313,8 @@ class LCADetector : public StmtExprVisitor {
229313
std::unordered_set<const BufferNode*> match_buffers_ = {};
230314
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
231315
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
316+
/*! \brief The map from loop var to the corresponding scope. */
317+
std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {};
232318
/*! \brief Internal arena. */
233319
support::Arena arena_;
234320
};

tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import tvm
18+
import tvm.testing
1819
from tvm import te
1920
from tvm.script import tir as T
2021

@@ -242,9 +243,107 @@ def test_lower_te():
242243
) # PlanAndUpdateBufferAllocationLocation should do nothing on TE
243244

244245

246+
def test_loop_carried_dependency():
247+
"""The buffer allocation should be above opaque iter var's loop scopes
248+
such that buffer accesses with loop carried dependencies are covered."""
249+
250+
@T.prim_func
251+
def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]):
252+
C = T.alloc_buffer([8, 8, 8], dtype="int32")
253+
for i in T.serial(8):
254+
for j in T.serial(8):
255+
for k in T.serial(8):
256+
with T.block("b0"):
257+
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
258+
C[vi, vj, vk] = A[vi, vj, vk] + 1
259+
for k in T.serial(8):
260+
with T.block("b1"):
261+
vi, vk = T.axis.remap("SS", [i, k])
262+
vj = T.axis.opaque(8, j)
263+
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
264+
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
265+
)
266+
267+
@T.prim_func
268+
def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> None:
269+
for i in T.serial(8):
270+
with T.block():
271+
T.reads(A[i, 0:8, 0:8])
272+
T.writes(B[i, 0:8, 0:8])
273+
C = T.alloc_buffer([8, 8, 8], dtype="int32")
274+
for j in T.serial(8):
275+
for k in T.serial(8):
276+
with T.block("b0"):
277+
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
278+
C[vi, vj, vk] = A[vi, vj, vk] + 1
279+
for k in T.serial(8):
280+
with T.block("b1"):
281+
vi, vk = T.axis.remap("SS", [i, k])
282+
vj = T.axis.opaque(8, j)
283+
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
284+
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
285+
)
286+
287+
_check(before, after)
288+
289+
290+
def test_1D_cascade_op_rolling_buffer():
291+
"""The intermediate buffer must be allocated above rolling buffer's rolling loop,
292+
which is marked as opaque in consumer block's iter mappings."""
293+
294+
@T.prim_func
295+
def before(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
296+
B = T.alloc_buffer((4, 6), "int32")
297+
for c in T.serial(4):
298+
for i in T.serial(0, 2):
299+
for j in T.serial(0, 6):
300+
for k in T.serial(3):
301+
with T.block("P1"):
302+
T.where(i < 1 or j >= 2)
303+
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
304+
if vk == 0:
305+
B[cc, T.floormod(vi * 4 + vj, 6)] = 0
306+
B[cc, T.floormod(vi * 4 + vj, 6)] = (
307+
B[cc, T.floormod(vi * 4 + vj, 6)] + A[cc, vi * 4 + vj + vk]
308+
)
309+
for j in T.serial(0, 4):
310+
for k in T.serial(3):
311+
with T.block("P2"):
312+
vi = T.axis.opaque(2, i)
313+
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
314+
if vk == 0:
315+
C[cc, vi * 4 + vj] = 0
316+
C[cc, vi * 4 + vj] = (
317+
C[cc, vi * 4 + vj] + B[cc, T.floormod(vi * 4 + vj + vk, 6)]
318+
)
319+
320+
@T.prim_func
321+
def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
322+
for c in T.serial(4):
323+
with T.block():
324+
T.reads(A[c, 0:12], C[c, 0:8])
325+
T.writes(C[c, 0:8])
326+
B = T.alloc_buffer([4, 6], dtype="int32")
327+
for i in T.serial(2):
328+
for j, k in T.grid(6, 3):
329+
with T.block("P1"):
330+
T.where(i < 1 or j >= 2)
331+
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
332+
if vk == 0:
333+
B[cc, (vi * 4 + vj) % 6] = 0
334+
B[cc, (vi * 4 + vj) % 6] = (
335+
B[cc, (vi * 4 + vj) % 6] + A[cc, vi * 4 + vj + vk]
336+
)
337+
for j, k in T.grid(4, 3):
338+
with T.block("P2"):
339+
vi = T.axis.opaque(2, i)
340+
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
341+
if vk == 0:
342+
C[cc, vi * 4 + vj] = 0
343+
C[cc, vi * 4 + vj] = C[cc, vi * 4 + vj] + B[cc, (vi * 4 + vj + vk) % 6]
344+
345+
_check(before, after)
346+
347+
245348
if __name__ == "__main__":
246-
test_elementwise()
247-
test_locate_buffer_allocation()
248-
test_match_buffer_allocation()
249-
test_opaque_access()
250-
test_lower_te()
349+
tvm.testing.main()

0 commit comments

Comments
 (0)