Skip to content

Commit 1b70fa2

Browse files
committed
[TVMScript] Updated buffer_var printing
LetStmt and AllocateNode can both be used to generate handles that are used in Buffer objects. In these cases, the Buffer declarations must go after the handle declaration, not in the function header.
1 parent 0c589aa commit 1b70fa2

File tree

1 file changed

+100
-26
lines changed

1 file changed

+100
-26
lines changed

src/printer/tvmscript_printer.cc

Lines changed: 100 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,58 @@ enum class ExprPrecedence : int {
6868
kUnknown = 7,
6969
};
7070

71+
/*! \brief Utility used for identifying usage of a buffer_var
72+
*
73+
* \details Find the Buffer object that corresponds to a variable or
74+
* allocation, based on the BufferLoad/BufferStore instances that
75+
* occur within the allocation's body.
76+
*/
77+
class BufferUsageFinder : public StmtExprVisitor {
78+
public:
79+
static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
80+
BufferUsageFinder visitor(std::move(usage));
81+
visitor.VisitStmt(body);
82+
return std::move(visitor.usage_);
83+
}
84+
85+
void VisitExpr_(const VarNode* op) final {
86+
Var var = GetRef<Var>(op);
87+
if (!usage_.count(var)) {
88+
usage_.Set(var, {});
89+
}
90+
}
91+
92+
void VisitExpr_(const BufferLoadNode* op) final {
93+
VisitBuffer(op->buffer);
94+
StmtExprVisitor::VisitExpr_(op);
95+
}
96+
97+
void VisitStmt_(const BufferStoreNode* op) final {
98+
VisitBuffer(op->buffer);
99+
StmtExprVisitor::VisitStmt_(op);
100+
}
101+
102+
private:
103+
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}
104+
105+
void VisitBuffer(const Buffer& buffer) {
106+
if (buffers_visited_.count(buffer.get())) {
107+
return;
108+
}
109+
buffers_visited_.insert(buffer.get());
110+
111+
Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
112+
arr.push_back(buffer);
113+
usage_.Set(buffer->data, arr);
114+
}
115+
116+
// The search result.
117+
Map<Var, Array<Buffer>> usage_;
118+
// The buffers that have been visited so far, to avoid duplicate
119+
// entries in the search result.
120+
std::unordered_set<const BufferNode*> buffers_visited_;
121+
};
122+
71123
/*!
72124
* \brief The printer for TVMScript
73125
* \details The printer obtain the precedence of the top-level operation when printing each
@@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
138190
* 3. The iter range is equal to loop range
139191
*/
140192
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
193+
/*!
194+
* \brief Map from variables to the buffers they are used in.
195+
*
196+
* Used for identifying buffers that should be declared after the
197+
* LetStmt or Allocate that generates their data pointer, rather
198+
* than in the header.
199+
*/
200+
Map<Var, Array<Buffer>> buffer_var_usage_;
141201

142202
Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
143203
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
@@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
201261
Doc PrintRange(const RangeNode* op);
202262
Doc PrintArray(const ArrayNode* op);
203263
Doc PrintBuffer(const BufferNode* op);
264+
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
204265
Doc AllocBufferDeclaration(const Buffer& buf);
205266
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
206267
Doc PrintBlockVarRemaps();
@@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
830891
Doc doc;
831892
if (current_num_ != num_child_ - 1) {
832893
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
833-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
894+
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
895+
<< PrintBody(op->body));
834896
} else {
835897
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
836898
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
837-
<< Doc::NewLine() << PrintBody(op->body);
899+
<< Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
900+
<< PrintBody(op->body);
838901
}
839902
return doc;
840903
}
@@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
923986

924987
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
925988
var_not_in_headers_.insert(op->buffer_var.get());
926-
Doc doc;
989+
927990
auto storage_scope = GetPtrStorageScope(op->buffer_var);
991+
Doc func_call;
992+
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
993+
<< ", " << Print(storage_scope);
994+
if (!is_one(op->condition)) {
995+
func_call << ", " << Print(op->condition);
996+
}
997+
if (!op->annotations.empty()) {
998+
func_call << ", annotations={";
999+
func_call << PrintAnnotations(op->annotations);
1000+
func_call << "}";
1001+
}
1002+
func_call << ")";
1003+
1004+
Doc doc;
9281005
if (current_num_ != num_child_ - 1) {
929-
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
930-
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
931-
if (!is_one(op->condition)) {
932-
doc << ", " << Print(op->condition);
933-
}
934-
if (!op->annotations.empty()) {
935-
doc << ", annotations={";
936-
doc << PrintAnnotations(op->annotations);
937-
doc << "}";
938-
}
939-
doc << ") as " << Print(op->buffer_var) << ":";
940-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1006+
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
1007+
doc << Doc::Indent(4, Doc::NewLine()
1008+
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
1009+
<< PrintBody(op->body));
9411010
} else {
942-
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
943-
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
944-
if (!is_one(op->condition)) {
945-
doc << ", " << Print(op->condition);
946-
}
947-
if (!op->annotations.empty()) {
948-
doc << ", annotations={";
949-
doc << PrintAnnotations(op->annotations);
950-
doc << "}";
951-
}
952-
doc << ")" << Doc::NewLine() << PrintBody(op->body);
1011+
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine()
1012+
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
9531013
}
9541014
TryDeallocVar(op->buffer_var);
9551015
return doc;
@@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
14581518
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
14591519
}
14601520

1521+
Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
1522+
if (!buffer_var_usage_.count(buffer_var)) {
1523+
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body);
1524+
}
1525+
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
1526+
Doc decls;
1527+
for (const auto& buf_usage : buffer_usage) {
1528+
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
1529+
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
1530+
buf_not_in_headers_.insert(buf_usage.get());
1531+
}
1532+
return decls;
1533+
}
1534+
14611535
Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
14621536
Doc doc;
14631537
if (op->region.size() == 0) {

0 commit comments

Comments
 (0)