Skip to content

Commit 695f958

Browse files
authored
[TIR] Improve well-formed check's handling of match buffer (#16655)
* [TIR] Improve well-formed check's handling of match buffer - The `T.match_buffer` at the start of a function may contain repeated use of the same data var. For example, a function that must accept two `DLTensor` objects with the same backing allocation. - The `"buffer_bind_scope"` is an older style of match buffer, and may be the point of definition for variables. * Improved comment, added context.pop_back()
1 parent c00cc03 commit 695f958

File tree

4 files changed

+228
-43
lines changed

4 files changed

+228
-43
lines changed

src/tir/analysis/verify_well_formed.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
228228
using Verifier::Verifier;
229229

230230
private:
231+
using Verifier::Visit;
231232
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
232233
Verifier::Visit(prim_func, path);
233234
redefine_allowed_within_function_.clear();

src/tir/ir/tir_visitor_with_path.cc

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
7878
// variable has occurred. Therefore, to ensure that we only avoid
7979
// duplicate calls to VisitVarDef, these semantics need to be
8080
// checked.
81-
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params;
8281
std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context;
8382

8483
auto ppath = path->Attr("params");
8584
for (size_t i = 0; i < func->params.size(); i++) {
8685
context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i)));
87-
defined_params.insert(func->params[i]);
8886
}
8987

90-
auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr,
91-
ObjectPath path) {
92-
if (auto opt = expr.as<Var>()) {
93-
auto var = opt.value();
94-
if (!defined_params.count(var)) {
95-
context.push_back(WithDef(var, path));
96-
defined_params.insert(var);
97-
}
98-
}
99-
};
100-
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr,
101-
ObjectPath path) {
102-
for (size_t i = 0; i < arr.size(); i++) {
103-
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
104-
}
105-
};
106-
10788
auto buffer_map_path = path->Attr("buffer_map");
10889
for (size_t i = 0; i < func->params.size(); i++) {
10990
if (auto opt = func->buffer_map.Get(func->params[i])) {
11091
auto buf = opt.value();
11192
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
11293

113-
// A buffer in the buffer_map always defines its data pointer
114-
context.push_back(WithDef(buf->data, buf_path->Attr("data")));
115-
116-
// But other implicit definitions only apply if they weren't
117-
// provided as explicit parameters, and they weren't defined
118-
// implicitly by any previous buffer.
119-
try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape"));
120-
try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides"));
121-
try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset"));
94+
for (auto& def : WithMatchBufferDefs(buf, buf_path)) {
95+
context.push_back(std::move(def));
96+
}
12297
}
12398
}
12499

@@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
127102
for (size_t i = 0; i < func->params.size(); i++) {
128103
if (auto opt = func->buffer_map.Get(func->params[i])) {
129104
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
130-
EnterDef(opt.value(), buf_path);
105+
context.push_back(WithDef(opt.value(), buf_path));
131106
}
132107
}
133108

@@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) {
199174
void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) {
200175
Visit(op->value, path->Attr("value"));
201176

202-
std::optional<DefContext<IterVar>> context = std::nullopt;
177+
std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context;
203178
if (auto iter_var = op->node.as<IterVar>();
204179
iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) {
205180
// Some attributes serve as a source of definition for the
206181
// tir::Var they annotate.
207-
context = WithDef(iter_var.value(), path->Attr("node"));
182+
context.push_back(WithDef(iter_var.value(), path->Attr("node")));
183+
184+
} else if (op->attr_key == attr::buffer_bind_scope) {
185+
// The `attr::buffer_bind_scope` attribute defines a view into an
186+
// existing buffer, similar to the newer
187+
// `BlockNode::match_buffers` field. It requires the buffer being
188+
// viewed to be defined prior to the attribute. The
189+
// `attr::buffer_bind_scope` is the point of definition for the
190+
// `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
191+
// symbolic shapes used within `buffer_view that are not already
192+
// defined.
193+
Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
194+
ICHECK_EQ(arr.size(), 2U);
195+
Buffer buffer_view = Downcast<Buffer>(arr[0]);
196+
Buffer orig_buffer = Downcast<Buffer>(arr[1]);
197+
Visit(orig_buffer, path->Attr("node")->ArrayIndex(1));
198+
199+
for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) {
200+
context.push_back(std::move(var));
201+
}
202+
208203
} else if (auto expr = op->node.as<PrimExpr>()) {
209204
Visit(expr.value(), path->Attr("node"));
210205
}
211206
Visit(op->body, path->Attr("body"));
207+
208+
while (context.size()) {
209+
context.pop_back();
210+
}
212211
}
213212

