Skip to content

Commit 700b702

Browse files
merge DetectIterMap and DetectIterMapPadded
1 parent 5a2d333 commit 700b702

File tree

14 files changed

+419
-299
lines changed

14 files changed

+419
-299
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 45 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
259259
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
260260
};
261261

262+
/*! \brief Mapping level for iterators. */
263+
enum IterMapLevel {
264+
// Require the mapping to be bijective.
265+
Bijective = 0,
266+
// Require the mapping to be subjective.
267+
Surjective = 1,
268+
// Require the mapping to be injective.
269+
Injective = 2
270+
};
271+
262272
/*!
263-
* \brief Detect if indices can be written as
264-
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
265-
*
266-
* Here y = some-quasi-affine-iter-map(input_iters)
267-
* and c are symbolic constants.
268-
*
269-
* We also requires that y_i and y_j to be independent for i != j.
270-
*
271-
* For returned value rv, the following is always true:
272-
* - rv[i]->args.size() <=1: only one iterator per element.
273-
*
274-
* \param indices The indices to detect pattern for.
275-
* \param input_iters Map from variable to iterator's range.
276-
* \param predicate The predicate constraints on the input iterators
277-
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
278-
* \param analyzer Analyzer used to get context information.
279-
* \param simplify_trivial_iterators If true, iterators with extent of
280-
* 1 will be replaced with a constant value.
281-
*
282-
* \return The detected pattern if a match exists,
283-
* otherwise return an empty array.
273+
* \brief Result of DetectIterMap.
284274
*/
285-
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
286-
const PrimExpr& predicate, bool require_bijective,
287-
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
275+
class IterMapResultNode : public Object {
276+
public:
277+
// The detected pattern if a match exists.
278+
Array<IterSumExpr> indices;
288279

289-
/*! \brief A utility struct for return values from DetectPaddedIterMap
290-
*/
291-
struct PaddedIterMapResult {
292280
// Any errors that occurred while converting the input indices. If
293281
// the array is empty, the conversion was successful.
294282
Array<String> errors;
295283

296-
// The detected pattern if a match exists.
297-
Array<IterSumExpr> indices;
298-
299-
/* \brief Boolean expression indicating if padding was required
300-
*
301-
* `requires_padding` evaluates to true if the returned indices
302-
* contain padding relative to the provided expressions, and false
303-
* otherwise. If `input_iters` contains a variable extent, this
304-
* expression may be in terms of those variables.
305-
*/
306-
PrimExpr requires_padding;
307-
308-
/* \brief Boolean expression indicating if a specific value w
284+
/*! \brief Boolean expression indicating if a specific value w
309285
*
310286
* `padding_predicate` evaluates to true for a set of indices that
311287
* are outside the bounds of the provided index iterators, but
@@ -314,43 +290,54 @@ struct PaddedIterMapResult {
314290
* `input_iters`.
315291
*/
316292
PrimExpr padding_predicate;
293+
294+
// overrides
295+
void VisitAttrs(tvm::AttrVisitor* v) {
296+
v->Visit("errors", &errors);
297+
v->Visit("indices", &indices);
298+
v->Visit("padding_predicate", &padding_predicate);
299+
}
300+
301+
static constexpr const char* _type_key = "arith.IterMapResult";
302+
TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
303+
};
304+
305+
/*!
306+
* \brief Managed reference to IterMapResultNode.
307+
* \sa IterMapResultNode
308+
*/
309+
class IterMapResult : public ObjectRef {
310+
public:
311+
TVM_DEFINE_OBJECT_REF_METHODS(IterMapResult, ObjectRef, IterMapResultNode);
312+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMapResultNode);
317313
};
318314

319315
/*!
320316
* \brief Detect if indices can be written as
321317
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
322318
*
323-
* Here y = some-quasi-affine-iter-map(input_iters) and c are
324-
* symbolic constants. The y_i iterators may be padded to fit this
325-
* representation.
319+
* Here y = some-quasi-affine-iter-map(input_iters)
320+
* and c are symbolic constants.
326321
*
327322
* We also requires that y_i and y_j to be independent for i != j.
328323
*
329324
* For returned value rv, the following is always true:
330-
* - rv.indices[i]->args.size() <=1: only one iterator per element.
325+
* - rv[i]->args.size() <=1: only one iterator per element.
331326
*
332327
* \param indices The indices to detect pattern for.
333-
*
334328
* \param input_iters Map from variable to iterator's range.
335-
*
336329
* \param predicate The predicate constraints on the input iterators
337-
*
338-
* \param require_bijective A boolean flag that indicates whether the
339-
* mapping should be bijective. If true, no padding may be
340-
* introduced.
341-
*
330+
* \param check_level The iter mapping check level.
342331
* \param analyzer Analyzer used to get context information.
343-
*
344332
* \param simplify_trivial_iterators If true, iterators with extent of
345333
* 1 will be replaced with a constant value.
346334
*
347-
* \return An instance of PaddedIterMapResult.
335+
* \return The detected iteration result.
336+
* The return object's .indices is empty on failure.
348337
*/
349-
PaddedIterMapResult DetectPaddedIterMap(const Array<PrimExpr>& indices,
350-
const Map<Var, Range>& input_iters,
351-
const PrimExpr& predicate, bool require_bijective,
352-
arith::Analyzer* analyzer,
353-
bool simplify_trivial_iterators = true);
338+
IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
339+
const PrimExpr& predicate, IterMapLevel check_level,
340+
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
354341

355342
/*!
356343
* \brief Use IterVarMap detector to rewrite and simplify the indices

python/tvm/arith/iter_affine_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def detect_iter_map(
117117
118118
Returns
119119
-------
120-
results : List[IterSumExpr]
120+
results : IterMapResult
121121
The iter map matching result.
122-
Empty array if no match can be found.
122+
The result's .indices is empty array if no match can be found.
123123
124124
"""
125125
return _ffi_api.DetectIterMap(
126126
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
127-
)
127+
).indices
128128

129129

130130
def normalize_iter_map_to_expr(expr):

src/arith/int_set.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,9 +867,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
867867
for (const Range& range : region) {
868868
affine_indices.push_back(range->min);
869869
}
870-
iter_sum_exprs = DetectIterMap(
870+
auto res = DetectIterMap(
871871
/*indices=*/affine_indices, /*input_iters=*/var_dom,
872-
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
872+
/*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
873+
iter_sum_exprs = res->indices;
873874
}
874875
if (iter_sum_exprs.empty()) {
875876
return NullOpt;

0 commit comments

Comments
 (0)