diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 837a3e6d3587..df7a88598532 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -145,6 +145,7 @@ class BuiltinLower : public StmtExprMutator { if (scope.max_sizes.shape_stack != -1) { scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, DataType::Int(64), "stack_shape"); + stmt = DeclBuffer(scope.stack_shape, stmt); stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt); } @@ -159,6 +160,7 @@ class BuiltinLower : public StmtExprMutator { stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); + stmt = DeclBuffer(scope.stack_tcode, stmt); stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), stmt); } diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 6eac5e90b553..cf2e3f045b63 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -71,7 +71,9 @@ def check_packed_func(target="llvm"): # Recursively visit PrimFunc until we meet the for-loop: while True: - if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): + if isinstance( + node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt, tvm.tir.DeclBuffer) + ): node = node.body elif isinstance(node, tvm.tir.SeqStmt): node = node[0] @@ -98,7 +100,7 @@ def check_packed_func(target="llvm"): # # let stack_value = tir.tvm_stack_alloca("arg_value", 4) # - alloca_value = alloca_tcode.body + alloca_value = alloca_tcode.body.body assert isinstance(alloca_value, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin(