@@ -68,6 +68,58 @@ enum class ExprPrecedence : int {
6868 kUnknown = 7 ,
6969};
7070
71+ /* ! \brief Utility used for identifying usage of a buffer_var
72+ *
73+ * \details Find the Buffer object that corresponds to a variable or
74+ * allocation, based on the BufferLoad/BufferStore instances that
75+ * occur within the allocation's body.
76+ */
77+ class BufferUsageFinder : public StmtExprVisitor {
78+ public:
79+ static Map<Var, Array<Buffer>> FindUsage (Map<Var, Array<Buffer>> usage, Stmt body) {
80+ BufferUsageFinder visitor (std::move (usage));
81+ visitor.VisitStmt (body);
82+ return std::move (visitor.usage_ );
83+ }
84+
85+ void VisitExpr_ (const VarNode* op) final {
86+ Var var = GetRef<Var>(op);
87+ if (!usage_.count (var)) {
88+ usage_.Set (var, {});
89+ }
90+ }
91+
92+ void VisitExpr_ (const BufferLoadNode* op) final {
93+ VisitBuffer (op->buffer );
94+ StmtExprVisitor::VisitExpr_ (op);
95+ }
96+
97+ void VisitStmt_ (const BufferStoreNode* op) final {
98+ VisitBuffer (op->buffer );
99+ StmtExprVisitor::VisitStmt_ (op);
100+ }
101+
102+ private:
103+ explicit BufferUsageFinder (Map<Var, Array<Buffer>> usage) : usage_(usage) {}
104+
105+ void VisitBuffer (const Buffer& buffer) {
106+ if (buffers_visited_.count (buffer.get ())) {
107+ return ;
108+ }
109+ buffers_visited_.insert (buffer.get ());
110+
111+ Array<Buffer> arr = usage_.Get (buffer->data ).value_or ({});
112+ arr.push_back (buffer);
113+ usage_.Set (buffer->data , arr);
114+ }
115+
116+ // The search result.
117+ Map<Var, Array<Buffer>> usage_;
118+ // The buffers that have been visited so far, to avoid duplicate
119+ // entries in the search result.
120+ std::unordered_set<const BufferNode*> buffers_visited_;
121+ };
122+
71123/* !
72124 * \brief The printer for TVMScript
73125 * \details The printer obtain the precedence of the top-level operation when printing each
@@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
138190 * 3. The iter range is equal to loop range
139191 */
140192 std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
193+ /* !
194+ * \brief Map from variables to the buffers they are used in.
195+ *
196+ * Used for identifying buffers that should be declared after the
197+ * LetStmt or Allocate that generates their data pointer, rather
198+ * than in the header.
199+ */
200+ Map<Var, Array<Buffer>> buffer_var_usage_;
141201
142202 Doc VisitExpr_ (const CastNode* op, ExprPrecedence* out_precedence) override ;
143203 Doc VisitExpr_ (const VarNode* op, ExprPrecedence* out_precedence) override ;
@@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
201261 Doc PrintRange (const RangeNode* op);
202262 Doc PrintArray (const ArrayNode* op);
203263 Doc PrintBuffer (const BufferNode* op);
264+ Doc PrintNonHeaderBufferDeclarations (Var buffer_var, Stmt body);
204265 Doc AllocBufferDeclaration (const Buffer& buf);
205266 Doc PrintBlockVar (const IterVar& iter_var, const PrimExpr& value);
206267 Doc PrintBlockVarRemaps ();
@@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
830891 Doc doc;
831892 if (current_num_ != num_child_ - 1 ) {
832893 doc << " with " << tir_prefix_ << " .let(" << Print (op->var ) << " , " << Print (op->value ) << " ):" ;
833- doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (op->body ));
894+ doc << Doc::Indent (4 , Doc::NewLine () << PrintNonHeaderBufferDeclarations (op->var , op->body )
895+ << PrintBody (op->body ));
834896 } else {
835897 if (memo_var_.find (op->var ) == memo_var_.end ()) var_not_in_headers_.insert (op->var .get ());
836898 doc << Print (op->var ) << " : " << Print (GetType (op->var )) << " = " << Print (op->value )
837- << Doc::NewLine () << PrintBody (op->body );
899+ << Doc::NewLine () << PrintNonHeaderBufferDeclarations (op->var , op->body )
900+ << PrintBody (op->body );
838901 }
839902 return doc;
840903}
@@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
923986
924987Doc TVMScriptPrinter::VisitStmt_ (const AllocateNode* op) {
925988 var_not_in_headers_.insert (op->buffer_var .get ());
926- Doc doc;
989+
927990 auto storage_scope = GetPtrStorageScope (op->buffer_var );
991+ Doc func_call;
992+ func_call << tir_prefix_ << " .allocate(" << Print (op->extents ) << " , " << PrintDType (op->dtype )
993+ << " , " << Print (storage_scope);
994+ if (!is_one (op->condition )) {
995+ func_call << " , " << Print (op->condition );
996+ }
997+ if (!op->annotations .empty ()) {
998+ func_call << " , annotations={" ;
999+ func_call << PrintAnnotations (op->annotations );
1000+ func_call << " }" ;
1001+ }
1002+ func_call << " )" ;
1003+
1004+ Doc doc;
9281005 if (current_num_ != num_child_ - 1 ) {
929- doc << " with " << tir_prefix_ << " .allocate(" << Print (op->extents ) << " , "
930- << PrintDType (op->dtype ) << " , " << Print (storage_scope);
931- if (!is_one (op->condition )) {
932- doc << " , " << Print (op->condition );
933- }
934- if (!op->annotations .empty ()) {
935- doc << " , annotations={" ;
936- doc << PrintAnnotations (op->annotations );
937- doc << " }" ;
938- }
939- doc << " ) as " << Print (op->buffer_var ) << " :" ;
940- doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (op->body ));
1006+ doc << " with " << func_call << " as " << Print (op->buffer_var ) << " :" ;
1007+ doc << Doc::Indent (4 , Doc::NewLine ()
1008+ << PrintNonHeaderBufferDeclarations (op->buffer_var , op->body )
1009+ << PrintBody (op->body ));
9411010 } else {
942- doc << Print (op->buffer_var ) << " = " << tir_prefix_ << " .allocate(" << Print (op->extents )
943- << " , " << PrintDType (op->dtype ) << " , " << Print (storage_scope);
944- if (!is_one (op->condition )) {
945- doc << " , " << Print (op->condition );
946- }
947- if (!op->annotations .empty ()) {
948- doc << " , annotations={" ;
949- doc << PrintAnnotations (op->annotations );
950- doc << " }" ;
951- }
952- doc << " )" << Doc::NewLine () << PrintBody (op->body );
1011+ doc << Print (op->buffer_var ) << " = " << func_call << Doc::NewLine ()
1012+ << PrintNonHeaderBufferDeclarations (op->buffer_var , op->body ) << PrintBody (op->body );
9531013 }
9541014 TryDeallocVar (op->buffer_var );
9551015 return doc;
@@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
14581518 return meta_.InMeta (buffer) ? meta_.GetMetaNode (buffer) : AllocBuf (buffer);
14591519}
14601520
1521+ Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations (Var buffer_var, Stmt body) {
1522+ if (!buffer_var_usage_.count (buffer_var)) {
1523+ buffer_var_usage_ = BufferUsageFinder::FindUsage (std::move (buffer_var_usage_), body);
1524+ }
1525+ Array<Buffer> buffer_usage = buffer_var_usage_.Get (buffer_var).value_or ({});
1526+ Doc decls;
1527+ for (const auto & buf_usage : buffer_usage) {
1528+ decls << Print (buf_usage) << " = " << tir_prefix_ << " .buffer_decl("
1529+ << memo_buf_decl_[buf_usage] << " )" << Doc::NewLine ();
1530+ buf_not_in_headers_.insert (buf_usage.get ());
1531+ }
1532+ return decls;
1533+ }
1534+
14611535Doc TVMScriptPrinter::PrintBufferRegion (const BufferRegionNode* op) {
14621536 Doc doc;
14631537 if (op->region .size () == 0 ) {
0 commit comments