@@ -133,6 +133,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
133133 Doc VisitStmt_ (const BufferStoreNode* op) override ;
134134 Doc VisitStmt_ (const BufferRealizeNode* op) override ;
135135 Doc VisitStmt_ (const AllocateNode* op) override ;
136+ Doc VisitStmt_ (const AllocateConstNode* op) override ;
136137 Doc VisitStmt_ (const IfThenElseNode* op) override ;
137138 Doc VisitStmt_ (const SeqStmtNode* op) override ;
138139 Doc VisitStmt_ (const ForNode* op) override ;
@@ -246,6 +247,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
246247 }
247248 return doc;
248249 }
250+
251+ /* !
252+ * \brief special method to print NDArray in TIR
253+ * \param arr the NDArray to be printed
254+ * \param os the output stream where the NDArray will be printed to
255+ */
256+ template <typename T>
257+ void NDArrayToTIR (::tvm::runtime::NDArray arr, std::ostream& os) {
258+ int ndim = arr->ndim ;
259+ int tot_dim = 1 ;
260+ for (int i = 0 ; i < ndim; i++) {
261+ tot_dim *= arr->shape [i];
262+ }
263+ T* data_ptr = reinterpret_cast <T*>(arr->data );
264+ os << " [" ;
265+ for (int i = 0 ; i < tot_dim; i++) {
266+ os << data_ptr[i] << " , " ;
267+ }
268+ os << " ]" ;
269+ }
249270};
250271
251272Doc TVMScriptPrinter::GetUniqueName (std::string prefix) {
@@ -684,6 +705,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
684705 return Doc ();
685706}
686707
708+ Doc TVMScriptPrinter::VisitStmt_ (const AllocateConstNode* alloc) {
709+ std::stringstream ss;
710+
711+ if (alloc->dtype .is_int ()) {
712+ if (alloc->dtype .bits () == 8 ) {
713+ NDArrayToTIR<int8_t >(alloc->data , ss);
714+ } else if (alloc->dtype .bits () == 16 ) {
715+ NDArrayToTIR<int16_t >(alloc->data , ss);
716+ } else if (alloc->dtype .bits () == 32 ) {
717+ NDArrayToTIR<int32_t >(alloc->data , ss);
718+ } else {
719+ LOG (FATAL) << " DataType not supported" ;
720+ }
721+ } else if (alloc->dtype .is_float ()) {
722+ if (alloc->dtype .bits () == 16 ) {
723+ NDArrayToTIR<int16_t >(alloc->data , ss);
724+ } else if (alloc->dtype .bits () == 32 ) {
725+ NDArrayToTIR<float >(alloc->data , ss);
726+ } else if (alloc->dtype .bits () == 64 ) {
727+ NDArrayToTIR<double >(alloc->data , ss);
728+ } else {
729+ LOG (FATAL) << " DataType not supported" ;
730+ }
731+ } else {
732+ LOG (FATAL) << " DataType not supported" ;
733+ }
734+ auto ndarray_str = ss.str ();
735+
736+ Doc doc;
737+ var_not_in_headers.insert (alloc->buffer_var .get ());
738+ if (current_num_ != num_child_ - 1 ) {
739+ doc << " with tir.allocate_const(" << ndarray_str << " , " << PrintDType (alloc->dtype ) << " , "
740+ << Print (alloc->extents ) << " )" ;
741+ doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (alloc->body ));
742+ } else {
743+ doc << Print (alloc->buffer_var ) << " = tir.allocate_const(" << ndarray_str << " , "
744+ << PrintDType (alloc->dtype ) << " , " << Print (alloc->extents );
745+ doc << " )" << Doc::NewLine () << PrintBody (alloc->body );
746+ }
747+ return doc;
748+ }
749+
687750Doc TVMScriptPrinter::VisitStmt_ (const IfThenElseNode* op) {
688751 Doc doc;
689752 doc << " if " << Print (op->condition ) << " :" ;
0 commit comments