Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2147,7 +2147,8 @@ class InverseAffineIterMapTransformer {
// Case 1: Propagate to the input node directly when the sum expression has only one components
if (iter_map_expr->args.size() == 1) {
const auto& source = iter_map_expr->args[0];
backprop_.Set(source, backprop_.at(source) + input);
ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1));
backprop_.Set(source, (backprop_.at(source) + input) * source->scale);
return;
}

Expand Down
179 changes: 129 additions & 50 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,81 +391,163 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
* domain
* \param provided The provided integer set to cover the required domain
* \param required The required domain to be covered
* \param dim_max The maximum index bound by the buffer shape
* \param analyzer The arithmetic analyzer
*/
std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& provided,
const arith::IntSet& required,
arith::Analyzer* analyzer) {
std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain(const arith::IntSet& provided,
const arith::IntSet& required,
PrimExpr dim_max,
arith::Analyzer* analyzer) {
PrimExpr provided_min = analyzer->Simplify(provided.min());
PrimExpr provided_max = analyzer->Simplify(provided.max());
PrimExpr required_min = analyzer->Simplify(required.min());
PrimExpr required_max = analyzer->Simplify(required.max());
PrimExpr dom_min{nullptr}, dom_max{nullptr};
Var dom_var{ObjectPtr<VarNode>{nullptr}};
arith::IntSet var_dom, var_bound;
Optional<Var> var;
arith::PVar<Var> p_v;
arith::PVar<PrimExpr> p_e;
if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
PrimExpr e = p_e.Eval();
dom_var = p_v.Eval();
dom_min = floordiv(required_min, e);
dom_max = floordiv(required_max, e);
var = p_v.Eval();
var_dom = arith::IntSet::Interval(floordiv(required_min, e), floordiv(required_max, e));
var_bound = arith::IntSet::Interval(0, floordiv(dim_max, e));
} else if (analyzer->CanProveEqual(provided_min, provided_max)) {
if (p_v.Match(provided_min)) {
dom_var = p_v.Eval();
dom_min = required_min;
dom_max = required_max;
var = p_v.Eval();
var_dom = arith::IntSet::Interval(required_min, required_max);
var_bound = arith::IntSet::Interval(0, dim_max);
} else {
arith::PVar<PrimExpr> p_f;
if ((floordiv(p_v, p_f)).Match(provided_min)) {
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
PrimExpr fac = p_f.Eval();
if (analyzer->CanProveGreaterEqual(fac, 1)) {
dom_var = p_v.Eval();
dom_min = required_min * fac;
dom_max = analyzer->Simplify(required_max * fac + fac - 1);
var = p_v.Eval();
var_dom = arith::IntSet::Interval(required_min * fac,
analyzer->Simplify(required_max * fac + fac - 1));
var_bound = arith::IntSet::Interval(0, analyzer->Simplify(dim_max * fac + fac - 1));
}
} else if ((floormod(p_v, p_f).Match(provided_min))) {
// generally domain of (x % fac) enforce no constraints to domain of x
dom_var = p_v.Eval();
return std::make_pair(dom_var, arith::IntSet::Nothing());
return {p_v.Eval(), BlockVarDomainInfo()};
}
}
}
ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
ICHECK(var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
return {var.value(), BlockVarDomainInfo{var_dom, var_bound}};
}

/*!
* \brief Calculate and update the iteration domain info to fully cover the required domain
* \param provided The provided integer set to cover the required domain
* \param required The required domain to be covered
* \param required_bound The additional region bound of the required domain to be covered
* \brief Calculate and update the iteration domain info to fully cover the required domain in
* dimension-wise fashion. The region relation on each buffer dimension is independently estimated.
* \param buffer The accessed buffer
* \param provided_region The provided NDIntSet to cover the required domain
* \param required_region The required NDIntSet domain to be covered
* \param analyzer The arithmetic analyzer
* \param iter_doms The result iteration domains to be updated
*/
void UpdateBlockVarDomainDimwise(
const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet& required_region,
arith::Analyzer* analyzer, std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
size_t ndim = buffer->shape.size();
for (size_t i = 0; i < ndim; ++i) {
arith::IntSet provided = provided_region[i];
arith::IntSet required = required_region[i];
PrimExpr dim_max = max(buffer->shape[i] - 1, 0);

if (provided.IsSinglePoint() && is_const_int(provided.min())) {
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
continue;
}

auto [var, dom_info] = SolveBlockVarDomain(provided, required, dim_max, analyzer);
auto it = iter_doms->find(var.get());
if (it != iter_doms->end()) {
it->second.Union(dom_info);
} else {
ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
}
}
}

/*! \brief Helper function to implement intset version of `InverseAffineIterMap`. */
Map<Var, arith::IntSet> InverseAffineIterMap(const Array<arith::IterSumExpr>& iter_map,
const NDIntSet& outputs, arith::Analyzer* analyzer) {
Array<PrimExpr> min_point, max_point;
min_point.reserve(outputs.size());
max_point.reserve(outputs.size());
for (const auto& intset : outputs) {
ICHECK(intset.HasLowerBound() && intset.HasUpperBound());
min_point.push_back(intset.min());
max_point.push_back(intset.max());
}
auto rev_min = InverseAffineIterMap(iter_map, min_point);
auto rev_max = InverseAffineIterMap(iter_map, max_point);
Map<Var, arith::IntSet> dom_map;
for (const auto& kv : rev_min) {
const Var& var = kv.first;
auto it = rev_max.find(var);
ICHECK(it != rev_max.end()); // InverseAffineIterMap's result vars are assumed stable
const PrimExpr& rev_min_point = kv.second;
const PrimExpr& rev_max_point = (*it).second;
dom_map.Set(var,
arith::IntSet::Interval(analyzer->Simplify(min(rev_min_point, rev_max_point)),
analyzer->Simplify(max(rev_min_point, rev_max_point))));
}
return dom_map;
}

/*!
* \brief Calculate and update the iteration domain info to fully cover the required domain
* with affine analysis. It requires bijective mapping of block var to provided region points.
* \param buffer The accessed buffer
* \param iter_vars The list of block vars to cover the required region
* \param provided_region The provided NDIntSet to cover the required domain
* \param required_region The required NDIntSet domain to be covered
* \param analyzer The arithmetic analyzer
* \param iter_doms The result iteration domains to be updated
* \returns bool. Denotes whether update success
*/
void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
const arith::IntSet& required_bound,
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
arith::Analyzer* analyzer) {
if (provided.IsSinglePoint() && is_const_int(provided.min())) {
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
ICHECK(required_bound.IsSinglePoint() &&
analyzer->CanProveEqual(provided.min(), required_bound.min()));
return;
bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array<IterVar>& iter_vars,
const NDIntSet& provided_region, const NDIntSet& required_region,
arith::Analyzer* analyzer,
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
// we only support single point provided region now, which could cover most cases
for (const auto& intset : provided_region) {
if (!intset.IsSinglePoint()) return false;
}
// calculate forward mapping (block vars -> provided region point)
Map<Var, Range> dom_map;
for (const IterVar& iter_var : iter_vars) {
dom_map.Set(iter_var->var, iter_var->dom);
}
auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
const Var& var = var_with_dom.first;
const auto& var_dom = var_with_dom.second;
const auto& var_bound = var_with_bound.second;
ICHECK(var.same_as(var_with_bound.first));
auto it = iter_doms->find(var.get());
if (it != iter_doms->end()) {
it->second.Union({var_dom, var_bound});
} else {
ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
size_t ndim = buffer->shape.size();
Array<PrimExpr> provide_indices;
provide_indices.reserve(ndim);
for (size_t i = 0; i < ndim; ++i) {
provide_indices.push_back(provided_region[i].min());
}
auto res = arith::DetectIterMap(provide_indices, dom_map, const_true(),
arith::IterMapLevel::Bijective, analyzer, false);
if (res->indices.empty()) {
return false;
}
// calculate backward mapping (required region point -> block vars)
NDIntSet required_bound;
for (size_t i = 0; i < ndim; ++i) {
required_bound.push_back(
arith::IntSet::Interval(make_zero(buffer->shape[i]->dtype), max(buffer->shape[i] - 1, 0)));
}
Map<Var, arith::IntSet> var_dom = InverseAffineIterMap(res->indices, required_region, analyzer);
Map<Var, arith::IntSet> var_bound = InverseAffineIterMap(res->indices, required_bound, analyzer);
for (const auto& kv : var_dom) {
const Var& var = kv.first;
auto it = var_bound.find(var);
ICHECK(it != var_bound.end()); // InverseAffineIterMap's result vars are assumed stable
(*iter_doms)[var.get()].Union(BlockVarDomainInfo{kv.second, (*it).second});
}
return true;
}

/*!
Expand Down Expand Up @@ -501,13 +583,10 @@ std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions);
ICHECK_EQ(provided_region.size(), buffer->shape.size());
ICHECK_EQ(required_region.size(), buffer->shape.size());
// For each dimension, update the iteration domain
int ndim = buffer->shape.size();
for (int i = 0; i < ndim; ++i) {
arith::IntSet provided = provided_region[i];
arith::IntSet required = required_region[i];
arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]);
UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer);
// Try update iter var domains with current required and provided region pair.
if (!UpdateBlockVarDomainAffine(buffer, iter_vars, provided_region, required_region, analyzer,
&iter_doms)) {
UpdateBlockVarDomainDimwise(buffer, provided_region, required_region, analyzer, &iter_doms);
}
}
// Union the iter var domains, put them in the same order of block vars, and return
Expand Down
74 changes: 74 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,40 @@ def test_compute_at_tiled_repeat_op(use_block_name):
verify_trace_roundtrip(sch=sch, mod=tiled_repeat_op)


def test_compute_at_rev_iter():
@T.prim_func
def before(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), "float32"]):
Y = T.alloc_buffer([10, 10], "float32")
for i, j in T.grid(10, 10):
with T.block("b0"):
vi, vj = T.axis.remap("SS", [i, j])
Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
for i, j in T.grid(10, 10):
with T.block("b1"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vj, vi] + 2.0

@T.prim_func
def after(X: T.Buffer[(10, 10), "float32"], Z: T.Buffer[(10, 10), "float32"]):
Y = T.alloc_buffer([10, 10], "float32")
for i in range(10):
for j in range(10):
with T.block("b0"):
vi = T.axis.spatial(10, j)
vj = T.axis.spatial(10, 9 - i)
Y[9 - vi, 9 - vj] = X[vi, vj] + 1.0
for j in range(10):
with T.block("b1"):
vi, vj = T.axis.remap("SS", [i, j])
Z[vi, vj] = Y[vj, vi] + 2.0

sch = tir.Schedule(before, debug_mask="all")
axis = sch.get_loops(sch.get_block("b1"))[0]
sch.compute_at(sch.get_block("b0"), axis)
tvm.ir.assert_structural_equal(after, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


def test_reverse_compute_at_tiled(use_block_name):
sch = tir.Schedule(tiled, debug_mask="all")
block = sch.get_block("C")
Expand Down Expand Up @@ -1557,5 +1591,45 @@ def main_reverse_compute_at(
tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])


def test_reverse_compute_at_layout_trans():
@T.prim_func
def before(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), "float32")):
B = T.alloc_buffer((1, 3, 5, 5, 16))
for i0, i1, i2, i3, i4 in T.grid(1, 3, 5, 5, 16):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1)
for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 6, 5, 5, 8):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 + v_ax4) % 16
]

@T.prim_func
def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8), "float32")):
B = T.alloc_buffer((1, 3, 5, 5, 16))
for i0, i1 in T.grid(1, 3):
for i2, i3, i4 in T.grid(5, 5, 16):
with T.block("compute"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
B[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + T.float32(1)
for ax0, ax1, ax2, ax3 in T.grid(2, 5, 5, 8):
with T.block("T_layout_trans"):
v_ax0 = T.axis.spatial(1, 0)
v_ax1 = T.axis.spatial(6, i1 * 2 + ax0)
v_ax2, v_ax3, v_ax4 = T.axis.remap("SSS", [ax1, ax2, ax3])
C[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = B[
v_ax0, (v_ax1 * 8 + v_ax4) // 16, v_ax2, v_ax3, (v_ax1 * 8 + v_ax4) % 16
]

sch = tir.Schedule(before, debug_mask="all")
trans = sch.get_block("T_layout_trans")
axis = sch.get_loops("compute")[1]
sch.reverse_compute_at(trans, axis)
tvm.ir.assert_structural_equal(after, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=before)


if __name__ == "__main__":
tvm.testing.main()