@@ -134,6 +134,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
134134 Doc VisitStmt_ (const BufferStoreNode* op) override ;
135135 Doc VisitStmt_ (const BufferRealizeNode* op) override ;
136136 Doc VisitStmt_ (const AllocateNode* op) override ;
137+ Doc VisitStmt_ (const AllocateConstNode* op) override ;
137138 Doc VisitStmt_ (const IfThenElseNode* op) override ;
138139 Doc VisitStmt_ (const SeqStmtNode* op) override ;
139140 Doc VisitStmt_ (const ForNode* op) override ;
@@ -247,6 +248,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
247248 }
248249 return doc;
249250 }
251+
252+ /* !
253+ * \brief special method to print NDArray in TIR
254+ * \param arr the NDArray to be printed
255+ * \param os the output stream where the NDArray will be printed to
256+ */
257+ template <typename T>
258+ void NDArrayToTIR (::tvm::runtime::NDArray arr, std::ostream& os) {
259+ int ndim = arr->ndim ;
260+ int tot_dim = 1 ;
261+ for (int i = 0 ; i < ndim; i++) {
262+ tot_dim *= arr->shape [i];
263+ }
264+ T* data_ptr = reinterpret_cast <T*>(arr->data );
265+ os << " [" ;
266+ for (int i = 0 ; i < tot_dim; i++) {
267+ os << data_ptr[i] << " , " ;
268+ }
269+ os << " ]" ;
270+ }
250271};
251272
252273Doc TVMScriptPrinter::GetUniqueName (std::string prefix) {
@@ -685,6 +706,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
685706 return Doc ();
686707}
687708
709+ Doc TVMScriptPrinter::VisitStmt_ (const AllocateConstNode* alloc) {
710+ std::stringstream ss;
711+
712+ if (alloc->dtype .is_int ()) {
713+ if (alloc->dtype .bits () == 8 ) {
714+ NDArrayToTIR<int8_t >(alloc->data , ss);
715+ } else if (alloc->dtype .bits () == 16 ) {
716+ NDArrayToTIR<int16_t >(alloc->data , ss);
717+ } else if (alloc->dtype .bits () == 32 ) {
718+ NDArrayToTIR<int32_t >(alloc->data , ss);
719+ } else {
720+ LOG (FATAL) << " DataType not supported" ;
721+ }
722+ } else if (alloc->dtype .is_float ()) {
723+ if (alloc->dtype .bits () == 16 ) {
724+ NDArrayToTIR<int16_t >(alloc->data , ss);
725+ } else if (alloc->dtype .bits () == 32 ) {
726+ NDArrayToTIR<float >(alloc->data , ss);
727+ } else if (alloc->dtype .bits () == 64 ) {
728+ NDArrayToTIR<double >(alloc->data , ss);
729+ } else {
730+ LOG (FATAL) << " DataType not supported" ;
731+ }
732+ } else {
733+ LOG (FATAL) << " DataType not supported" ;
734+ }
735+ auto ndarray_str = ss.str ();
736+
737+ Doc doc;
738+ var_not_in_headers.insert (alloc->buffer_var .get ());
739+ if (current_num_ != num_child_ - 1 ) {
740+ doc << " with tir.allocate_const(" << ndarray_str << " , " << PrintDType (alloc->dtype ) << " , "
741+ << Print (alloc->extents ) << " )" ;
742+ doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (alloc->body ));
743+ } else {
744+ doc << Print (alloc->buffer_var ) << " = tir.allocate_const(" << ndarray_str << " , "
745+ << PrintDType (alloc->dtype ) << " , " << Print (alloc->extents );
746+ doc << " )" << Doc::NewLine () << PrintBody (alloc->body );
747+ }
748+ return doc;
749+ }
750+
688751Doc TVMScriptPrinter::VisitStmt_ (const IfThenElseNode* op) {
689752 Doc doc;
690753 doc << " if " << Print (op->condition ) << " :" ;
0 commit comments