@@ -99,30 +99,113 @@ class LCADetector : public StmtExprVisitor {
9999 }
100100
101101 ancestor_scopes_.push_back (current_scope);
102+ loop_scope_map_.insert ({op->loop_var .get (), current_scope});
102103 StmtExprVisitor::VisitStmt_ (op);
103104 ancestor_scopes_.pop_back ();
105+ loop_scope_map_.erase (op->loop_var .get ());
104106 }
105107
106- void VisitStmt_ (const BlockNode* op) final {
108+ void VisitStmt_ (const BlockRealizeNode* op) final {
109+ const BlockNode* block = op->block .get ();
107110 int n = ancestor_scopes_.size ();
108- for (const Buffer& buf : op ->alloc_buffers ) {
111+ for (const Buffer& buf : block ->alloc_buffers ) {
109112 buffer_var_map_.emplace (buf->data .get (), buf.get ());
110113 }
111114
112115 const ScopeInfo* parent_scope = ancestor_scopes_.back ();
113- auto * current_scope = arena_.make <ScopeInfo>(parent_scope, op , n);
116+ auto * current_scope = arena_.make <ScopeInfo>(parent_scope, block , n);
114117
115118 ancestor_scopes_.push_back (current_scope);
119+
120+ // For each accessed buffer of the block, update the buffer's lca to
121+ // the lowest inclusive stmt position, which should dominate all loops
122+ // related to the accessed opaque block iter vars in buffer indices.
123+ UpdateDominateScopeOfOpaqueIter (op);
124+
116125 // Update match_buffers
117- for (const MatchBufferRegion& match_buffer : op ->match_buffers ) {
118- UpdateBufferLCA (match_buffer->source ->buffer .get ());
126+ for (const MatchBufferRegion& match_buffer : block ->match_buffers ) {
127+ UpdateBufferLCA (match_buffer->source ->buffer .get (), ancestor_scopes_. back () );
119128 match_buffers_.insert (match_buffer->buffer .get ());
120129 }
121130
122131 StmtExprVisitor::VisitStmt_ (op);
123132 ancestor_scopes_.pop_back ();
124133 }
125134
135+ void UpdateDominateScopeOfOpaqueIter (const BlockRealizeNode* block_realize) {
136+ // map opaque iter var to the scope which dominate all loop carried dependencies.
137+ std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;
138+
139+ // function to collect `itervar_to_dom_scope`, the result scope for each block
140+ // iter var should be above all loop scopes the opaque iter var binding relates to.
141+ auto do_collect_itervar_scope = [this , &itervar_to_dom_scope](const IterVar& itervar,
142+ const PrimExpr& binding) {
143+ PostOrderVisit (binding, [this , &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
144+ if (const VarNode* loop_var = obj.as <VarNode>()) {
145+ auto it = loop_scope_map_.find (loop_var);
146+ if (it == loop_scope_map_.end ()) {
147+ return ;
148+ }
149+ const ScopeInfo* scope = it->second ->parent_scope_info ;
150+ // find the highest loop scope the iter var binding has related to.
151+ auto dom_scope_it = itervar_to_dom_scope.find (itervar->var .get ());
152+ if (dom_scope_it == itervar_to_dom_scope.end ()) {
153+ itervar_to_dom_scope.insert (dom_scope_it, {itervar->var .get (), scope});
154+ } else if (scope->depth < dom_scope_it->second ->depth ) {
155+ dom_scope_it->second = scope;
156+ }
157+ }
158+ });
159+ };
160+
161+ // function to update lca scope of the buffer with loop carried dependent buffer accesses.
162+ // the result scope should be above all loop scopes the accessed opaque block iter vars
163+ // relate to, which is record in `itervar_to_dom_scope`.
164+ auto do_update = [this , &itervar_to_dom_scope](const BufferRegion& region) {
165+ const Buffer& buffer = region->buffer ;
166+ const ScopeInfo* scope = ancestor_scopes_.back ();
167+
168+ auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
169+ if (const VarNode* iter_var = obj.as <VarNode>()) {
170+ auto dom_scope_it = itervar_to_dom_scope.find (iter_var);
171+ if (dom_scope_it == itervar_to_dom_scope.end ()) {
172+ return ;
173+ }
174+ // find the highest loop scope the accessed buffer index has
175+ // loop carried dependencies to (via opaque iter var binding).
176+ if (dom_scope_it->second ->depth < scope->depth ) {
177+ scope = dom_scope_it->second ;
178+ }
179+ }
180+ };
181+
182+ // visit region min and max to find the lowest legal lca scope
183+ for (const Range& range : region->region ) {
184+ PostOrderVisit (range->min , handle_itervar);
185+ PostOrderVisit (range->min + range->extent - 1 , handle_itervar);
186+ }
187+ UpdateBufferLCA (buffer.get (), scope);
188+ };
189+
190+ // do collect and update
191+ const Block& block = block_realize->block ;
192+ for (size_t i = 0 ; i < block_realize->iter_values .size (); ++i) {
193+ const IterVar& iter_var = block->iter_vars [i];
194+ if (iter_var->iter_type != IterVarType::kDataPar &&
195+ iter_var->iter_type != IterVarType::kCommReduce ) {
196+ do_collect_itervar_scope (iter_var, block_realize->iter_values [i]);
197+ }
198+ }
199+ if (!itervar_to_dom_scope.empty ()) {
200+ for (const auto & read : block->reads ) {
201+ do_update (read);
202+ }
203+ for (const auto & write : block->writes ) {
204+ do_update (write);
205+ }
206+ }
207+ }
208+
126209 void VisitStmt_ (const AttrStmtNode* op) final {
127210 if (op->attr_key == attr::thread_extent) {
128211 const auto * iter = op->node .as <IterVarNode>();
@@ -136,17 +219,18 @@ class LCADetector : public StmtExprVisitor {
136219 }
137220
138221 void VisitExpr_ (const BufferLoadNode* op) final {
139- UpdateBufferLCA (op->buffer .get ());
222+ UpdateBufferLCA (op->buffer .get (), ancestor_scopes_. back () );
140223 StmtExprVisitor::VisitExpr_ (op);
141224 }
142225
143226 void VisitStmt_ (const BufferStoreNode* op) final {
144- UpdateBufferLCA (op->buffer .get ());
227+ UpdateBufferLCA (op->buffer .get (), ancestor_scopes_. back () );
145228 StmtExprVisitor::VisitStmt_ (op);
146229 }
147230
148231 void VisitStmt_ (const BufferRealizeNode* op) final {
149232 buffer_var_map_.emplace (op->buffer ->data .get (), op->buffer .get ());
233+ UpdateBufferLCA (op->buffer .get (), ancestor_scopes_.back ());
150234 StmtExprVisitor::VisitStmt_ (op);
151235 }
152236
@@ -165,16 +249,16 @@ class LCADetector : public StmtExprVisitor {
165249 void VisitBufferVar (const VarNode* op) {
166250 auto it = buffer_var_map_.find (op);
167251 if (it != buffer_var_map_.end ()) {
168- UpdateBufferLCA (it->second );
252+ UpdateBufferLCA (it->second , ancestor_scopes_. back () );
169253 }
170254 }
171255
172- void UpdateBufferLCA (const BufferNode* buffer) {
256+ void UpdateBufferLCA (const BufferNode* buffer, const ScopeInfo* scope ) {
173257 buffer_var_map_.emplace (buffer->data .get (), buffer);
174258 if (match_buffers_.find (buffer) == match_buffers_.end ()) {
175259 // Ingore buffer created by block match_buffer
176260 const ScopeInfo*& lca = buffer_lca_[buffer];
177- lca = LowestCommonAncestor (lca, ancestor_scopes_. back () );
261+ lca = LowestCommonAncestor (lca, scope );
178262 }
179263 }
180264
@@ -229,6 +313,8 @@ class LCADetector : public StmtExprVisitor {
229313 std::unordered_set<const BufferNode*> match_buffers_ = {};
230314 /* ! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
231315 std::vector<const ScopeInfo*> blockidx_scopes_ = {};
316+ /* ! \brief The map from loop var to the corresponding scope. */
317+ std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {};
232318 /* ! \brief Internal arena. */
233319 support::Arena arena_;
234320};
0 commit comments