Skip to content

Commit a60cd0f

Browse files
authored
[TIR] Allow symbolic bounds in IndexMap analysis (#15264)
This PR adds the bounds of shape variables to the arithmetic analyzer so that it is possible to simplify certain expressions.
1 parent 0c1aad7 commit a60cd0f

35 files changed

+411
-181
lines changed

include/tvm/tir/index_map.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ class IndexMapNode : public Object {
102102
* \returns The indices in the output space. Contains one value for
103103
* each expression in `final_indices`.
104104
*/
105-
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
106-
arith::Analyzer* analyzer = nullptr) const;
105+
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices, arith::Analyzer* analyzer) const;
107106

108107
/*! \brief Map a memory range to the output space
109108
*
@@ -121,7 +120,7 @@ class IndexMapNode : public Object {
121120
* \returns The ranges in the output space. Contains one value for
122121
* each expression in `final_indices`.
123122
*/
124-
Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer = nullptr) const;
123+
Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer) const;
125124

126125
/*! \brief Map a buffer shape to the output space
127126
*
@@ -134,7 +133,7 @@ class IndexMapNode : public Object {
134133
* \returns The buffer shape in the output space. Contains one
135134
* value for each expression in `final_indices`.
136135
*/
137-
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer = nullptr) const;
136+
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer) const;
138137

139138
/* \brief Map an NDArray according to this index map
140139
*
@@ -203,7 +202,7 @@ class IndexMap : public ObjectRef {
203202
* If the user has supplied an `inverse_index_map`, that map is
204203
* assumed to be correct and bijective, and is returned.
205204
*/
206-
IndexMap Inverse(Array<Range> initial_ranges) const;
205+
IndexMap Inverse(Array<Range> initial_ranges, arith::Analyzer* analyzer) const;
207206

208207
/*! \brief Rename the variables in the index map and ensure the names are unique.
209208
*
@@ -225,7 +224,8 @@ class IndexMap : public ObjectRef {
225224
* \return The inverted index map, along with the predicate for
226225
* which the inverse maps to a valid range.
227226
*/
228-
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const;
227+
std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges,
228+
arith::Analyzer* analyzer) const;
229229

230230
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
231231
};

include/tvm/topi/transform.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#ifndef TVM_TOPI_TRANSFORM_H_
2525
#define TVM_TOPI_TRANSFORM_H_
2626

27+
#include <tvm/arith/analyzer.h>
2728
#include <tvm/te/operation.h>
2829
#include <tvm/tir/data_layout.h>
2930
#include <tvm/tir/index_map.h>
@@ -1738,16 +1739,18 @@ inline Tensor auto_scheduler_layout_transform(const Tensor& src, const String& s
17381739
inline Tensor meta_schedule_layout_transform(const Tensor& src, const tir::IndexMap& index_map,
17391740
const String name = "T_meta_schedule_layout_trans",
17401741
const String tag = kInjective) {
1742+
arith::Analyzer analyzer;
17411743
Array<Range> iter_domain;
17421744
iter_domain.reserve(src->shape.size());
17431745
for (const PrimExpr& e : src->shape) {
17441746
iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
17451747
}
1746-
Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
1748+
Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape, &analyzer);
17471749
return compute(
17481750
post_transform_shape,
1749-
[src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices) -> PrimExpr {
1750-
return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}));
1751+
[src, inv = index_map.Inverse(iter_domain, &analyzer),
1752+
&analyzer](const Array<Var>& indices) -> PrimExpr {
1753+
return src(inv->MapIndices(Array<PrimExpr>{indices.begin(), indices.end()}, &analyzer));
17511754
},
17521755
name, tag);
17531756
}

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
[tool.isort]
18+
profile = "black"
19+
src_paths = ["python", "tests/python"]
1720

1821
[tool.black]
1922
line-length = 100

python/tvm/te/schedule.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222

2323
import tvm._ffi
2424
from tvm._ffi.base import string_types
25-
26-
from tvm.runtime import Object, convert
2725
from tvm.ir import container as _container
28-
from tvm.tir import IterVar, Buffer, Var, IndexMap
26+
from tvm.runtime import Object, convert
27+
from tvm.tir import Buffer, IndexMap, IterVar, Var
2928

30-
from . import tensor as _tensor
3129
from . import _ffi_api
30+
from . import tensor as _tensor
3231

3332

