@@ -259,53 +259,29 @@ class IterSumExpr : public IterMapExpr {
259259 TVM_DEFINE_OBJECT_REF_COW_METHOD (IterSumExprNode);
260260};
261261
262+ /* ! \brief Mapping level for iterators. */
263+ enum IterMapLevel {
264+ // Require the mapping to be bijective.
265+ Bijective = 0 ,
266+ // Require the mapping to be subjective.
267+ Surjective = 1 ,
268+ // Require the mapping to be injective.
269+ Injective = 2
270+ };
271+
262272/* !
263- * \brief Detect if indices can be written as
264- * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
265- *
266- * Here y = some-quasi-affine-iter-map(input_iters)
267- * and c are symbolic constants.
268- *
269- * We also requires that y_i and y_j to be independent for i != j.
270- *
271- * For returned value rv, the following is always true:
272- * - rv[i]->args.size() <=1: only one iterator per element.
273- *
274- * \param indices The indices to detect pattern for.
275- * \param input_iters Map from variable to iterator's range.
276- * \param predicate The predicate constraints on the input iterators
277- * \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
278- * \param analyzer Analyzer used to get context information.
279- * \param simplify_trivial_iterators If true, iterators with extent of
280- * 1 will be replaced with a constant value.
281- *
282- * \return The detected pattern if a match exists,
283- * otherwise return an empty array.
273+ * \brief Result of DetectIterMap.
284274 */
285- Array<IterSumExpr> DetectIterMap (const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
286- const PrimExpr& predicate, bool require_bijective,
287- arith::Analyzer* analyzer, bool simplify_trivial_iterators = true );
275+ class IterMapResultNode : public Object {
276+ public:
277+ // The detected pattern if a match exists.
278+ Array<IterSumExpr> indices;
288279
289- /* ! \brief A utility struct for return values from DetectPaddedIterMap
290- */
291- struct PaddedIterMapResult {
292280 // Any errors that occurred while converting the input indices. If
293281 // the array is empty, the conversion was successful.
294282 Array<String> errors;
295283
296- // The detected pattern if a match exists.
297- Array<IterSumExpr> indices;
298-
299- /* \brief Boolean expression indicating if padding was required
300- *
301- * `requires_padding` evaluates to true if the returned indices
302- * contain padding relative to the provided expressions, and false
303- * otherwise. If `input_iters` contains a variable extent, this
304- * expression may be in terms of those variables.
305- */
306- PrimExpr requires_padding;
307-
308- /* \brief Boolean expression indicating if a specific value w
284+ /* ! \brief Boolean expression indicating if a specific value w
309285 *
310286 * `padding_predicate` evaluates to true for a set of indices that
311287 * are outside the bounds of the provided index iterators, but
@@ -314,43 +290,54 @@ struct PaddedIterMapResult {
314290 * `input_iters`.
315291 */
316292 PrimExpr padding_predicate;
293+
294+ // overrides
295+ void VisitAttrs (tvm::AttrVisitor* v) {
296+ v->Visit (" errors" , &errors);
297+ v->Visit (" indices" , &indices);
298+ v->Visit (" padding_predicate" , &padding_predicate);
299+ }
300+
301+ static constexpr const char * _type_key = " arith.IterMapResult" ;
302+ TVM_DECLARE_FINAL_OBJECT_INFO (IterMapResultNode, Object);
303+ };
304+
305+ /* !
306+ * \brief Managed reference to IterMapResultNode.
307+ * \sa IterMapResultNode
308+ */
309+ class IterMapResult : public ObjectRef {
310+ public:
311+ TVM_DEFINE_OBJECT_REF_METHODS (IterMapResult, ObjectRef, IterMapResultNode);
312+ TVM_DEFINE_OBJECT_REF_COW_METHOD (IterMapResultNode);
317313};
318314
319315/* !
320316 * \brief Detect if indices can be written as
321317 * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
322318 *
323- * Here y = some-quasi-affine-iter-map(input_iters) and c are
324- * symbolic constants. The y_i iterators may be padded to fit this
325- * representation.
319+ * Here y = some-quasi-affine-iter-map(input_iters)
320+ * and c are symbolic constants.
326321 *
327322 * We also requires that y_i and y_j to be independent for i != j.
328323 *
329324 * For returned value rv, the following is always true:
330- * - rv.indices [i]->args.size() <=1: only one iterator per element.
325+ * - rv[i]->args.size() <=1: only one iterator per element.
331326 *
332327 * \param indices The indices to detect pattern for.
333- *
334328 * \param input_iters Map from variable to iterator's range.
335- *
336329 * \param predicate The predicate constraints on the input iterators
337- *
338- * \param require_bijective A boolean flag that indicates whether the
339- * mapping should be bijective. If true, no padding may be
340- * introduced.
341- *
330+ * \param check_level The iter mapping check level.
342331 * \param analyzer Analyzer used to get context information.
343- *
344332 * \param simplify_trivial_iterators If true, iterators with extent of
345333 * 1 will be replaced with a constant value.
346334 *
347- * \return An instance of PaddedIterMapResult.
335+ * \return The detected iteration result.
336+ * The return object's .indices is empty on failure.
348337 */
349- PaddedIterMapResult DetectPaddedIterMap (const Array<PrimExpr>& indices,
350- const Map<Var, Range>& input_iters,
351- const PrimExpr& predicate, bool require_bijective,
352- arith::Analyzer* analyzer,
353- bool simplify_trivial_iterators = true );
338+ IterMapResult DetectIterMap (const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
339+ const PrimExpr& predicate, IterMapLevel check_level,
340+ arith::Analyzer* analyzer, bool simplify_trivial_iterators = true );
354341
355342/* !
356343 * \brief Use IterVarMap detector to rewrite and simplify the indices
0 commit comments