Skip to content

Commit fe1090e

Browse files
authored
[TIR] IndexMap Simplification Constraints (#11342)
* [TIR] Added optional arith::Analyzer argument to IndexMap methods Simplifications done when applying a transformation may require iteration bounds from the caller scope. This is a C++ only feature, because `arith::Analyzer` doesn't inherit from `ObjectRef`, and cannot be passed through the FFI. * [TIR] Pass analyzer from TransformLayoutRewriter to IndexMap Avoid needing to simplify twice, now that IndexMap can accept the analyzer from the calling scope. * [TIR] Added BlockNode handling to IRMutatorWithAnalyzer Iteration variables defined in `BlockNode::iter_vars` may be useful for simplifications. This functionality was extracted from `TransformLayoutRewriter`.
1 parent 2b1e5ce commit fe1090e

File tree

5 files changed

+70
-30
lines changed

5 files changed

+70
-30
lines changed

include/tvm/tir/index_map.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333

3434
#include <utility>
3535

36+
namespace tvm {
37+
namespace arith {
38+
class Analyzer;
39+
}
40+
} // namespace tvm
41+
3642
namespace tvm {
3743
namespace tir {
3844

@@ -78,10 +84,14 @@ class IndexMapNode : public Object {
7884
* \param indices The indices in the input space. Should contain
7985
* one value for each variable in `initial_indices`.
8086
*
87+
* \param analyzer An optional analyzer to be used to simplify the
88+
* resulting expressions. If null, will use a fresh analyzer.
89+
*
8190
* \returns The indices in the output space. Contains one value for
8291
* each expression in `final_indices`.
8392
*/
84-
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices) const;
93+
Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
94+
arith::Analyzer* analyzer = nullptr) const;
8595

8696
/*! \brief Map a memory range to the output space
8797
*
@@ -93,20 +103,26 @@ class IndexMapNode : public Object {
93103
* \param ranges The ranges in the input space. Should contain one
94104
* value for each variable in `initial_indices`.
95105
*
106+
* \param analyzer An optional analyzer to be used to simplify the
107+
* resulting expressions. If null, will use a fresh analyzer.
108+
*
96109
* \returns The ranges in the output space. Contains one value for
97110
* each expression in `final_indices`.
98111
*/
99-
Array<Range> MapRanges(const Array<Range>& ranges) const;
112+
Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer = nullptr) const;
100113

101114
/*! \brief Map a buffer shape to the output space
102115
*
103116
* \param shape The buffer shape in the input space. Should contain
104117
* one value for each variable in `initial_indices`.
105118
*
119+
* \param analyzer An optional analyzer to be used to simplify the
120+
* resulting expressions. If null, will use a fresh analyzer.
121+
*
106122
* \returns The buffer shape in the output space. Contains one
107123
* value for each expression in `final_indices`.
108124
*/
109-
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape) const;
125+
Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer = nullptr) const;
110126

