Skip to content

Commit 946581a

Browse files
[TIR][Compute-at] Utilize InverseAffineIterMap for dom estimation (#14184)
utilize inverse iter map tool for compute_at iter region estimation
1 parent f6b7579 commit 946581a

File tree

3 files changed

+205
-51
lines changed

3 files changed

+205
-51
lines changed

src/arith/iter_affine_map.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2147,7 +2147,8 @@ class InverseAffineIterMapTransformer {
21472147
// Case 1: Propagate to the input node directly when the sum expression has only one components
21482148
if (iter_map_expr->args.size() == 1) {
21492149
const auto& source = iter_map_expr->args[0];
2150-
backprop_.Set(source, backprop_.at(source) + input);
2150+
ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1));
2151+
backprop_.Set(source, (backprop_.at(source) + input) * source->scale);
21512152
return;
21522153
}
21532154

src/tir/schedule/primitive/compute_at.cc

Lines changed: 129 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -391,81 +391,163 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
391391
* domain
392392
* \param provided The provided integer set to cover the required domain
393393
* \param required The required domain to be covered
394+
* \param dim_max The maximum index bound by the buffer shape
394395
* \param analyzer The arithmetic analyzer
395396
*/
396-
std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& provided,
397-
const arith::IntSet& required,
398-
arith::Analyzer* analyzer) {
397+
std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& provided,
398+
const arith::IntSet& required,
399+
PrimExpr dim_max,
400+
arith::Analyzer* analyzer) {
399401
PrimExpr provided_min = analyzer->Simplify(provided.min());
400402
PrimExpr provided_max = analyzer->Simplify(provided.max());
401403
PrimExpr required_min = analyzer->Simplify(required.min());
402404
PrimExpr required_max = analyzer->Simplify(required.max());
403-
PrimExpr dom_min{nullptr}, dom_max{nullptr};
404-
Var dom_var{ObjectPtr<VarNode>{nullptr}};
405+
arith::IntSet var_dom, var_bound;
406+
Optional<Var> var;
405407
arith::PVar<Var> p_v;
406408
arith::PVar<PrimExpr> p_e;
407409
if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
408410
PrimExpr e = p_e.Eval();
409-
dom_var = p_v.Eval();
410-
dom_min = floordiv(required_min, e);
411-
dom_max = floordiv(required_max, e);
411+
var = p_v.Eval();
412+
var_dom = arith::IntSet::Interval(floordiv(required_min, e), floordiv(required_max, e));
413+
var_bound = arith::IntSet::Interval(0, floordiv(dim_max, e));
412414
} else if (analyzer->CanProveEqual(provided_min, provided_max)) {
413415
if (p_v.Match(provided_min)) {
414-
dom_var = p_v.Eval();
415-
dom_min = required_min;
416-
dom_max = required_max;
416+
var = p_v.Eval();
417+
var_dom = arith::IntSet::Interval(required_min, required_max);
418+
var_bound = arith::IntSet::Interval(0, dim_max);
417419
} else {
418420
arith::PVar<PrimExpr> p_f;
419421
if ((floordiv(p_v, p_f)).Match(provided_min)) {
420422
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
421423
PrimExpr fac = p_f.Eval();
422424
if (analyzer->CanProveGreaterEqual(fac, 1)) {
423-
dom_var = p_v.Eval();
424-
dom_min = required_min * fac;
425-
dom_max = analyzer->Simplify(required_max * fac + fac - 1);
425+
var = p_v.Eval();
426+
var_dom = arith::IntSet::Interval(required_min * fac,
427+
analyzer->Simplify(required_max * fac + fac - 1));
428+
var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
426429
}
427430
} else if ((floormod(p_v, p_f).Match(provided_min))) {
428431
// generally domain of (x % fac) enforce no constraints to domain of x
429-
dom_var = p_v.Eval();
430-
return std::make_pair(dom_var, arith::IntSet::Nothing());
432+
return {p_v.Eval(), BlockVarDomainInfo()};
431433
}
432434
}
433435
}
434-
ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
435-
return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
436+
ICHECK(var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
437+
return {var.value(), BlockVarDomainInfo{var_dom, var_bound}};
436438
}
437439