3433
@tvm._ffi.register_object
@@ -600,7 +599,9 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
600599
"""
601600

602601
ndim = len(self.op.output(0).shape)
603-
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
602+
index_map, axis_separators = IndexMap.from_func_with_separators(
603+
mapping_function, ndim=ndim, index_dtype="int32"
604+
)
604605

605606
new_iter_vars = _ffi_api.StageTransformLayout(
606607
self, index_map.initial_indices, index_map.final_indices

python/tvm/tir/function.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(
6767
attrs=None,
6868
span=None,
6969
):
70-
7170
param_list = []
7271
buffer_map = {} if buffer_map is None else buffer_map
7372
for x in params:
@@ -266,6 +265,8 @@ def from_func(
266265
mapping_function: Callable,
267266
ndim: Optional[int] = None,
268267
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
268+
*,
269+
index_dtype: str = "int64",
269270
):
270271
"""Create an index map from a function
271272
@@ -302,7 +303,10 @@ def from_func(
302303
303304
"""
304305
index_map, axis_separators = IndexMap.from_func_with_separators(
305-
mapping_function, ndim, inverse_index_map
306+
mapping_function,
307+
ndim,
308+
inverse_index_map,
309+
index_dtype=index_dtype,
306310
)
307311
assert not axis_separators, (
308312
"The mapping_function provided to IndexMap.from_func "
@@ -316,6 +320,8 @@ def from_func_with_separators(
316320
mapping_function: Callable,
317321
ndim: Optional[int] = None,
318322
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
323+
*,
324+
index_dtype: str = "int64",
319325
):
320326
"""Create an index map from a function
321327
@@ -346,6 +352,9 @@ def from_func_with_separators(
346352
It is the user's responsibility to ensure the correctness of the pre-defined inverse
347353
index map.
348354
355+
index_dtype : str
356+
The default index dtype to use for input iters in the mapping function.
357+
349358
Returns
350359
-------
351360
ret: Tuple[IndexMap, List[int]]
@@ -361,20 +370,19 @@ def from_func_with_separators(
361370
args = []
362371
var_arg_name = None
363372
kwargs = collections.OrderedDict()
364-
default_index_dtype = "int32"
365373

366374
for name, param in params.items():
367375
if param.kind in [
368376
inspect.Parameter.POSITIONAL_ONLY,
369377
inspect.Parameter.POSITIONAL_OR_KEYWORD,
370378
]:
371-
args.append(tvm.tir.Var(name, default_index_dtype))
379+
args.append(tvm.tir.Var(name, index_dtype))
372380

373381
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
374382
var_arg_name = name
375383

376384
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
377-
kwargs[name] = tvm.tir.Var(name, default_index_dtype)
385+
kwargs[name] = tvm.tir.Var(name, index_dtype)
378386

379387
else:
380388
raise ValueError("transform_layout mapping may not have *args")
@@ -386,7 +394,7 @@ def from_func_with_separators(
386394
assert ndim is not None, "ndim must be specified when *args is used"
387395
num_var_args = ndim - len(args) - len(kwargs)
388396
for i in range(num_var_args):
389-
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", default_index_dtype))
397+
args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype))
390398

391399
mapping = mapping_function(*args, **kwargs)
392400

python/tvm/tir/schedule/schedule.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int:
9393
return seed
9494

9595

96+
def _get_block_default_dtype(block: Block) -> str:
97+
for i in block.iter_vars:
98+
return i.var.dtype
99+
for buffer_region in list(block.reads) + list(block.writes):
100+
for dom in buffer_region.region:
101+
return dom.min.dtype
102+
return "int64"
103+
104+
96105
@_register_object("tir.Schedule")
97106
class Schedule(Object):
98107
"""The user-facing schedule class
@@ -1492,7 +1501,10 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
14921501
block = self._normalize_block_arg(block)
14931502

14941503
if callable(index_map):
1495-
index_map = IndexMap.from_func(index_map)
1504+
index_map = IndexMap.from_func(
1505+
index_map,
1506+
index_dtype=_get_block_default_dtype(self.get(block)),
1507+
)
14961508
return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member
14971509
self, block, read_buffer_index, storage_scope, index_map
14981510
)
@@ -1589,7 +1601,10 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
15891601
block = self._normalize_block_arg(block)
15901602

15911603
if callable(index_map):
1592-
index_map = IndexMap.from_func(index_map)
1604+
index_map = IndexMap.from_func(
1605+
index_map,
1606+
index_dtype=_get_block_default_dtype(self.get(block)),
1607+
)
15931608
return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member
15941609
self, block, write_buffer_index, storage_scope, index_map
15951610
)
@@ -3246,14 +3261,22 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
32463261

32473262
ndim = len(buffer_obj.shape)
32483263
if callable(index_map):
3249-
index_map, axis_separators = IndexMap.from_func_with_separators(index_map, ndim=ndim)
3264+
index_map, axis_separators = IndexMap.from_func_with_separators(
3265+
index_map,
3266+
ndim=ndim,
3267+
index_dtype=_get_block_default_dtype(self.get(block)),
3268+
)
32503269
else:
32513270
axis_separators = []
32523271

32533272
if pad_value is None:
32543273
pass
32553274
elif callable(pad_value):
3256-
pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices))
3275+
pad_value = IndexMap.from_func(
3276+
pad_value,
3277+
ndim=len(index_map.final_indices),
3278+
index_dtype=_get_block_default_dtype(self.get(block)),
3279+
)
32573280
elif not isinstance(pad_value, IndexMap):
32583281
# Explicitly convert python int/float arguments to the
32593282
# buffer's type. If the default `tvm.runtime.convert`
@@ -3264,7 +3287,9 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
32643287
elif "float" in buffer_obj.dtype and isinstance(pad_value, float):
32653288
pad_value = FloatImm(buffer_obj.dtype, pad_value)
32663289
pad_value = IndexMap.from_func(
3267-
lambda *indices: pad_value, ndim=len(index_map.final_indices)
3290+
lambda *indices: pad_value,
3291+
ndim=len(index_map.final_indices),
3292+
index_dtype=_get_block_default_dtype(self.get(block)),
32683293
)
32693294

32703295
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
@@ -3337,7 +3362,10 @@ def after_transform_block_layout(
33373362
"""
33383363
block = self._normalize_block_arg(block)
33393364
if callable(index_map):
3340-
index_map = IndexMap.from_func(index_map)
3365+
index_map = IndexMap.from_func(
3366+
index_map,
3367+
index_dtype=_get_block_default_dtype(self.get(block)),
3368+
)
33413369
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
33423370
self, block, index_map
33433371
)

python/tvm/tir/schedule/testing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def verify_trace_roundtrip(
5151
The text format or formats whose round-trip behavior should be
5252
validated. If a single string, validate round-trips through
5353
"""
54+
from tvm.script import tir as T # pylint: disable=import-outside-toplevel
55+
5456
if not isinstance(text_format, str):
5557
for opt in text_format:
5658
new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt)
@@ -66,7 +68,9 @@ def verify_trace_roundtrip(
6668
Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
6769
elif text_format == "python":
6870
py_trace = "\n".join(trace.as_python())
69-
exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used
71+
vars_dict = {"T": T}
72+
vars_dict.update(tvm.tir.__dict__)
73+
exec(py_trace, vars_dict, {"sch": new_sch}) # pylint: disable=exec-used
7074
else:
7175
assert text_format in ("json", "python"), f"Unknown text format: {text_format}"
7276

src/arith/ir_mutator_with_analyzer.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*/
2323
#include "ir_mutator_with_analyzer.h"
2424

25+
#include <tvm/arith/iter_affine_map.h>
2526
#include <tvm/tir/analysis.h>
2627
#include <tvm/tir/op.h>
2728

@@ -39,6 +40,25 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) {
3940
}
4041
}
4142

43+
Array<PrimExpr> IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array<PrimExpr>& indices,
44+
bool non_trivial_only) {
45+
PrimExpr pred = const_true();
46+
for (PrimExpr val : iter_predicates_) {
47+
pred = pred && val;
48+
}
49+
int n = indices.size();
50+
Array<PrimExpr> simplified = arith::IterMapSimplify(
51+
indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_);
52+
if (non_trivial_only) {
53+
for (int i = 0; i < n; ++i) {
54+
if (simplified[i]->IsInstance<IntImmNode>() && indices[i]->IsInstance<VarNode>()) {
55+
simplified.Set(i, indices[i]);
56+
}
57+
}
58+
}
59+
return simplified;
60+
}
61+
4262
Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
4363
// record the loop variable as iterators
4464
Range dom = Range::FromMinExtent(op->min, op->extent);

src/arith/ir_mutator_with_analyzer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
7070
*/
7171
void MarkBufferMapShapes(const tir::PrimFunc& func);
7272

73+
/*!
74+
* \brief Use internal bound information to perform inter map simplification of indices.
75+
* \note Only do this during layout remapping
76+
*/
77+
Array<PrimExpr> IterMapSimplifyWithContext(const Array<PrimExpr>& indices, bool non_trivial_only);
78+
7379
/*! \brief internal analyzer field. */
7480
Analyzer* analyzer_;
7581
// the following two fields are useful in case we want

src/arith/iter_affine_map.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,12 @@ class IterMapToExprNormalizer : public ExprMutator {
20612061
if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
20622062
return source * expr->scale;
20632063
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
2064+
// Simplify if `expr` is always 0. The 2nd condition guarantess that we do not aggressively
2065+
// simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis
2066+
// like tensorization.
2067+
if (is_one(expr->extent) && !is_one(expr->source->extent)) {
2068+
return make_const(expr->extent->dtype, 0);
2069+
}
20642070
return floordiv(source, expr->lower_factor) * expr->scale;
20652071
} else {
20662072
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *

0 commit comments

Comments
 (0)