214213
void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) {
@@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path)
250249
void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) {
251250
Visit(op->condition, path->Attr("condition"));
252251
Visit(op->bounds, path->Attr("bounds"));
253-
auto context = WithDef(op->buffer, path->Attr("buffer"));
252+
auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data"));
253+
Visit(op->buffer, path->Attr("buffer"));
254254
Visit(op->body, path->Attr("body"));
255255
}
256256

@@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
318318
for (size_t i = 0; i < op->match_buffers.size(); i++) {
319319
auto buf = op->match_buffers[i]->buffer;
320320
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
321-
auto buffer_strides_path = buffer_path->Attr("strides");
322-
context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
323-
// Define buffer strides and elem_offset if they are vars
324-
if (const auto* v = buf->elem_offset.as<VarNode>()) {
325-
context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
326-
}
327-
for (size_t i = 0; i < buf->strides.size(); ++i) {
328-
if (const auto* v = buf->strides[i].as<VarNode>()) {
329-
context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
330-
}
321+
322+
for (auto& def : WithMatchBufferDefs(buf, buffer_path)) {
323+
context.push_back(std::move(def));
331324
}
332-
context.push_back(WithDef(buf, buffer_path));
333325
}
334326
}
335327

src/tir/ir/tir_visitor_with_path.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
#include <tvm/tir/stmt_functor.h>
3030

3131
#include <exception>
32+
#include <optional>
33+
#include <unordered_set>
3234
#include <utility>
35+
#include <vector>
3336

3437
namespace tvm {
3538
namespace tir {
@@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
173176
// construction of the DefContext and the destruction, we avoid
174177
// this case and allow the first error to propagate upward.
175178
if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) {
179+
self_->in_scope_definitions_.erase(obj_);
176180
self_->ExitDef(obj_, path_);
177181
}
178182
}
@@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
182186

183187
DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path)
184188
: self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) {
189+
self_->in_scope_definitions_.insert(obj_);
185190
self_->EnterDef(obj_, path_);
186191
}
187192

@@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
203208
DefContext<T> WithDef(T obj, ObjectPath path) {
204209
return DefContext(this, obj, path);
205210
}
211+
212+
/* \brief Utility to track the scope of a node's definition. */
213+
template <typename T>
214+
std::optional<DefContext<T>> WithDefIfUndefined(T obj, ObjectPath path) {
215+
if (in_scope_definitions_.count(obj)) {
216+
return std::nullopt;
217+
} else {
218+
return WithDef(obj, path);
219+
}
220+
}
221+
222+
std::vector<DefContext<Var>> WithMatchBufferDefs(Buffer buf, ObjectPath path) {
223+
std::vector<DefContext<Var>> context;
224+
225+
auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) {
226+
if (auto opt = expr.as<Var>()) {
227+
auto var = opt.value();
228+
if (auto var_def = WithDefIfUndefined(var, path)) {
229+
context.push_back(std::move(var_def).value());
230+
}
231+
}
232+
};
233+
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](
234+
const Array<PrimExpr>& arr, ObjectPath path) {
235+
for (size_t i = 0; i < arr.size(); i++) {
236+
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
237+
}
238+
};
239+
240+
try_visit_implicit_var_def(buf->data, path->Attr("data"));
241+
try_visit_implicit_var_def_array(buf->shape, path->Attr("shape"));
242+
try_visit_implicit_var_def_array(buf->strides, path->Attr("strides"));
243+
try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset"));
244+
245+
return context;
246+
}
247+
248+
std::unordered_set<ObjectRef, ObjectPtrHash, ObjectPtrEqual> in_scope_definitions_;
206249
};
207250

208251
} // namespace tir

tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,5 +199,154 @@ def kernel_2(A: T.Buffer([256], "float32")):
199199
tvm.tir.analysis.verify_well_formed(mod)
200200

201201