438440
/*!
439-
* \brief Calculate and update the iteration domain info to fully cover the required domain
440-
* \param provided The provided integer set to cover the required domain
441-
* \param required The required domain to be covered
442-
* \param required_bound The additional region bound of the required domain to be covered
441+
* \brief Calculate and update the iteration domain info to fully cover the required domain in
442+
* dimension-wise fashion. The region relation on each buffer dimension is independently estimated.
443+
* \param buffer The accessed buffer
444+
* \param provided_region The provided NDIntSet to cover the required domain
445+
* \param required_region The required NDIntSet domain to be covered
446+
* \param analyzer The arithmetic analyzer
443447
* \param iter_doms The result iteration domains to be updated
448+
*/
449+
void UpdateBlockVarDomainDimwise(
450+
const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet& required_region,
451+
arith::Analyzer* analyzer, std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
452+
size_t ndim = buffer->shape.size();
453+
for (size_t i = 0; i < ndim; ++i) {
454+
arith::IntSet provided = provided_region[i];
455+
arith::IntSet required = required_region[i];
456+
PrimExpr dim_max = max(buffer->shape[i] - 1, 0);
457+
458+
if (provided.IsSinglePoint() && is_const_int(provided.min())) {
459+
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
460+
continue;
461+
}
462+
463+
auto [var, dom_info] = SolveBlockVarDomain(provided, required, dim_max, analyzer);
464+
auto it = iter_doms->find(var.get());
465+
if (it != iter_doms->end()) {
466+
it->second.Union(dom_info);
467+
} else {
468+
ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
469+
ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
470+
}
471+
}
472+
}
473+
474+
/*! \brief Helper function to implement intset version of `InverseAffineIterMap`. */
475+
Map<Var, arith::IntSet> InverseAffineIterMap(const Array<arith::IterSumExpr>& iter_map,
476+
const NDIntSet& outputs, arith::Analyzer* analyzer) {
477+
Array<PrimExpr> min_point, max_point;
478+
min_point.reserve(outputs.size());
479+
max_point.reserve(outputs.size());
480+
for (const auto& intset : outputs) {
481+
ICHECK(intset.HasLowerBound() && intset.HasUpperBound());
482+
min_point.push_back(intset.min());
483+
max_point.push_back(intset.max());
484+
}
485+
auto rev_min = InverseAffineIterMap(iter_map, min_point);
486+
auto rev_max = InverseAffineIterMap(iter_map, max_point);
487+
Map<Var, arith::IntSet> dom_map;
488+
for (const auto& kv : rev_min) {
489+
const Var& var = kv.first;
490+
auto it = rev_max.find(var);
491+
ICHECK(it != rev_max.end()); // InverseAffineIterMap's result vars are assumed stable
492+
const PrimExpr& rev_min_point = kv.second;
493+
const PrimExpr& rev_max_point = (*it).second;
494+
dom_map.Set(var,
495+
arith::IntSet::Interval(analyzer->Simplify(min(rev_min_point, rev_max_point)),
496+
analyzer->Simplify(max(rev_min_point, rev_max_point))));
497+
}
498+
return dom_map;
499+
}
500+
501+
/*!
502+
* \brief Calculate and update the iteration domain info to fully cover the required domain
503+
* with affine analysis. It requires bijective mapping of block var to provided region points.
504+
* \param buffer The accessed buffer
505+
* \param iter_vars The list of block vars to cover the required region
506+
* \param provided_region The provided NDIntSet to cover the required domain
507+
* \param required_region The required NDIntSet domain to be covered
444508
* \param analyzer The arithmetic analyzer
509+
* \param iter_doms The result iteration domains to be updated
510+
* \returns bool. Denotes whether update success
445511
*/
446-
void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
447-
const arith::IntSet& required_bound,
448-
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
449-
arith::Analyzer* analyzer) {
450-
if (provided.IsSinglePoint() && is_const_int(provided.min())) {
451-
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
452-
ICHECK(required_bound.IsSinglePoint() &&
453-
analyzer->CanProveEqual(provided.min(), required_bound.min()));
454-
return;
512+
bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array<IterVar>& iter_vars,
513+
const NDIntSet& provided_region, const NDIntSet& required_region,
514+
arith::Analyzer* analyzer,
515+
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
516+
// we only support single point provided region now, which could cover most cases
517+
for (const auto& intset : provided_region) {
518+
if (!intset.IsSinglePoint()) return false;
519+
}
520+
// calculate forward mapping (block vars -> provided region point)
521+
Map<Var, Range> dom_map;
522+
for (const IterVar& iter_var : iter_vars) {
523+
dom_map.Set(iter_var->var, iter_var->dom);
455524
}
456-
auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
457-
auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
458-
const Var& var = var_with_dom.first;
459-
const auto& var_dom = var_with_dom.second;
460-
const auto& var_bound = var_with_bound.second;
461-
ICHECK(var.same_as(var_with_bound.first));
462-
auto it = iter_doms->find(var.get());
463-
if (it != iter_doms->end()) {
464-
it->second.Union({var_dom, var_bound});
465-
} else {
466-
ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
467-
ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
525+
size_t ndim = buffer->shape.size();
526+
Array<PrimExpr> provide_indices;
527+
provide_indices.reserve(ndim);
528+
for (size_t i = 0; i < ndim; ++i) {
529+
provide_indices.push_back(provided_region[i].min());
530+
}
531+
auto res = arith::DetectIterMap(provide_indices, dom_map, const_true(),
532+
arith::IterMapLevel::Bijective, analyzer, false);
533+
if (res->indices.empty()) {
534+
return false;
468535
}
536+
// calculate backward mapping (required region point -> block vars)
537+
NDIntSet required_bound;
538+
for (size_t i = 0; i < ndim; ++i) {
539+
required_bound.push_back(
540+
arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype), max(buffer->shape[i] - 1, 0)));
541+
}
542+
Map<Var, arith::IntSet> var_dom = InverseAffineIterMap(res->indices, required_region, analyzer);
543+
Map<Var, arith::IntSet> var_bound = InverseAffineIterMap(res->indices, required_bound, analyzer);
544+
for (const auto& kv : var_dom) {
545+
const Var& var = kv.first;
546+
auto it = var_bound.find(var);
547+
ICHECK(it != var_bound.end()); // InverseAffineIterMap's result vars are assumed stable
548+
(*iter_doms)[var.get()].Union(BlockVarDomainInfo{kv.second, (*it).second});
549+
}
550+
return true;
469551
}
470552

