diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..e0284f4a93c5 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -259,24 +259,13 @@ PrimFunc MakePackedAPI(PrimFunc func) { return res; }; - // Need to re-declare vars, in case some arguments also appears in the buffer. - std::vector> var_def; + // Need to delay binding of the buffers, in case some arguments also + // appear in the buffer. + std::vector> var_def; std::vector> buffer_def; for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { Var param = func_ptr->params[i]; - std::string param_name = [&]() { - std::ostringstream oss; - oss << "arg"; - if (param->name_hint.defined() && (!param->name_hint.empty())) { - oss << "." << param->name_hint; - - } else { - oss << i; - } - return oss.str(); - }(); - Var v_arg = Var(param_name, param->dtype); // Pluck the device API context out based on name if (param->name_hint == kDeviceContextVar) { @@ -285,19 +274,16 @@ PrimFunc MakePackedAPI(PrimFunc func) { continue; } + var_def.emplace_back(f_arg_value(param.dtype(), i), param); if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(v_arg, func_ptr->buffer_map[param]); - } else { - var_def.emplace_back(v_arg, param); + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } - // Value loads - seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks - Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); + Var tcode(param->name_hint + ".code", DataType::Int(32)); seq_init.emplace_back( LetStmt(tcode, BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}), nop)); - DataType t = v_arg.dtype(); + DataType t = param.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; @@ -327,8 +313,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { // either 0 or the original stride will be correctly used. Checks here have // to use the args that may have no let binding yet. Therefore, hoisting let // binding for args before buffer declaration is needed. - for (const auto& kv : var_def) { - binder.Bind(kv.second, kv.first, name_hint + "." + kv.first->name_hint, true); + for (const auto& [expr, param] : var_def) { + binder.Bind(param, expr, name_hint + "." + param->name_hint, true); } for (const auto& kv : buffer_def) { diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py b/tests/python/unittest/test_tir_transform_make_packed_api.py index cd27c0305c8b..8af7efb59604 100644 --- a/tests/python/unittest/test_tir_transform_make_packed_api.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -101,18 +101,15 @@ def test_variable_passed_from_args(): assert func.body.condition.b == 2 # Arguments unpacking - assignment = _find_assignment(func.body, "arg.input_buffer") + assignment = _find_assignment(func.body, "input_buffer") assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' - assignment = _find_assignment(func.body, "arg.not_device_context") - assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")' - - assignment = _find_assignment(func.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' + assignment = _find_assignment(assignment.body, "input_buffer") + assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var assignment = _find_assignment(func.body, "not_device_context") - assert str(assignment.value) == "arg_not_device_context" + assert str(assignment.value) == 'T.tvm_struct_get(args, 1, 12, "handle")' unpacked_not_device_context = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt) @@ -147,11 +144,11 @@ def test_device_api_context_implicit_resource_handle(): assert func.body.condition.b == 1 # Arguments unpacking - assignment = _find_assignment(func.body, "arg.input_buffer") + assignment = _find_assignment(func.body, "input_buffer") assert str(assignment.value) == 'T.tvm_struct_get(args, 0, 12, "handle")' - assignment = _find_assignment(func.body, "input_buffer") - assert str(assignment.value) == 'T.tvm_struct_get(arg_input_buffer, 0, 1, "handle")' + assignment = _find_assignment(assignment.body, "input_buffer") + assert str(assignment.value) == 'T.tvm_struct_get(input_buffer, 0, 1, "handle")' unpacked_input_buffer = assignment.var seq_stmt = _find_next(assignment, tvm.tir.SeqStmt)