111127
/*!
112128
* \brief Convert to string representation in Python.

src/arith/ir_mutator_with_analyzer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
3535
return StmtExprMutator::VisitStmt_(op);
3636
}
3737

38+
Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
39+
for (const auto& iter_var : op->iter_vars) {
40+
analyzer_->Bind(iter_var->var, iter_var->dom);
41+
}
42+
return StmtExprMutator::VisitStmt_(op);
43+
}
44+
3845
Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
3946
PrimExpr value = this->VisitExpr(op->value);
4047
if (SideEffect(value) <= CallEffectKind::kPure) {

src/arith/ir_mutator_with_analyzer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
5050

5151
// override functions that need to populate the context information.
5252
tir::Stmt VisitStmt_(const tir::ForNode* op) override;
53+
tir::Stmt VisitStmt_(const tir::BlockNode* op) override;
5354
tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override;
5455
tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
5556
tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;

src/tir/ir/index_map.cc

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,29 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
159159
return IndexMap(output_vars, inverse_exprs);
160160
}
161161

162-
Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices) const {
162+
Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
163+
arith::Analyzer* analyzer) const {
163164
ICHECK_EQ(indices.size(), initial_indices.size());
164165

165-
arith::Analyzer analyzer;
166+
Map<Var, PrimExpr> vmap;
166167

167168
for (size_t i = 0; i < initial_indices.size(); i++) {
168-
analyzer.Bind(initial_indices[i], indices[i]);
169+
vmap.Set(initial_indices[i], indices[i]);
169170
}
170171

171-
Array<PrimExpr> output;
172-
for (const auto& output_dim : final_indices) {
173-
output.push_back(analyzer.Simplify(output_dim));
172+
arith::Analyzer local_analyzer;
173+
if (!analyzer) {
174+
analyzer = &local_analyzer;
174175
}
175176

177+
Array<PrimExpr> output = final_indices;
178+
output.MutateByApply(
179+
[&](const PrimExpr& index) { return analyzer->Simplify(Substitute(index, vmap)); });
180+
176181
return output;
177182
}
178183

179-
Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges) const {
184+
Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer) const {
180185
ICHECK_EQ(ranges.size(), initial_indices.size());
181186

182187
Map<Var, Range> input_iters;
@@ -189,25 +194,30 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges) const {
189194
dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
190195
}
191196

197+
arith::Analyzer local_analyzer;
198+
if (!analyzer) {
199+
analyzer = &local_analyzer;
200+
}
201+
192202
Array<Range> output;
193-
arith::Analyzer analyzer;
194203
for (const auto& final_index : final_indices) {
195204
auto int_set = arith::EvalSet(final_index, dom_map);
196-
output.push_back(Range::FromMinExtent(analyzer.Simplify(int_set.min()),
197-
analyzer.Simplify(int_set.max() - int_set.min() + 1)));
205+
output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
206+
analyzer->Simplify(int_set.max() - int_set.min() + 1)));
198207
}
199208

200209
return output;
201210
}
202211

203-
Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape) const {
212+
Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
213+
arith::Analyzer* analyzer) const {
204214
ICHECK_EQ(shape.size(), initial_indices.size());
205215

206216
Array<Range> ranges;
207217
for (auto& dim : shape) {
208218
ranges.push_back(Range(0, dim));
209219
}
210-
Array<Range> mapped = MapRanges(std::move(ranges));
220+
Array<Range> mapped = MapRanges(std::move(ranges), analyzer);
211221

212222
Array<PrimExpr> output;
213223
for (auto& range : mapped) {
@@ -265,8 +275,12 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
265275
return IndexMap(initial_indices, final_indices);
266276
});
267277

268-
TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices").set_body_method<IndexMap>(&IndexMapNode::MapIndices);
269-
TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_method<IndexMap>(&IndexMapNode::MapShape);
278+
TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
279+
.set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return map->MapIndices(indices); });
280+
281+
TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Array<PrimExpr> shape) {
282+
return map->MapShape(shape);
283+
});
270284
TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
271285

272286
TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include "../../../arith/ir_mutator_with_analyzer.h"
1920
#include "../utils.h"
2021

2122
namespace tvm {
2223
namespace tir {
2324

24-
class TransformLayoutRewriter : private StmtExprMutator {
25+
class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
2526
public:
2627
/*!
2728
* \brief Rewrite the access to the buffer after the transformation
@@ -36,27 +37,32 @@ class TransformLayoutRewriter : private StmtExprMutator {
3637
const Buffer& old_buffer,
3738
const Buffer& new_buffer,
3839
const IndexMap& index_map) {
39-
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map);
40+
arith::Analyzer analyzer;
41+
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, &analyzer);
4042
Stmt result = rewriter(scope_stmt);
4143
return {result, rewriter.block_sref_reuse_};
4244
}
4345

4446
private:
4547
TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
46-
const IndexMap& index_map)
47-
: old_buffer_(old_buffer),
48+
const IndexMap& index_map, arith::Analyzer* analyzer)
49+
: IRMutatorWithAnalyzer(analyzer),
50+
old_buffer_(old_buffer),
4851
new_buffer_(new_buffer),
4952
index_map_(index_map),
5053
buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {}
5154

5255
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
5356
*buffer = new_buffer_;
54-
*indices = index_map_->MapIndices(*indices);
55-
(*indices).MutateByApply([this](const PrimExpr& index) { return analyzer_.Simplify(index); });
57+
*indices = index_map_->MapIndices(*indices, analyzer_);
5658
}
5759

60+
using Parent = arith::IRMutatorWithAnalyzer;
61+
using Parent::VisitExpr_;
62+
using Parent::VisitStmt_;
63+
5864
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
59-
BufferLoad buffer_load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
65+
BufferLoad buffer_load = Downcast<BufferLoad>(Parent::VisitExpr_(op));
6066
if (buffer_load->buffer.same_as(old_buffer_)) {
6167
auto* n = buffer_load.CopyOnWrite();
6268
RewriteBufferAccess(&n->buffer, &n->indices);
@@ -65,7 +71,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
6571
}
6672

6773
Stmt VisitStmt_(const BufferStoreNode* op) final {
68-
BufferStore buffer_store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
74+
BufferStore buffer_store = Downcast<BufferStore>(Parent::VisitStmt_(op));
6975
if (buffer_store->buffer.same_as(old_buffer_)) {
7076
auto* n = buffer_store.CopyOnWrite();
7177
RewriteBufferAccess(&n->buffer, &n->indices);
@@ -86,10 +92,7 @@ class TransformLayoutRewriter : private StmtExprMutator {
8692
}
8793

8894
Stmt VisitStmt_(const BlockNode* op) final {
89-
for (const auto& iter_var : op->iter_vars) {
90-
analyzer_.Bind(iter_var->var, iter_var->dom);
91-
}
92-
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
95+
Block block = Downcast<Block>(Parent::VisitStmt_(op));
9396
auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
9497
auto* n = block.CopyOnWrite();
9598
RewriteAccessRegion(&n->reads, infered_access_regions[0]);
@@ -101,7 +104,6 @@ class TransformLayoutRewriter : private StmtExprMutator {
101104
const Buffer& old_buffer_;
102105
const Buffer& new_buffer_;
103106
const IndexMap& index_map_;
104-
arith::Analyzer analyzer_;
105107
Map<Var, Buffer> buffer_data_to_buffer_;
106108
Map<Block, Block> block_sref_reuse_;
107109
};

0 commit comments

Comments
 (0)