diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 15e483562a4b..db6e32d65f04 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -183,6 +183,8 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { LOG(FATAL) << "Dynamic allocation not supported"; } +void CodeGenStackVM::VisitStmt_(const DeclBufferNode* op) { VisitStmt(op->body); } + void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 87a61f76cc61..0bac55e3b2af 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -139,6 +139,7 @@ class CodeGenStackVM : public ExprFunctor, void VisitStmt_(const ForNode* op) final; void VisitStmt_(const IfThenElseNode* op) final; void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const DeclBufferNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AssertStmtNode* op) final; void VisitStmt_(const EvaluateNode* op) final; diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index 5667521dc659..d1a3c7217aa9 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -17,6 +17,8 @@ import tvm import tvm.testing from tvm import te +from tvm.script import tir as T, ir as I + import numpy as np @@ -122,8 +124,20 @@ def check(f): run_jit(mod, check) +def test_codegen_decl_buffer(): + """The codegen should accept DeclBuffer nodes in its input""" + + @I.ir_module + class mod: + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"global_symbol": "kernel"}) + A_buf = T.decl_buffer([256], dtype="float32", scope="global", data=A_data) + + target = tvm.target.Target("stackvm") + stackvm_codegen = tvm.get_global_func("target.build.stackvm") + stackvm_codegen(mod, target) + + if __name__ == "__main__": - test_vm_parallel() - test_stack_vm_loop() - test_stack_vm_basic() - test_stack_vm_cond() + tvm.testing.main()