Skip to content

Commit ef8502d

Browse files
jinhongyiiMasterJH5574
authored andcommitted
[Dynamic] M2 for S3: Compute Inline (apache#173)
1 parent ab0e91c commit ef8502d

File tree

4 files changed

+542
-112
lines changed

4 files changed

+542
-112
lines changed

src/tir/analysis/block_access_region_detector.cc

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ namespace tir {
3737
*/
3838
class BlockReadWriteDetector : public StmtExprVisitor {
3939
public:
40-
explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
41-
: buffer_var_map_(buffer_var_map) {}
40+
explicit BlockReadWriteDetector(const Array<Buffer>& alloc_buffers,
41+
const Map<Var, Buffer>& buffer_var_map)
42+
: buffer_var_map_(buffer_var_map),
43+
alloc_buffers_(alloc_buffers.begin(), alloc_buffers.end()) {}
4244

4345
/*! \brief Return read regions of the block */
4446
Array<BufferRegion> CollectReads(
@@ -78,6 +80,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
7880
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
7981
/*!\ brief Internal analyzer. */
8082
arith::Analyzer ana_;
83+
/*! \brief The alloc buffers of the current block*/
84+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> alloc_buffers_;
8185

8286
/*!
8387
* \brief Update read/write buffers and regions with provided buffer and region
@@ -145,11 +149,13 @@ Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
145149
void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef<Var>(op)); }
146150

147151
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
148-
std::vector<arith::IntSet> relaxed_region;
149-
for (const PrimExpr& index : op->indices) {
150-
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
152+
if (!alloc_buffers_.count(op->buffer)) {
153+
std::vector<arith::IntSet> relaxed_region;
154+
for (const PrimExpr& index : op->indices) {
155+
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
156+
}
157+
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
151158
}
152-
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
153159
ExprVisitor::VisitExpr_(op);
154160
}
155161

@@ -182,20 +188,22 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
182188
auto it = buffer_var_map_.find(GetRef<Var>(buffer_var));
183189
if (it != buffer_var_map_.end()) {
184190
const Buffer& buffer = (*it).second;
185-
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
186-
const Region& region = buffer_region->region;
187-
std::vector<arith::IntSet> int_set;
188-
int_set.reserve(region.size());
189-
for (const Range& range : region) {
190-
int_set.push_back(arith::EvalSet(range, dom_map_));
191-
}
192-
// read access, write access or opaque access
193-
if ((access_mask->value & 1) && (access_mask->value & 2)) {
194-
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
195-
} else if (access_mask->value & 1) {
196-
Update(&read_buffers_, &read_regions_, buffer, int_set);
197-
} else if (access_mask->value & 2) {
198-
Update(&writes_buffers_, &write_regions_, buffer, int_set);
191+
if (!alloc_buffers_.count(buffer)) {
192+
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
193+
const Region& region = buffer_region->region;
194+
std::vector<arith::IntSet> int_set;
195+
int_set.reserve(region.size());
196+
for (const Range& range : region) {
197+
int_set.push_back(arith::EvalSet(range, dom_map_));
198+
}
199+
// read access, write access or opaque access
200+
if ((access_mask->value & 1) && (access_mask->value & 2)) {
201+
Update(&opaque_buffers_, &opaque_regions_, buffer, int_set);
202+
} else if (access_mask->value & 1) {
203+
Update(&read_buffers_, &read_regions_, buffer, int_set);
204+
} else if (access_mask->value & 2) {
205+
Update(&writes_buffers_, &write_regions_, buffer, int_set);
206+
}
199207
}
200208
}
201209
} else {
@@ -221,11 +229,13 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
221229
}
222230

223231
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
224-
std::vector<arith::IntSet> relaxed_region;
225-
for (const PrimExpr& index : op->indices) {
226-
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
232+
if (!alloc_buffers_.count(op->buffer)) {
233+
std::vector<arith::IntSet> relaxed_region;
234+
for (const PrimExpr& index : op->indices) {
235+
relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_));
236+
}
237+
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
227238
}
228-
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
229239
StmtVisitor::VisitStmt_(op);
230240
}
231241

@@ -236,24 +246,28 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
236246
vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
237247
}
238248
for (const auto& read : op->block->reads) {
239-
std::vector<arith::IntSet> relaxed_region;
240-
for (const auto& range : read->region) {
241-
relaxed_region.push_back(
242-
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
243-
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
244-
dom_map_));
249+
if (!alloc_buffers_.count(read->buffer)) {
250+
std::vector<arith::IntSet> relaxed_region;
251+
for (const auto& range : read->region) {
252+
relaxed_region.push_back(
253+
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
254+
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
255+
dom_map_));
256+
}
257+
Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
245258
}
246-
Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
247259
}
248260
for (const auto& write : op->block->writes) {
249-
std::vector<arith::IntSet> relaxed_region;
250-
for (const auto& range : write->region) {
251-
relaxed_region.push_back(
252-
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
253-
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
254-
dom_map_));
261+
if (!alloc_buffers_.count(write->buffer)) {
262+
std::vector<arith::IntSet> relaxed_region;
263+
for (const auto& range : write->region) {
264+
relaxed_region.push_back(
265+
arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
266+
Substitute(range->min, vmap), Substitute(range->extent, vmap))),
267+
dom_map_));
268+
}
269+
Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
255270
}
256-
Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
257271
}
258272
}
259273

@@ -349,7 +363,7 @@ void BlockReadWriteDetector::UpdateOpaque(const Var& buffer_var) {
349363

350364
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
351365
const Map<Var, Buffer>& buffer_var_map) {
352-
BlockReadWriteDetector detector(buffer_var_map);
366+
BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map);
353367
detector(block);
354368
Array<BufferRegion> writes = detector.CollectWrites();
355369
std::unordered_set<const BufferNode*> excluded_buffers;
@@ -366,7 +380,7 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
366380

367381
Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
368382
const Map<Var, Buffer>& buffer_var_map) {
369-
BlockReadWriteDetector detector(buffer_var_map);
383+
BlockReadWriteDetector detector(block->alloc_buffers, buffer_var_map);
370384
detector(block);
371385
Array<BufferRegion> opaques = detector.CollectOpaques();
372386
std::unordered_set<const BufferNode*> excluded_buffers;

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ namespace tvm {
2222
namespace tir {
2323

2424
static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of
25-
'A[i, j, k, ...] = f(i, j, k, ...)',
26-
where the indices on the left are distinct atomic variables,
27-
and there should be no variables other than the index variables)";
25+
'A[f(i, j, k, ...)] = g(i, j, k, ...)',
26+
where the store indices mapping f on the left are bijective affine.)";
2827

2928
static const char kErrBodyReverseInline[] = R"(The body of the inlined block should be in form of
3029
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
@@ -284,31 +283,6 @@ class BaseInliner : public StmtExprMutator {
284283
return std::move(tgt_block);
285284
}
286285

287-
/*!
288-
* \brief Count the number of undefined variables that are not used
289-
* as buffer objects.
290-
*
291-
* This is used to determine whether inlining or reverse inlining is
292-
* possible. The only undefined variables present should be the
293-
* load/store indices, or buffer access based on those indices.
294-
*
295-
* \param stmt The statement in which to count undefined variables
296-
*/
297-
static int GetNumUndefinedNonpointerVars(const Stmt& stmt) {
298-
auto undefined_vars = UndefinedVars(stmt, {});
299-
// Buffer pointers and the inlined indices are allowed, but no
300-
// other variables may appear in the inlined block.
301-
int num_nonpointer_vars = 0;
302-
for (const auto& var : undefined_vars) {
303-
bool is_pointer = var->dtype.is_handle() && var->type_annotation.defined() &&
304-
var->type_annotation.as<PointerTypeNode>();
305-
if (!is_pointer) {
306-
num_nonpointer_vars++;
307-
}
308-
}
309-
return num_nonpointer_vars;
310-
}
311-
312286
private:
313287
/*!
314288
* \brief Add the buffers in the block signature to the `buffer_var_map_`,
@@ -406,7 +380,7 @@ class BaseInliner : public StmtExprMutator {
406380
/*! \brief Maps a buffer's data field to itself */
407381
Map<Var, Buffer> buffer_var_map_;
408382
/*! \brief The indices used for indexing the buffer to be inlined */
409-
std::vector<const VarNode*> idx_vars_;
383+
std::vector<Var> idx_vars_;
410384
/*! \brief The mapping to substitute index variables to PrimExprs */
411385
std::unordered_map<const VarNode*, PrimExpr> idx_sub_;
412386

@@ -443,10 +417,62 @@ class ComputeInliner : public BaseInliner {
443417
return false;
444418
}
445419

446-
int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
447-
if (!UpdateAndCheckIndexVars(inlined_store_->indices, n_vars)) {
420+
// Fast path on trivial case:
421+
// Check the store indices are same with the block iters;
422+
store_value_ = inlined_store_->value;
423+
size_t num_iters = producer_block->iter_vars.size();
424+
size_t buffer_ndim = inlined_store_->indices.size();
425+
if (num_iters == buffer_ndim) {
426+
std::vector<Var> idx_vars;
427+
idx_vars.reserve(num_iters);
428+
for (size_t i = 0; i < num_iters; ++i) {
429+
const IterVar& iter = producer_block->iter_vars[i];
430+
const PrimExpr& e = inlined_store_->indices[i];
431+
if (e.same_as(iter->var) ||
432+
(analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) &&
433+
analyzer_.CanProveEqual(iter->dom->extent, 1))) {
434+
idx_vars.push_back(iter->var);
435+
} else {
436+
break;
437+
}
438+
}
439+
if (idx_vars.size() == num_iters) {
440+
// match success
441+
idx_vars_ = std::move(idx_vars);
442+
return true;
443+
}
444+
}
445+
446+
// If the mapping for store indices is non-trivial
447+
// check bijective mapping from producer iter var to store indices
448+
Map<Var, Range> producer_iter_doms;
449+
for (const auto& iter : producer_block->iter_vars) {
450+
producer_iter_doms.Set(iter->var, iter->dom);
451+
}
452+
auto res = arith::DetectIterMap(
453+
/*indices=*/inlined_store_->indices,
454+
/*input_iters=*/producer_iter_doms,
455+
/*predicate=*/true,
456+
/*check_level=*/arith::IterMapLevel::Bijective,
457+
/*analyzer=*/&analyzer_,
458+
/*simplify_trivial_iterators=*/false);
459+
if (res->indices.empty()) {
460+
// Failure: indices of BufferStore are not bijective affine
448461
return false;
449462
}
463+
idx_vars_.resize(buffer_ndim);
464+
for (size_t i = 0; i < idx_vars_.size(); ++i) {
465+
idx_vars_[i] = Var("ph_" + std::to_string(i), inlined_store_->indices[i].dtype());
466+
}
467+
auto inverse_iter_map = arith::InverseAffineIterMap(
468+
res->indices, Array<PrimExpr>(idx_vars_.begin(), idx_vars_.end()));
469+
for (const auto& iter : producer_block->iter_vars) {
470+
if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) {
471+
// fallback mapping for constant iters
472+
inverse_iter_map.Set(iter->var, iter->dom->min);
473+
}
474+
}
475+
store_value_ = Substitute(store_value_, inverse_iter_map);
450476
return true;
451477
}
452478

@@ -464,45 +490,7 @@ class ComputeInliner : public BaseInliner {
464490

465491
PrimExpr ReplaceInlinedBuffer(BufferLoad load) {
466492
SetIndexSubstitution(load->indices);
467-
return Substitute(inlined_store_->value, idx_sub_);
468-
}
469-
470-
/*!
471-
* \brief Check if the indices are atomic distinct variables and the access is n-dimensional.
472-
* If so, set `self->idx_vars_` properly.
473-
* \param indices The indices to be extracted
474-
* \param expected_ndim The expected ndim of the access
475-
* \return A boolean flag indicating if the check is successful
476-
*/
477-
bool UpdateAndCheckIndexVars(const Array<PrimExpr>& indices, int expected_ndim) {
478-
int n = indices.size();
479-
if (n != expected_ndim) {
480-
// Failure: dimension mismatch
481-
return false;
482-
}
483-
std::vector<const VarNode*> result;
484-
result.reserve(n);
485-
for (const PrimExpr& i : indices) {
486-
if (const auto* var = i.as<VarNode>()) {
487-
result.push_back(var);
488-
} else {
489-
// Failure: indexing expression is not a variable
490-
return false;
491-
}
492-
}
493-
using DistinctSet = std::unordered_set<const VarNode*>;
494-
int n_distinct = DistinctSet(result.begin(), result.end()).size();
495-
if (n != n_distinct) {
496-
// Failure: indexing variables are not distinct
497-
return false;
498-
}
499-
if (idx_vars_.empty()) {
500-
idx_vars_ = std::move(result);
501-
} else if (!support::ArrayWithSameContent(idx_vars_, result)) {
502-
// Failure: indexing variables are not consitent in different BufferLoads
503-
return false;
504-
}
505-
return true;
493+
return Substitute(store_value_, idx_sub_);
506494
}
507495

508496
/*!
@@ -512,11 +500,17 @@ class ComputeInliner : public BaseInliner {
512500
void SetIndexSubstitution(const Array<PrimExpr>& indices) {
513501
ICHECK_EQ(indices.size(), idx_vars_.size());
514502
int n = idx_vars_.size();
515-
idx_sub_.reserve(n);
516503
for (int i = 0; i < n; ++i) {
517-
idx_sub_[idx_vars_[i]] = indices[i];
504+
idx_sub_[idx_vars_[i].get()] = indices[i];
518505
}
519506
}
507+
508+
/*! \brief The arithmetic analyzer */
509+
arith::Analyzer analyzer_;
510+
/*! \brief The store value for inlinement. If the producer
511+
store indices are trivial, it is wrt the producer block iter var,
512+
otherwise it is wrt to the placeholder vars of store indices. */
513+
PrimExpr store_value_;
520514
};
521515

522516
/*!

0 commit comments

Comments
 (0)