Skip to content

Commit 871b3e3

Browse files
committed
Support printing StmtDoc in PythonDocPrinter
1 parent 95ff28c commit 871b3e3

File tree

4 files changed

+906
-6
lines changed

4 files changed

+906
-6
lines changed

src/script/printer/base_doc_printer.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
5858
PrintTypedDoc(GetRef<DictDoc>(doc_node));
5959
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
6060
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
61+
} else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
62+
PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
63+
} else if (const auto* doc_node = doc.as<AssignDocNode>()) {
64+
PrintTypedDoc(GetRef<AssignDoc>(doc_node));
65+
} else if (const auto* doc_node = doc.as<IfDocNode>()) {
66+
PrintTypedDoc(GetRef<IfDoc>(doc_node));
67+
} else if (const auto* doc_node = doc.as<WhileDocNode>()) {
68+
PrintTypedDoc(GetRef<WhileDoc>(doc_node));
69+
} else if (const auto* doc_node = doc.as<ForDocNode>()) {
70+
PrintTypedDoc(GetRef<ForDoc>(doc_node));
71+
} else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
72+
PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
73+
} else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
74+
PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
75+
} else if (const auto* doc_node = doc.as<AssertDocNode>()) {
76+
PrintTypedDoc(GetRef<AssertDoc>(doc_node));
77+
} else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
78+
PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
79+
} else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
80+
PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
81+
} else if (const auto* doc_node = doc.as<ClassDocNode>()) {
82+
PrintTypedDoc(GetRef<ClassDoc>(doc_node));
6183
} else {
6284
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
6385
throw;

src/script/printer/base_doc_printer.h

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,22 @@ class DocPrinter {
8484
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;
8585

8686
/*!
87-
* \brief Virtual method to print a IdDoc
87+
* \brief Virtual method to print an IdDoc
8888
*/
8989
virtual void PrintTypedDoc(const IdDoc& doc) = 0;
9090

9191
/*!
92-
* \brief Virtual method to print a AttrAccessDoc
92+
* \brief Virtual method to print an AttrAccessDoc
9393
*/
9494
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;
9595

9696
/*!
97-
* \brief Virtual method to print a IndexDoc
97+
* \brief Virtual method to print an IndexDoc
9898
*/
9999
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;
100100

101101
/*!
102-
* \brief Virtual method to print a OperationDoc
102+
* \brief Virtual method to print an OperationDoc
103103
*/
104104
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;
105105

@@ -133,6 +133,61 @@ class DocPrinter {
133133
*/
134134
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;
135135

136+
/*!
137+
* \brief Virtual method to print a StmtBlockDoc
138+
*/
139+
virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;
140+
141+
/*!
142+
* \brief Virtual method to print an AssignDoc
143+
*/
144+
virtual void PrintTypedDoc(const AssignDoc& doc) = 0;
145+
146+
/*!
147+
* \brief Virtual method to print an IfDoc
148+
*/
149+
virtual void PrintTypedDoc(const IfDoc& doc) = 0;
150+
151+
/*!
152+
* \brief Virtual method to print a WhileDoc
153+
*/
154+
virtual void PrintTypedDoc(const WhileDoc& doc) = 0;
155+
156+
/*!
157+
* \brief Virtual method to print a ForDoc
158+
*/
159+
virtual void PrintTypedDoc(const ForDoc& doc) = 0;
160+
161+
/*!
162+
* \brief Virtual method to print a ScopeDoc
163+
*/
164+
virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;
165+
166+
/*!
167+
* \brief Virtual method to print an ExprStmtDoc
168+
*/
169+
virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;
170+
171+
/*!
172+
* \brief Virtual method to print an AssertDoc
173+
*/
174+
virtual void PrintTypedDoc(const AssertDoc& doc) = 0;
175+
176+
/*!
177+
* \brief Virtual method to print a ReturnDoc
178+
*/
179+
virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;
180+
181+
/*!
182+
* \brief Virtual method to print a FunctionDoc
183+
*/
184+
virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;
185+
186+
/*!
187+
* \brief Virtual method to print a ClassDoc
188+
*/
189+
virtual void PrintTypedDoc(const ClassDoc& doc) = 0;
190+
136191
/*!
137192
* \brief Increase the indent level of any content to be
138193
* printed after this call

src/script/printer/python_doc_printer.cc

Lines changed: 211 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
2019
#include <tvm/runtime/logging.h>
2120
#include <tvm/runtime/registry.h>
21+
#include <tvm/script/printer/doc.h>
22+
23+
#include <algorithm>
24+
#include <string>
2225

2326
#include "../../support/str_escape.h"
27+
#include "../../support/utils.h"
2428
#include "./base_doc_printer.h"
2529

2630
namespace tvm {
@@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
4549
void PrintTypedDoc(const DictDoc& doc) final;
4650
void PrintTypedDoc(const TupleDoc& doc) final;
4751
void PrintTypedDoc(const SliceDoc& doc) final;
52+
void PrintTypedDoc(const StmtBlockDoc& doc) final;
53+
void PrintTypedDoc(const AssignDoc& doc) final;
54+
void PrintTypedDoc(const IfDoc& doc) final;
55+
void PrintTypedDoc(const WhileDoc& doc) final;
56+
void PrintTypedDoc(const ForDoc& doc) final;
57+
void PrintTypedDoc(const ExprStmtDoc& doc) final;
58+
void PrintTypedDoc(const AssertDoc& doc) final;
59+
void PrintTypedDoc(const ReturnDoc& doc) final;
60+
void PrintTypedDoc(const ScopeDoc& doc) final;
61+
void PrintTypedDoc(const FunctionDoc& doc) final;
62+
void PrintTypedDoc(const ClassDoc& doc) final;
4863

4964
private:
65+
void NewLineWithoutIndent() { output_ << "\n"; }
66+
5067
template <typename DocType>
5168
void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
5269
bool is_first = true;
@@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
5976
PrintDoc(doc);
6077
}
6178
}
79+
80+
void PrintIndentedBlock(const Array<StmtDoc>& docs) {
81+
IncreaseIndent();
82+
for (const StmtDoc& d : docs) {
83+
NewLine();
84+
PrintDoc(d);
85+
}
86+
if (docs.empty()) {
87+
NewLine();
88+
output_ << "pass";
89+
}
90+
DecreaseIndent();
91+
}
92+
93+
void PrintDecorators(const Array<ExprDoc>& decorators) {
94+
for (const ExprDoc& decorator : decorators) {
95+
output_ << "@";
96+
PrintDoc(decorator);
97+
NewLine();
98+
}
99+
}
100+
101+
void MaybePrintCommentInline(const StmtDoc& stmt) {
102+
if (stmt->comment.defined()) {
103+
const std::string& comment = stmt->comment.value();
104+
bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
105+
CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
106+
<< " cannot have newline.";
107+
output_ << " # " << comment;
108+
}
109+
}
110+
111+
void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
112+
if (stmt->comment.defined()) {
113+
std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
114+
for (const std::string& line : comment_lines) {
115+
output_ << "# " << line;
116+
NewLine();
117+
}
118+
}
119+
}
120+
121+
void PrintBlockComment(const String& comment) {
122+
IncreaseIndent();
123+
NewLine() << "\"\"\"";
124+
125+
std::vector<std::string> comment_lines = support::Split(comment, '\n');
126+
for (const std::string& line : comment_lines) {
127+
if (line.empty()) {
128+
// No indentation on empty line
129+
output_ << "\n";
130+
} else {
131+
NewLine() << line;
132+
}
133+
}
134+
135+
NewLine() << "\"\"\"";
136+
DecreaseIndent();
137+
}
62138
};
63139

64140
void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
@@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
260336
}
261337
}
262338

339+
void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
340+
for (const StmtDoc& stmt : doc->stmts) {
341+
PrintDoc(stmt);
342+
NewLine();
343+
}
344+
}
345+
346+
void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
347+
if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
348+
PrintJoinedDocs(tuple_doc->elements, ", ");
349+
} else {
350+
PrintDoc(doc->lhs);
351+
}
352+
353+
if (doc->annotation) {
354+
output_ << ": ";
355+
PrintDoc(doc->annotation.value());
356+
}
357+
if (doc->rhs) {
358+
output_ << " = ";
359+
PrintDoc(doc->rhs.value());
360+
}
361+
MaybePrintCommentInline(doc);
362+
}
363+
364+
void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
365+
MaybePrintCommentWithNewLine(doc);
366+
output_ << "if ";
367+
PrintDoc(doc->predicate);
368+
output_ << ":";
369+
370+
PrintIndentedBlock(doc->then_branch);
371+
372+
if (!doc->else_branch.empty()) {
373+
NewLine();
374+
output_ << "else:";
375+
PrintIndentedBlock(doc->else_branch);
376+
}
377+
}
378+
379+
void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
380+
MaybePrintCommentWithNewLine(doc);
381+
output_ << "while ";
382+
PrintDoc(doc->predicate);
383+
output_ << ":";
384+
385+
PrintIndentedBlock(doc->body);
386+
}
387+
388+
void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
389+
MaybePrintCommentWithNewLine(doc);
390+
output_ << "for ";
391+
PrintDoc(doc->lhs);
392+
output_ << " in ";
393+
PrintDoc(doc->rhs);
394+
output_ << ":";
395+
396+
PrintIndentedBlock(doc->body);
397+
}
398+
399+
void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
400+
MaybePrintCommentWithNewLine(doc);
401+
output_ << "with ";
402+
PrintDoc(doc->rhs);
403+
if (doc->lhs != nullptr) {
404+
output_ << " as ";
405+
PrintDoc(doc->lhs.value());
406+
}
407+
output_ << ":";
408+
409+
PrintIndentedBlock(doc->body);
410+
}
411+
412+
void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
413+
PrintDoc(doc->expr);
414+
MaybePrintCommentInline(doc);
415+
}
416+
417+
void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
418+
output_ << "assert ";
419+
PrintDoc(doc->test);
420+
if (doc->msg.defined()) {
421+
output_ << ", ";
422+
PrintDoc(doc->msg.value());
423+
}
424+
MaybePrintCommentInline(doc);
425+
}
426+
427+
void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
428+
output_ << "return ";
429+
PrintDoc(doc->value);
430+
MaybePrintCommentInline(doc);
431+
}
432+
433+
void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
434+
for (const AssignDoc& arg_doc : doc->args) {
435+
ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
436+
}
437+
438+
PrintDecorators(doc->decorators);
439+
440+
output_ << "def ";
441+
PrintDoc(doc->name);
442+
443+
output_ << "(";
444+
PrintJoinedDocs(doc->args, ", ");
445+
output_ << ")";
446+
447+
output_ << " -> ";
448+
PrintDoc(doc->return_type);
449+
450+
output_ << ":";
451+
452+
if (doc->comment.defined()) {
453+
PrintBlockComment(doc->comment.value());
454+
}
455+
PrintIndentedBlock(doc->body);
456+
NewLineWithoutIndent();
457+
}
458+
459+
void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
460+
PrintDecorators(doc->decorators);
461+
462+
output_ << "class ";
463+
PrintDoc(doc->name);
464+
output_ << ":";
465+
466+
if (doc->comment.defined()) {
467+
PrintBlockComment(doc->comment.value());
468+
}
469+
PrintIndentedBlock(doc->body);
470+
NewLineWithoutIndent();
471+
}
472+
263473
String DocToPythonScript(Doc doc, int indent_spaces) {
264474
PythonDocPrinter printer(indent_spaces);
265475
printer.Append(doc);

0 commit comments

Comments
 (0)