Skip to content

Commit 5b9e9bd

Browse files
authored
[TIR] Keep trivial LetStmt in tir.Simplify when used in buffer decl (#14951)
Prior to this commit, any trivial let binding of `var1 = var2` is inlined. However, buffer definitions are not updated, so this can result in dangling `tir::Var` instances. This commit updates the `tir.Simplify` pass to keep trivial let bindings if they are used as part of a buffer definition. Ideally, the trivial `LetStmt` variable would be inlined into the buffer definition as well as other expressions. However, because a buffer may be implicitly declared, the first usage may be within a constrained context. If that happens, the simplified shape/strides expression cannot be used to update the buffer definition, as that simplification is not valid at all possible usage points of the buffer. ```python for i in range(n): elem_offset = i view = T.Buffer(1, data=buf, elem_offset = elem_offset) if i == 0: # First occurrence in TIR is here, where elem_offset would # simplify to zero. view[0] = 1 else: # But the same buffer is used here, where elem_offset doesn't # simplify to zero. view[0] = 2 ``` This will be resolvable after #14778 lands, requiring all buffers to be declared with `DeclBuffer` prior to usage. ```python for i in range(n): elem_offset = i # All variables used by the DeclBuffer are valid across the entire # body of the DeclBuffer. view = T.decl_buffer(1, data=buf, elem_offset = elem_offset) if i == 0: view[0] = 1 else: view[0] = 2 ```
1 parent 0af9ff9 commit 5b9e9bd

File tree

2 files changed

+123
-9
lines changed

2 files changed

+123
-9
lines changed

src/tir/transforms/simplify.cc

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "../../arith/ir_mutator_with_analyzer.h"
3535
#include "../../tir/analysis/control_flow_graph.h"
36+
#include "../../tir/analysis/var_use_def_analysis.h"
3637

3738
namespace tvm {
3839
namespace arith {
@@ -91,6 +92,46 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
9192
}
9293
};
9394

95+
/* \brief Utility function to collect vars that should be retained */
96+
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) {
97+
struct Visitor : StmtExprVisitor {
98+
using StmtExprVisitor::VisitExpr_;
99+
using StmtExprVisitor::VisitStmt_;
100+
101+
void VisitExpr_(const BufferLoadNode* op) override {
102+
VisitBuffer(op->buffer);
103+
StmtExprVisitor::VisitExpr_(op);
104+
}
105+
void VisitStmt_(const BufferStoreNode* op) override {
106+
VisitBuffer(op->buffer);
107+
StmtExprVisitor::VisitStmt_(op);
108+
}
109+
110+
void VisitBuffer(const Buffer& buf) {
111+
// Collect variables that should remain defined
112+
VarUseDefAnalyzer usage(Array<Var>{});
113+
usage(buf->data);
114+
for (const auto& dim : buf->shape) {
115+
usage(dim);
116+
}
117+
for (const auto& dim : buf->strides) {
118+
usage(dim);
119+
}
120+
usage(buf->elem_offset);
121+
122+
// Track for use in LetStmtNode mutator
123+
for (const auto& var : usage.undefined_) {
124+
used_in_buffer_def_.insert(var.get());
125+
}
126+
}
127+
std::unordered_set<const VarNode*> used_in_buffer_def_;
128+
};
129+
130+
Visitor visitor;
131+
visitor(stmt);
132+
return visitor.used_in_buffer_def_;
133+
}
134+
94135
class SimplifyConfig : public Attrs {
95136
public:
96137
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
@@ -110,16 +151,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
110151
config->propagate_knowns_to_simplify_expressions) {
111152
touch_pattern = ControlFlowGraph(stmt);
112153
}
113-
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern));
154+
155+
std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
156+
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
157+
std::move(used_in_buffer_def));
114158
return simplifier(std::move(stmt));
115159
}
116160

117161
private:
118162
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
119-
std::optional<ControlFlowGraph> touch_pattern)
120-
: IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {}
163+
std::optional<ControlFlowGraph> touch_pattern,
164+
std::unordered_set<const VarNode*> used_in_buffer_def)
165+
: IRMutatorWithAnalyzer(analyzer),
166+
config_(config),
167+
touch_pattern_(touch_pattern),
168+
used_in_buffer_def_(used_in_buffer_def) {}
121169

122170
using Parent = IRMutatorWithAnalyzer;
171+
using Parent::VisitExpr_;
123172
using Parent::VisitStmt;
124173
using Parent::VisitStmt_;
125174

@@ -159,18 +208,36 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
159208