202+
def test_multiple_buffer_arguments_may_share_allocation():
203+
"""T.match_buffer may re-use a data argument
204+
205+
Like the shape/strides/elem_offset fields in a buffer, the first
206+
occurrence of a `buffer->data` field defines it, and the
207+
occurrences are usages of that definition.
208+
"""
209+
210+
@I.ir_module
211+
class mod:
212+
@T.prim_func
213+
def func(A_handle: T.handle, B_handle: T.handle):
214+
A = T.match_buffer(A_handle, [256], "float32")
215+
B = T.match_buffer(B_handle, [256], "float32", data=A.data)
216+
217+
pass
218+
219+
tvm.tir.analysis.verify_well_formed(mod)
220+
221+
222+
def test_buffer_bind_scope_defines_buffer_obj():
223+
"""The "buffer_bind_scope" attribute defines a buffer view"""
224+
225+
@I.ir_module
226+
class mod:
227+
@T.prim_func
228+
def func(A: T.Buffer([256, 256], "float32")):
229+
230+
for tile_i, tile_j in T.grid(16, 16):
231+
B = T.Buffer([16, 16], "float32")
232+
T.attr(
233+
[B, A],
234+
"buffer_bind_scope",
235+
T.tvm_tuple(
236+
tile_i * 16,
237+
16,
238+
tile_j * 16,
239+
16,
240+
dtype="handle",
241+
),
242+
)
243+
for i, j in T.grid(16, 16):
244+
B[i, j] = 0.0
245+
246+
tvm.tir.analysis.verify_well_formed(mod)
247+
248+
249+
def test_buffer_bind_scope_defines_symbolic_variables():
250+
"""The "buffer_bind_scope" attribute may define symbolic variables"""
251+
252+
@I.ir_module
253+
class mod:
254+
@T.prim_func
255+
def func(A: T.Buffer([256, 256], "int32")):
256+
257+
for tile_i, tile_j in T.grid(16, 16):
258+
elem_offset = T.int32()
259+
B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
260+
T.attr(
261+
[B, A],
262+
"buffer_bind_scope",
263+
T.tvm_tuple(
264+
tile_i * 16,
265+
16,
266+
tile_j * 16,
267+
16,
268+
dtype="handle",
269+
),
270+
)
271+
for i, j in T.grid(16, 16):
272+
B[i, j] = elem_offset
273+
274+
tvm.tir.analysis.verify_well_formed(mod)
275+
276+
277+
def test_block_match_buffer_defines_buffer_obj():
278+
"""In a block, T.match_buffer defines a buffer view"""
279+
280+
@I.ir_module
281+
class mod:
282+
@T.prim_func
283+
def func(A: T.Buffer([256, 256], "float32")):
284+
for iters in T.grid(16, 16, 16, 16):
285+
with T.block("compute"):
286+
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
287+
B = T.match_buffer(
288+
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
289+
dtype="float32",
290+
)
291+
B[i, j] = 0.0
292+
293+
tvm.tir.analysis.verify_well_formed(mod)
294+
295+
296+
def test_block_match_buffer_defines_symbolic_variables():
297+
"""In a block, T.match_buffer may define symbolic variables"""
298+
299+
@I.ir_module
300+
class mod:
301+
@T.prim_func
302+
def func(A: T.Buffer([256, 256], "int32")):
303+
304+
for iters in T.grid(16, 16, 16, 16):
305+
with T.block("compute"):
306+
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
307+
308+
elem_offset = T.int32()
309+
B = T.match_buffer(
310+
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
311+
dtype="float32",
312+
elem_offset=elem_offset,
313+
)
314+
315+
B[i, j] = elem_offset
316+
317+
tvm.tir.analysis.verify_well_formed(mod)
318+
319+
320+
def test_buffer_realize_on_external_buffer_is_annotation():
321+
"""A T.realize statement on an existing buffer annotates the region used"""
322+
323+
@I.ir_module
324+
class mod:
325+
@T.prim_func
326+
def func(A: T.Buffer(256, "int32")):
327+
T.realize(A[0:16], "global")
328+
329+
for i in range(16):
330+
A[i] = 1
331+
332+
tvm.tir.analysis.verify_well_formed(mod)
333+
334+
335+
def test_buffer_realize_is_allocation():
336+
"""A T.realize statement on an fresh buffer allocates the buffer"""
337+
338+
@I.ir_module
339+
class mod:
340+
@T.prim_func
341+
def func():
342+
A = T.Buffer(256, "int32")
343+
T.realize(A[0:16], "global")
344+
345+
for i in range(16):
346+
A[i] = 1
347+
348+
tvm.tir.analysis.verify_well_formed(mod)
349+
350+
202351
if __name__ == "__main__":
203352
tvm.testing.main()

0 commit comments

Comments
 (0)