diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9997a4d95694..949f92051c14 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -575,6 +575,73 @@ class Allocate : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); }; +/*! + * \brief Allocate a buffer that can be used in body. + */ +class AllocateConstNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Var buffer_var; + /*! \brief The data associated to the constant. */ + ::tvm::runtime::NDArray data; + /*! \brief The type of the buffer. */ + DataType dtype; + /*! \brief The extents of the buffer. */ + Array extents; + /*! \brief The body to be executed. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer_var", &buffer_var); + v->Visit("dtype", &dtype); + v->Visit("extents", &extents); + v->Visit("body", &body); + v->Visit("span", &span); + } + + bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const { + return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && + equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(buffer_var); + hash_reduce(dtype); + hash_reduce(extents); + hash_reduce(body); + hash_reduce(data); + } + + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \return The result. + */ + int32_t constant_allocation_size() const { return constant_allocation_size(extents); } + /*! + * \brief If the buffer size is constant, return the size. + * Otherwise return 0. + * \param extents The extents of the buffer. + * \return The result. + */ + TVM_DLL static int32_t constant_allocation_size(const Array& extents); + + static constexpr const char* _type_key = "tir.AllocateConst"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode); +}; + +/*! + * \brief Managed reference to AllocateNode. + * \sa AllocateNode + */ +class AllocateConst : public Stmt { + public: + TVM_DLL AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype, + Array extents, Stmt body, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); +}; + /*! * \brief The container of seq statement. * Represent a sequence of statements. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 24773a5a471f..6e11a9131564 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -87,6 +87,7 @@ class StmtFunctor { virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -113,6 +114,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(ForNode); IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); + IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); @@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; + void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; @@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index d07209485bd4..1b7c522995b4 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -24,6 +24,7 @@ from tvm.runtime import Object from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind +import numpy as np from .context_maintainer import ContextMaintainer from .utils import ( @@ -147,6 +148,52 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): context.update_symbol(name, self.buffer_var, node) +@register +class AllocateConst(WithScopeHandler): + """With scope handler tir.allocate(data, extents, dtype, condition)""" + + def __init__(self): + def allocate_const(raw_data, dtype, shape, span=None): + list_data = [] + for i in raw_data: + list_data.append(i.value) + nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) + + n = tvm.tir.AllocateConst(self.buffer_var, nd_data, dtype, shape, self.body, span=span) + return n + + super().__init__(allocate_const, concise_scope=True, def_symbol=True) + self.buffer_var = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define buffer vars in symbol table + if isinstance(node, ast.With): + vars = WithScopeHandler.get_optional_vars(node, context) + if len(vars) != 1: + context.report_error("Unexpected number of vars", node.span) + name = vars[0].id.name + var_span = vars[0].id.span + elif isinstance(node, ast.Assign): + name = node.lhs.id.name + var_span = node.lhs.id.span + else: + raise Exception("Internal Bug") + + def setup_buffer_var(data, dtype, shape, span: Span = None): + """Setup buffer var for a given type.""" + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) + + setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) + context.update_symbol(name, self.buffer_var, node) + + @register class LaunchThread(WithScopeHandler): """With scope handler tir.launch_thread(env_var, extent)""" diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index eb200df0c599..bf3e4bf6fed1 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -28,7 +28,15 @@ from .expr import Call, CallEffectKind, Let, IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For -from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt +from .stmt import ( + BufferStore, + BufferRealize, + Store, + ProducerStore, + Allocate, + AllocateConst, + AttrStmt, +) from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index d57077f08b52..f9f645842910 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -328,6 +328,40 @@ def __init__(self, buffer_var, dtype, extents, condition, body, span=None): ) +@tvm._ffi.register_object("tir.AllocateConst") +class AllocateConst(Stmt): + """Allocate constant node. + + Parameters + ---------- + buffer_var : Var + The buffer variable. + + data : NDarray + The data associated with the constant + + dtype : str + The data type of the buffer. + + extents : list of Expr + The extents of the allocate + + condition : PrimExpr + The condition. + + body : Stmt + The body statement. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer_var, dtype, extents, condition, body, span=None): + self.__init_handle_by_constructor__( + _ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span + ) + + @tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 0332a2d539d2..d6c342e943ae 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -322,6 +322,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const BufferStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const AllocateConstNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; @@ -359,6 +360,7 @@ class TIRTextPrinter : public StmtFunctor, static Doc PrintConstScalar(DataType dtype, const T& data); Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); + Doc AllocConst(const AllocateConst& var); Doc AllocBuf(const Buffer& buffer); /*! * \brief special method to render vectors of docs with a separator diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0fefb0515e49..1210d05c2ac9 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -460,6 +460,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { + Doc doc; + doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e855712617ca..47db47ac145b 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -134,6 +134,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const BufferStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const AllocateConstNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; Doc VisitStmt_(const ForNode* op) override; @@ -247,6 +248,26 @@ class TVMScriptPrinter : public StmtFunctor, } return doc; } + + /*! + * \brief special method to print NDArray in TIR + * \param arr the NDArray to be printed + * \param os the output stream where the NDArray will be printed to + */ + template + void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + T* data_ptr = reinterpret_cast(arr->data); + os << "["; + for (int i = 0; i < tot_dim; i++) { + os << data_ptr[i] << ", "; + } + os << "]"; + } }; Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { @@ -685,6 +706,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { return Doc(); } +Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { + std::stringstream ss; + + if (alloc->dtype.is_int()) { + if (alloc->dtype.bits() == 8) { + NDArrayToTIR(alloc->data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(alloc->data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(alloc->data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (alloc->dtype.is_float()) { + if (alloc->dtype.bits() == 16) { + NDArrayToTIR(alloc->data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(alloc->data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(alloc->data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else { + LOG(FATAL) << "DataType not supported"; + } + auto ndarray_str = ss.str(); + + Doc doc; + var_not_in_headers.insert(alloc->buffer_var.get()); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", " + << Print(alloc->extents) << ")"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + } else { + doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", " + << PrintDType(alloc->dtype) << ", " << Print(alloc->extents); + doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + } + return doc; +} + Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << ":"; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 42ef60bb86d7..e67a0a099048 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -385,6 +385,68 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); +// Const +AllocateConst::AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype, + Array extents, Stmt body, Span span) { + CHECK(IsPointerType(buffer_var->type_annotation, dtype)) + << "The allocated data type (" << dtype + << ") does not match the type annotation of the buffer " << buffer_var << " (" + << buffer_var->type_annotation + << "). The data type should be an element of the pointer type."; + + for (size_t i = 0; i < extents.size(); ++i) { + ICHECK(extents[i].defined()); + ICHECK(extents[i].dtype().is_scalar()); + } + ICHECK(body.defined()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->body = std::move(body); + node->span = std::move(span); + node->data = data; + data_ = std::move(node); +} + +int32_t AllocateConstNode::constant_allocation_size(const Array& extents) { + int64_t result = 1; + for (size_t i = 0; i < extents.size(); ++i) { + if (const IntImmNode* int_size = extents[i].as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; + } + } else { + return 0; + } + } + return static_cast(result); +} + +TVM_REGISTER_GLOBAL("tir.AllocateConst") + .set_body_typed([](Var buffer_var, ::tvm::runtime::NDArray data, DataType type, + Array extents, Stmt body, Span span) { + return AllocateConst(buffer_var, data, type, extents, body, span); + }); + +TVM_REGISTER_NODE_TYPE(AllocateConstNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "constant " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } + p->stream << "]"; + p->stream << "\n"; + p->Print(op->body); + }); + // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, String storage_scope, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..949e8a1312aa 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -58,6 +58,11 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { this->VisitExpr(op->condition); } +void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { + VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const StoreNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->index); @@ -319,6 +324,20 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } } +Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { + Array extents = Internal::Mutate(this, op->extents); + Stmt body = this->VisitStmt(op->body); + + if (extents.same_as(op->extents) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->extents = std::move(extents); + n->body = std::move(body); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 0a9618cb8ccf..6f51cb6b8cf2 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -17,14 +17,18 @@ import tvm from tvm import tir +from tvm import ir from tvm.script import ty +import numpy as np + @tvm.script.tir class Module1: def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # buffer definition C_global = tir.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) packedB = tir.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) @@ -2888,6 +2892,25 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 +@tvm.script.tir +def constant(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (10), "int32") + C = tir.match_buffer(c, (10), "int32") + B = tir.alloc_buffer((10), "int32") + K = tir.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) + for x in tir.serial(0, 10): + B[x] = A[x] + tir.load("int32", K, x) + + for x in tir.serial(0, 10): + C[x] = B[x] + + +def test_const(): + func = constant + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + @tvm.script.tir def rank0(a: ty.handle) -> None: A = tir.match_buffer(a, (), "float32") @@ -2977,3 +3000,4 @@ def test_abs(): test_block_elements() test_opaque_block() test_abs() + test_const()