160209
Stmt VisitStmt_(const LetStmtNode* op) override {
161210
PrimExpr value = this->VisitExpr(op->value);
162-
if (CanInlineLetStmt(op)) {
163-
// it is fine to discard the let binding
164-
// because the call to simplify will always inline the var.
211+
bool can_inline = CanInlineLetStmt(op);
212+
if (can_inline) {
213+
// It is usually fine to discard the let binding because the
214+
// call to simplify will always inline the var.
215+
//
216+
// The exception is when the variable is used in a Buffer's
217+
// definition, as these are not updated by the simplification.
218+
// After DeclBuffer is required prior to use of a buffer,
219+
// simplifying can update the buffer definition as well. The
220+
// buffer can only be updated at its point of definition,
221+
// because the points of use may occur within contexts that
222+
// allow for additional simplifications (e.g. a buffer of shape
223+
// [i,j] whose first use occurs within "if i==1" should not have
224+
// its shape simplified to [1,j]).
165225
analyzer_->Bind(op->var, value);
166-
return this->VisitStmt(op->body);
167226
} else if (SideEffect(op->value) <= CallEffectKind::kPure) {
168227
// Even if we aren't replacing all occurrences, they may be
169228
// necessary for proving conditional statements.
170229
non_inlined_bindings_.Set(op->var, value);
171230
}
172231
Stmt body = this->VisitStmt(op->body);
173-
if (value.same_as(op->value) && body.same_as(op->body)) {
232+
233+
// TODO(Lunderberg): Update the Buffer object as part of
234+
// DeclBuffer updates, which will first require
235+
// https://github.com/apache/tvm/pull/14778.
236+
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
237+
238+
if (can_inline && !used_in_buffer_def) {
239+
return body;
240+
} else if (value.same_as(op->value) && body.same_as(op->body)) {
174241
return GetRef<Stmt>(op);
175242
} else {
176243
auto n = this->CopyOnWrite(op);
@@ -207,8 +274,10 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
207274
return Parent::VisitExpr_(op);
208275
}
209276

277+
PrimExpr VisitExpr_(const BufferLoadNode* op) override { return Parent::VisitExpr_(op); }
278+
210279
// eliminate useless stores
211-
Stmt VisitStmt_(const BufferStoreNode* op) final {
280+
Stmt VisitStmt_(const BufferStoreNode* op) override {
212281
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
213282
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
214283
if (load->buffer->data.same_as(store->buffer->data) &&
@@ -260,6 +329,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
260329

261330
Map<Var, PrimExpr> non_inlined_bindings_;
262331
Optional<Stmt> current_stmt_{NullOpt};
332+
std::unordered_set<const VarNode*> used_in_buffer_def_;
263333
};
264334

265335
} // namespace arith

tests/python/unittest/test_tir_transform_simplify.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,5 +1689,49 @@ def expected(A: T.Buffer(1, "int32")):
16891689
A[0] = 12
16901690

16911691

1692+
class TestSimplifyTrivialLetBufferVar(BaseBeforeAfter):
1693+
"""A LetStmt used in a buffer definition should be retained"""
1694+
1695+
def before(A_ptr: T.handle("float32")):
1696+
A_ptr_redef: T.handle("float32") = A_ptr
1697+
A = T.decl_buffer(1, "float32", data=A_ptr_redef)
1698+
A[0] = 42.0
1699+
1700+
expected = before
1701+
1702+
1703+
class TestSimplifyTrivialLetElemOffset(BaseBeforeAfter):
1704+
"""A LetStmt used in a buffer definition should be retained"""
1705+
1706+
def before(A_ptr: T.handle("float32"), A_offset: T.int32):
1707+
A_offset_redef = A_offset
1708+
A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
1709+
A[0] = 42.0
1710+
1711+
expected = before
1712+
1713+
1714+
class TestSimplifyTrivialLetShape(BaseBeforeAfter):
1715+
"""A LetStmt used in a buffer definition should be retained"""
1716+
1717+
def before(A_ptr: T.handle("float32"), A_size: T.int32):
1718+
A_size_redef = A_size
1719+
A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
1720+
A[0] = 42.0
1721+
1722+
expected = before
1723+
1724+
1725+
class TestSimplifyTrivialLetStride(BaseBeforeAfter):
1726+
"""A LetStmt used in a buffer definition should be retained"""
1727+
1728+
def before(A_ptr: T.handle("float32"), A_stride: T.int32):
1729+
A_stride_redef = A_stride
1730+
A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
1731+
A[0] = 42.0
1732+
1733+
expected = before
1734+
1735+
16921736
if __name__ == "__main__":
16931737
tvm.testing.main()

0 commit comments

Comments
 (0)