@@ -37,8 +37,10 @@ namespace tir {
3737 */
3838class BlockReadWriteDetector : public StmtExprVisitor {
3939 public:
40- explicit BlockReadWriteDetector (const Map<Var, Buffer>& buffer_var_map)
41- : buffer_var_map_(buffer_var_map) {}
40+ explicit BlockReadWriteDetector (const Array<Buffer>& alloc_buffers,
41+ const Map<Var, Buffer>& buffer_var_map)
42+ : buffer_var_map_(buffer_var_map),
43+ alloc_buffers_(alloc_buffers.begin(), alloc_buffers.end()) {}
4244
4345 /* ! \brief Return read regions of the block */
4446 Array<BufferRegion> CollectReads (
@@ -78,6 +80,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
7880 std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
7981 /* !\ brief Internal analyzer. */
8082 arith::Analyzer ana_;
83+ /* ! \brief The alloc buffers of the current block*/
84+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> alloc_buffers_;
8185
8286 /* !
8387 * \brief Update read/write buffers and regions with provided buffer and region
@@ -145,11 +149,13 @@ Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
145149void BlockReadWriteDetector::VisitExpr_ (const VarNode* op) { UpdateOpaque (GetRef<Var>(op)); }
146150
147151void BlockReadWriteDetector::VisitExpr_ (const BufferLoadNode* op) {
148- std::vector<arith::IntSet> relaxed_region;
149- for (const PrimExpr& index : op->indices ) {
150- relaxed_region.push_back (arith::EvalSet (arith::IntSet::Vector (index), dom_map_));
152+ if (!alloc_buffers_.count (op->buffer )) {
153+ std::vector<arith::IntSet> relaxed_region;
154+ for (const PrimExpr& index : op->indices ) {
155+ relaxed_region.push_back (arith::EvalSet (arith::IntSet::Vector (index), dom_map_));
156+ }
157+ Update (&read_buffers_, &read_regions_, op->buffer , relaxed_region);
151158 }
152- Update (&read_buffers_, &read_regions_, op->buffer , relaxed_region);
153159 ExprVisitor::VisitExpr_ (op);
154160}
155161
@@ -182,20 +188,22 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
182188 auto it = buffer_var_map_.find (GetRef<Var>(buffer_var));
183189 if (it != buffer_var_map_.end ()) {
184190 const Buffer& buffer = (*it).second ;
185- const BufferRegion buffer_region = BufferRegion::FullRegion (buffer);
186- const Region& region = buffer_region->region ;
187- std::vector<arith::IntSet> int_set;
188- int_set.reserve (region.size ());
189- for (const Range& range : region) {
190- int_set.push_back (arith::EvalSet (range, dom_map_));
191- }
192- // read access, write access or opaque access
193- if ((access_mask->value & 1 ) && (access_mask->value & 2 )) {
194- Update (&opaque_buffers_, &opaque_regions_, buffer, int_set);
195- } else if (access_mask->value & 1 ) {
196- Update (&read_buffers_, &read_regions_, buffer, int_set);
197- } else if (access_mask->value & 2 ) {
198- Update (&writes_buffers_, &write_regions_, buffer, int_set);
191+ if (!alloc_buffers_.count (buffer)) {
192+ const BufferRegion buffer_region = BufferRegion::FullRegion (buffer);
193+ const Region& region = buffer_region->region ;
194+ std::vector<arith::IntSet> int_set;
195+ int_set.reserve (region.size ());
196+ for (const Range& range : region) {
197+ int_set.push_back (arith::EvalSet (range, dom_map_));
198+ }
199+ // read access, write access or opaque access
200+ if ((access_mask->value & 1 ) && (access_mask->value & 2 )) {
201+ Update (&opaque_buffers_, &opaque_regions_, buffer, int_set);
202+ } else if (access_mask->value & 1 ) {
203+ Update (&read_buffers_, &read_regions_, buffer, int_set);
204+ } else if (access_mask->value & 2 ) {
205+ Update (&writes_buffers_, &write_regions_, buffer, int_set);
206+ }
199207 }
200208 }
201209 } else {
@@ -221,11 +229,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
221229}
222230
223231void BlockReadWriteDetector::VisitStmt_ (const BufferStoreNode* op) {
224- std::vector<arith::IntSet> relaxed_region;
225- for (const PrimExpr& index : op->indices ) {
226- relaxed_region.push_back (arith::EvalSet (arith::IntSet::Vector (index), dom_map_));
232+ if (!alloc_buffers_.count (op->buffer )) {
233+ std::vector<arith::IntSet> relaxed_region;
234+ for (const PrimExpr& index : op->indices ) {
235+ relaxed_region.push_back (arith::EvalSet (arith::IntSet::Vector (index), dom_map_));
236+ }
237+ Update (&writes_buffers_, &write_regions_, op->buffer , relaxed_region);
227238 }
228- Update (&writes_buffers_, &write_regions_, op->buffer , relaxed_region);
229239 StmtVisitor::VisitStmt_ (op);
230240}
231241
@@ -236,24 +246,28 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
236246 vmap[op->block ->iter_vars [i]->var .get ()] = op->iter_values [i];
237247 }
238248 for (const auto & read : op->block ->reads ) {
239- std::vector<arith::IntSet> relaxed_region;
240- for (const auto & range : read->region ) {
241- relaxed_region.push_back (
242- arith::EvalSet (arith::IntSet::FromRange (Range::FromMinExtent (
243- Substitute (range->min , vmap), Substitute (range->extent , vmap))),
244- dom_map_));
249+ if (!alloc_buffers_.count (read->buffer )) {
250+ std::vector<arith::IntSet> relaxed_region;
251+ for (const auto & range : read->region ) {
252+ relaxed_region.push_back (
253+ arith::EvalSet (arith::IntSet::FromRange (Range::FromMinExtent (
254+ Substitute (range->min , vmap), Substitute (range->extent , vmap))),
255+ dom_map_));
256+ }
257+ Update (&read_buffers_, &read_regions_, read->buffer , relaxed_region);
245258 }
246- Update (&read_buffers_, &read_regions_, read->buffer , relaxed_region);
247259 }
248260 for (const auto & write : op->block ->writes ) {
249- std::vector<arith::IntSet> relaxed_region;
250- for (const auto & range : write->region ) {
251- relaxed_region.push_back (
252- arith::EvalSet (arith::IntSet::FromRange (Range::FromMinExtent (
253- Substitute (range->min , vmap), Substitute (range->extent , vmap))),
254- dom_map_));
261+ if (!alloc_buffers_.count (write->buffer )) {
262+ std::vector<arith::IntSet> relaxed_region;
263+ for (const auto & range : write->region ) {
264+ relaxed_region.push_back (
265+ arith::EvalSet (arith::IntSet::FromRange (Range::FromMinExtent (
266+ Substitute (range->min , vmap), Substitute (range->extent , vmap))),
267+ dom_map_));
268+ }
269+ Update (&writes_buffers_, &write_regions_, write->buffer , relaxed_region);
255270 }
256- Update (&writes_buffers_, &write_regions_, write->buffer , relaxed_region);
257271 }
258272}
259273
@@ -349,7 +363,7 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) {
349363
350364Array<Array<BufferRegion>> GetBlockAccessRegion (const Block& block,
351365 const Map<Var, Buffer>& buffer_var_map) {
352- BlockReadWriteDetector detector (buffer_var_map);
366+ BlockReadWriteDetector detector (block-> alloc_buffers , buffer_var_map);
353367 detector (block);
354368 Array<BufferRegion> writes = detector.CollectWrites ();
355369 std::unordered_set<const BufferNode*> excluded_buffers;
@@ -366,7 +380,7 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
366380
367381Array<Array<BufferRegion>> GetBlockReadWriteRegion (const Block& block,
368382 const Map<Var, Buffer>& buffer_var_map) {
369- BlockReadWriteDetector detector (buffer_var_map);
383+ BlockReadWriteDetector detector (block-> alloc_buffers , buffer_var_map);
370384 detector (block);
371385 Array<BufferRegion> opaques = detector.CollectOpaques ();
372386 std::unordered_set<const BufferNode*> excluded_buffers;
0 commit comments