Skip to content

Commit 43f06ca

Browse files
authored
[TIR] Avoid re-defining var = arg_var in ArgBinder (#14952)
Prior to this commit, `ArgBinder` would always introduce a new variable to represent the input argument, even if the argument already a primitive type. This introduces trivial let bindings that are expected to be simplified out, but which can produce dangling `tir::Var` usage in some cases (see #14951). This commit updates `ArgBinder` to prefer using the original `tir::Var` when possible. That is, when a function takes `n: T.int32` as input, the packed function should produce a binding `n: T.int32 = T.tvm_struct_get(...)`, rather than producing a binding `arg_n = T.tvm_struct_get(...)` followed by `n = arg_n`.
1 parent 94c1b89 commit 43f06ca

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

src/tir/transforms/make_packed_api.cc

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -262,24 +262,13 @@ PrimFunc MakePackedAPI(PrimFunc func) {
262262
return res;
263263
};
264264

265-
// Need to re-declare vars, in case some arguments also appears in the buffer.
266-
std::vector<std::pair<Var, Var>> var_def;
265+
// Need to delay binding of the buffers, in case some arguments also
266+
// appear in the buffer.
267+
std::vector<std::pair<PrimExpr, Var>> var_def;
267268
std::vector<std::pair<Var, Buffer>> buffer_def;
268269

269270
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
270271
Var param = func_ptr->params[i];
271-
std::string param_name = [&]() {
272-
std::ostringstream oss;
273-
oss << "arg";
274-
if (param->name_hint.defined() && (!param->name_hint.empty())) {
275-
oss << "." << param->name_hint;
276-
277-
} else {
278-
oss << i;
279-
}
280-
return oss.str();
281-
}();
282-
Var v_arg = Var(param_name, param->dtype);
283272

284273
// Pluck the device API context out based on name
285274
if (param->name_hint == kDeviceContextVar) {
@@ -288,19 +277,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
288277
continue;
289278
}
290279

280+
var_def.emplace_back(f_arg_value(param.dtype(), i), param);
291281
if (func_ptr->buffer_map.count(param)) {
292-
buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]);
293-
} else {
294-
var_def.emplace_back(v_arg, param);
282+
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
295283
}
296284

297-
// Value loads
298-
seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop));
299285
// type code checks
300-
Var tcode(v_arg->name_hint + ".code", DataType::Int(32));
286+
Var tcode(param->name_hint + ".code", DataType::Int(32));
301287
seq_init.emplace_back(
302288
LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop));
303-
DataType t = v_arg.dtype();
289+
DataType t = param.dtype();
304290
if (t.is_handle()) {
305291
std::ostringstream msg;
306292
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
@@ -330,8 +316,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
330316
// either 0 or the original stride will be correctly used. Checks here have
331317
// to use the args that may have no let binding yet. Therefore, hoisting let
332318
// binding for args before buffer declaration is needed.
333-
for (const auto& kv : var_def) {
334-
binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint, true);
319+
for (const auto& [expr, param] : var_def) {
320+
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
335321
}
336322

337323
for (const auto& kv : buffer_def) {

tests/python/unittest/test_tir_transform_make_packed_api.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,15 @@ def test_variable_passed_from_args():
101101
assert func.body.condition.b == 2
102102

103103
# Arguments unpacking
104-
assignment = _find_assignment(func.body, "arg.input_buffer")
104+
assignment = _find_assignment(func.body, "input_buffer")
105105
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
106106

107-
assignment = _find_assignment(func.body, "arg.not_device_context")
108-
assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
109-
110-
assignment = _find_assignment(func.body, "input_buffer")
111-
assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")'
107+
assignment = _find_assignment(assignment.body, "input_buffer")
108+
assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")'
112109
unpacked_input_buffer = assignment.var
113110

114111
assignment = _find_assignment(func.body, "not_device_context")
115-
assert str(assignment.value) == "arg_not_device_context"
112+
assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")'
116113
unpacked_not_device_context = assignment.var
117114

118115
seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)
@@ -147,11 +144,11 @@ def test_device_api_context_implicit_resource_handle():
147144
assert func.body.condition.b == 1
148145

149146
# Arguments unpacking
150-
assignment = _find_assignment(func.body, "arg.input_buffer")
147+
assignment = _find_assignment(func.body, "input_buffer")
151148
assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")'
152149

153-
assignment = _find_assignment(func.body, "input_buffer")
154-
assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")'
150+
assignment = _find_assignment(assignment.body, "input_buffer")
151+
assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")'
155152
unpacked_input_buffer = assignment.var
156153

157154
seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)

0 commit comments

Comments
 (0)