471553
/*!
@@ -501,13 +583,10 @@ std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
501583
NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
502584
ICHECK_EQ(provided_region.size(), buffer->shape.size());
503585
ICHECK_EQ(required_region.size(), buffer->shape.size());
504-
// For each dimension, update the iteration domain
505-
int ndim = buffer->shape.size();
506-
for (int i = 0; i < ndim; ++i) {
507-
arith::IntSet provided = provided_region[i];
508-
arith::IntSet required = required_region[i];
509-
arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]);
510-
UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer);
586+
// Try update iter var domains with current required and provided region pair.
587+
if (!UpdateBlockVarDomainAffine(buffer, iter_vars, provided_region, required_region, analyzer,
588+
&iter_doms)) {
589+
UpdateBlockVarDomainDimwise(buffer, provided_region, required_region, analyzer, &iter_doms);
511590
}
512591
}
513592
// Union the iter var domains, put them in the same order of block vars, and return

tests/python/unittest/test_tir_schedule_compute_at.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,40 @@ def test_compute_at_tiled_repeat_op(use_block_name):
11741174
verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op)
11751175

11761176

1177+
def test_compute_at_rev_iter():
1178+
@T.prim_func
1179+
def before(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), "float32"]):
1180+
Y = T.alloc_buffer([10, 10], "float32")
1181+
for i, j in T.grid(10, 10):
1182+
with T.block("b0"):
1183+
vi, vj = T.axis.remap("SS", [i, j])
1184+
Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
1185+
for i, j in T.grid(10, 10):
1186+
with T.block("b1"):
1187+
vi, vj = T.axis.remap("SS", [i, j])
1188+
Z[vi, vj] = Y[vj, vi] + 2.0
1189+
1190+
@T.prim_func
1191+
def after(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), "float32"]):
1192+
Y = T.alloc_buffer([10, 10], "float32")
1193+
for i in range(10):
1194+
for j in range(10):
1195+
with T.block("b0"):
1196+
vi = T.axis.spatial(10, j)
1197+
vj = T.axis.spatial(10, 9 - i)
1198+
Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
1199+
for j in range(10):
1200+
with T.block("b1"):
1201+
vi, vj = T.axis.remap("SS", [i, j])
1202+
Z[vi, vj] = Y[vj, vi] + 2.0
1203+
1204+
sch = tir.Schedule(before, debug_mask="all")
1205+
axis = sch.get_loops(sch.get_block("b1"))[0]
1206+
sch.compute_at(sch.get_block("b0"), axis)
1207+
tvm.ir.assert_structural_equal(after, sch.mod["main"])
1208+
verify_trace_roundtrip(sch=sch, mod=before)
1209+
1210+
11771211
def test_reverse_compute_at_tiled(use_block_name):
11781212
sch = tir.Schedule(tiled, debug_mask="all")
11791213
block = sch.get_block("C")
@@ -1557,5 +1591,45 @@ def main_reverse_compute_at(
15571591
tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
15581592

15591593

1594+
def test_reverse_compute_at_layout_trans():
1595+
@T.prim_func
1596+
def before(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), "float32")):
1597+
B = T.alloc_buffer((1, 3, 5, 5, 16))
1598+
for i0, i1, i2, i3, i4 in T.grid(1, 3, 5, 5, 16):
1599+
with T.block("compute"):
1600+
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
1601+
B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1)
1602+
for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 6, 5, 5, 8):
1603+
with T.block("T_layout_trans"):
1604+
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
1605+
C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
1606+
v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 + v_ax4) % 16
1607+
]
1608+
1609+
@T.prim_func
1610+
def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), "float32")):
1611+
B = T.alloc_buffer((1, 3, 5, 5, 16))
1612+
for i0, i1 in T.grid(1, 3):
1613+
for i2, i3, i4 in T.grid(5, 5, 16):
1614+
with T.block("compute"):
1615+
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
1616+
B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1)
1617+
for ax0, ax1, ax2, ax3 in T.grid(2, 5, 5, 8):
1618+
with T.block("T_layout_trans"):
1619+
v_ax0 = T.axis.spatial(1, 0)
1620+
v_ax1 = T.axis.spatial(6, i1 * 2 + ax0)
1621+
v_ax2, v_ax3, v_ax4 = T.axis.remap("SSS", [ax1, ax2, ax3])
1622+
C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
1623+
v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 + v_ax4) % 16
1624+
]
1625+
1626+
sch = tir.Schedule(before, debug_mask="all")
1627+
trans = sch.get_block("T_layout_trans")
1628+
axis = sch.get_loops("compute")[1]
1629+
sch.reverse_compute_at(trans, axis)
1630+
tvm.ir.assert_structural_equal(after, sch.mod["main"])
1631+
verify_trace_roundtrip(sch=sch, mod=before)
1632+
1633+
15601634
if __name__ == "__main__":
15611635
tvm.testing.main()

0 commit comments

Comments
 (0)