Skip to content

Commit 21ccbf8

Browse files
update to use iter shift instead of itermark min for lowerbound
1 parent 1b1852e commit 21ccbf8

File tree

7 files changed

+151
-154
lines changed

7 files changed

+151
-154
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class IterMapExpr : public PrimExpr {
8383
};
8484

8585
/*!
86-
* \brief Mark the source as an iterator in [min, extent).
86+
* \brief Mark the source as an iterator in [0, extent).
8787
*
8888
* IterMark is used to mark source expression as a valid
8989
* iterator to make future analysis easy.
@@ -95,10 +95,6 @@ class IterMarkNode : public Object {
9595
* a IterSumExpr or a Var.
9696
*/
9797
PrimExpr source;
98-
/*!
99-
* \brief The min of the iteration.
100-
*/
101-
PrimExpr min;
10298
/*!
10399
* \brief The extent of the iteration.
104100
*/
@@ -107,19 +103,17 @@ class IterMarkNode : public Object {
107103
// overrides
108104
void VisitAttrs(tvm::AttrVisitor* v) {
109105
v->Visit("source", &source);
110-
v->Visit("min", &min);
111106
v->Visit("extent", &extent);
112107
}
113108

114109
bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
115110
equal->MarkGraphNode();
116-
return equal(source, other->source) && equal(extent, other->extent) && equal(min, other->min);
111+
return equal(source, other->source) && equal(extent, other->extent);
117112
}
118113

119114
void SHashReduce(SHashReducer hash_reduce) const {
120115
hash_reduce->MarkGraphNode();
121116
hash_reduce(source);
122-
hash_reduce(min);
123117
hash_reduce(extent);
124118
}
125119

@@ -138,10 +132,9 @@ class IterMark : public ObjectRef {
138132
/*!
139133
* \brief constructor.
140134
* \param source The source expression.
141-
* \param min The min of the iterator.
142135
* \param extent The extent of the iterator.
143136
*/
144-
TVM_DLL IterMark(PrimExpr source, PrimExpr min, PrimExpr extent);
137+
TVM_DLL IterMark(PrimExpr source, PrimExpr extent);
145138

146139
TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
147140
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode);

python/tvm/arith/iter_affine_map.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,12 @@ class IterMark(Object):
3434
source : PrimExpr.
3535
The source expression.
3636
37-
min_value : PrimExpr
38-
The min of the iterator.
39-
4037
extent : PrimExpr
4138
The extent of the iterator.
4239
"""
4340

44-
def __init__(self, source, min_value, extent):
45-
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, min_value, extent)
41+
def __init__(self, source, extent):
42+
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)
4643

4744

4845
@tvm._ffi.register_object("arith.IterSplitExpr")

src/arith/int_set.cc

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,9 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
839839
iter_sum_exprs = DetectIterMap(
840840
/*indices=*/affine_indices, /*input_iters=*/var_dom,
841841
/*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx);
842-
if (iter_sum_exprs.empty()) {
843-
return NullOpt;
844-
}
842+
}
843+
if (iter_sum_exprs.empty()) {
844+
return NullOpt;
845845
}
846846
ICHECK_EQ(iter_sum_exprs.size(), ndim);
847847
Array<IntSet> result;
@@ -858,27 +858,12 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
858858
if (!analyzer->CanProve(range->extent >= split->scale)) {
859859
return NullOpt;
860860
}
861-
if (is_zero(split->source->min)) {
862-
// IterSplitExpr: (source // lower_factor) % extent * scale
863-
// thus `(source // lower_factor) % extent` is within [0, extent - 1]
864-
// Therefore, the range of `region[i]->min` is `base + [0, (extent - 1) * scale]`
865-
const PrimExpr& base = sum_expr->base;
866-
result.push_back(IntSet::FromMinExtent(
867-
base, split->extent * split->scale + (range->extent - split->scale)));
868-
} else {
869-
// if mark_min is not zero, currently only process cases where
870-
// `scale == 1` and `lower_factor == 1` and `extent >= mark_extent`
871-
const IterMark& mark = split->source;
872-
arith::Analyzer analyzer;
873-
if (!analyzer.CanProveEqual(split->scale, 1) ||
874-
!analyzer.CanProveEqual(split->lower_factor, 1) ||
875-
!analyzer.CanProve(mark->extent - split->extent <= 0)) {
876-
return NullOpt;
877-
}
878-
PrimExpr region_min = analyzer.Simplify(sum_expr->base + split->source->min, 3);
879-
PrimExpr region_max = analyzer.Simplify(region_min + mark->extent - 1, 3);
880-
result.push_back(IntSet::Interval(region_min, region_max));
881-
}
861+
const PrimExpr& base = sum_expr->base;
862+
// IterSplitExpr: (source // lower_factor) % extent * scale
863+
// where `(source // lower_factor) % extent` is within [0, extent - 1]
864+
// Therefore, the range of `region[i]->min` is `base + [0, (extent - 1) * scale]`
865+
result.push_back(
866+
IntSet::FromMinExtent(base, split->extent * split->scale + (range->extent - split->scale)));
882867
}
883868
return result;
884869
}

0 commit comments

Comments
 (0)