Skip to content

Commit f422a7c

Browse files
committed
[ARITH] Enhance IterMapSimplify for symbolic
This PR refactors and enhances DetectIterMap and IterMapSimplify to enable symbolic shape simplification. Specifically, we add a routine to combine multiple IterSplitExpr into one if they come from the same source. It is helpful to distinguish iterator from normal constants in the simplification process. IterMapSimplify takes advantage of these information. This improvements is helpful to simplify the indices in flattened buffer when there is symbolic shape involved and normal simplifier. Also updated FlattenBuffer to take benefit of the enhanced simplifier. Test cases are added.
1 parent 4a919b4 commit f422a7c

File tree

11 files changed

+541
-87
lines changed

11 files changed

+541
-87
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,13 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
349349
* \param input_iters Map from variable to iterator's range.
350350
* \param input_pred The predicate constraints on the input iterators
351351
* \param check_level The iter mapping checking level.
352+
* \param analyzer Analyzer used to get context information.
352353
* \param simplify_trivial_iterators If true, iterators with unit extents are simplified
353354
* \return The indices after rewrite
354355
*/
355356
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
356357
const PrimExpr& input_pred, IterMapLevel check_level,
357-
bool simplify_trivial_iterators = true);
358+
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
358359

359360
/*!
360361
* \brief Apply the inverse of the affine transformation to the outputs.

python/tvm/arith/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
3131
from .iter_affine_map import (
3232
detect_iter_map,
33+
iter_map_simplify,
3334
normalize_iter_map_to_expr,
3435
subspace_divide,
3536
inverse_affine_iter_map,

python/tvm/arith/iter_affine_map.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,49 @@ def detect_iter_map(
156156
)
157157

158158

159+
def iter_map_simplify(
160+
indices,
161+
input_iters,
162+
predicate=True,
163+
check_level=IterMapLevel.Surjective,
164+
simplify_trivial_iterators=True,
165+
):
166+
"""Simplify the indices using iter map detection.
167+
168+
Parameters
169+
----------
170+
indices : List[PrimExpr]
171+
The input indices
172+
173+
input_iters : Map[Var, Range]
174+
The domain of each input iterators.
175+
176+
predicate : PrimExpr
177+
The predicate constraints on the input iterators
178+
179+
check_level : Union[str, IterMapLevel]
180+
Checking level of iteration mapping
181+
182+
simplify_trivial_iterators: bool
183+
If true, iterators with extent of 1 will be replaced with a
184+
constant value.
185+
186+
Returns
187+
-------
188+
results : IterMapResult
189+
The iter map matching result.
190+
The result's .indices is empty array if no match can be found.
191+
192+
"""
193+
if isinstance(check_level, str):
194+
check_level = IterMapLevel.from_str(check_level)
195+
elif check_level is None:
196+
check_level = IterMapLevel.NoCheck
197+
return _ffi_api.IterMapSimplify(
198+
indices, input_iters, predicate, check_level, simplify_trivial_iterators
199+
)
200+
201+
159202
def normalize_iter_map_to_expr(expr):
160203
"""Given an IterMapExpr, transform it to normal PrimExpr
161204

src/arith/canonical_simplify.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "const_fold.h"
2929
#include "pattern_match.h"
30+
#include "product_normal_form.h"
3031
#include "rewrite_simplify.h"
3132

3233
namespace tvm {
@@ -808,12 +809,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
808809
}
809810

810811
// normal path.
812+
// this only happens when b is symbolic
811813
a = Normalize(a);
812814
b = Normalize(b);
813-
if (op->a.same_as(a) && op->b.same_as(b)) {
815+
816+
PrimExpr ret = MulAndNormalize(a, b);
817+
const MulNode* mul = ret.as<MulNode>();
818+
819+
if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) {
814820
return GetRef<PrimExpr>(op);
815821
} else {
816-
return Mul(a, b);
822+
return ret;
817823
}
818824
}
819825

0 commit comments

Comments
 (0)