1616 * specific language governing permissions and limitations
1717 * under the License.
1818 */
19-
2019#include < tvm/runtime/logging.h>
2120#include < tvm/runtime/registry.h>
21+ #include < tvm/script/printer/doc.h>
22+
23+ #include < algorithm>
24+ #include < string>
2225
2326#include " ../../support/str_escape.h"
27+ #include " ../../support/utils.h"
2428#include " ./base_doc_printer.h"
2529
2630namespace tvm {
@@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
4549 void PrintTypedDoc (const DictDoc& doc) final ;
4650 void PrintTypedDoc (const TupleDoc& doc) final ;
4751 void PrintTypedDoc (const SliceDoc& doc) final ;
52+ void PrintTypedDoc (const StmtBlockDoc& doc) final ;
53+ void PrintTypedDoc (const AssignDoc& doc) final ;
54+ void PrintTypedDoc (const IfDoc& doc) final ;
55+ void PrintTypedDoc (const WhileDoc& doc) final ;
56+ void PrintTypedDoc (const ForDoc& doc) final ;
57+ void PrintTypedDoc (const ExprStmtDoc& doc) final ;
58+ void PrintTypedDoc (const AssertDoc& doc) final ;
59+ void PrintTypedDoc (const ReturnDoc& doc) final ;
60+ void PrintTypedDoc (const ScopeDoc& doc) final ;
61+ void PrintTypedDoc (const FunctionDoc& doc) final ;
62+ void PrintTypedDoc (const ClassDoc& doc) final ;
4863
4964 private:
65+ void NewLineWithoutIndent () { output_ << " \n " ; }
66+
5067 template <typename DocType>
5168 void PrintJoinedDocs (const Array<DocType>& docs, const std::string& separator) {
5269 bool is_first = true ;
@@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
5976 PrintDoc (doc);
6077 }
6178 }
79+
80+ void PrintIndentedBlock (const Array<StmtDoc>& docs) {
81+ IncreaseIndent ();
82+ for (const StmtDoc& d : docs) {
83+ NewLine ();
84+ PrintDoc (d);
85+ }
86+ if (docs.empty ()) {
87+ NewLine ();
88+ output_ << " pass" ;
89+ }
90+ DecreaseIndent ();
91+ }
92+
93+ void PrintDecorators (const Array<ExprDoc>& decorators) {
94+ for (const ExprDoc& decorator : decorators) {
95+ output_ << " @" ;
96+ PrintDoc (decorator);
97+ NewLine ();
98+ }
99+ }
100+
101+ void MaybePrintCommentInline (const StmtDoc& stmt) {
102+ if (stmt->comment .defined ()) {
103+ const std::string& comment = stmt->comment .value ();
104+ bool has_newline = std::find (comment.begin (), comment.end (), ' \n ' ) != comment.end ();
105+ CHECK (!has_newline) << " ValueError: the comment string of " << stmt->GetTypeKey ()
106+ << " cannot have newline." ;
107+ output_ << " # " << comment;
108+ }
109+ }
110+
111+ void MaybePrintCommentWithNewLine (const StmtDoc& stmt) {
112+ if (stmt->comment .defined ()) {
113+ std::vector<std::string> comment_lines = support::Split (stmt->comment .value (), ' \n ' );
114+ for (const std::string& line : comment_lines) {
115+ output_ << " # " << line;
116+ NewLine ();
117+ }
118+ }
119+ }
120+
121+ void PrintBlockComment (const String& comment) {
122+ IncreaseIndent ();
123+ NewLine () << " \"\"\" " ;
124+
125+ std::vector<std::string> comment_lines = support::Split (comment, ' \n ' );
126+ for (const std::string& line : comment_lines) {
127+ if (line.empty ()) {
128+ // No indentation on empty line
129+ output_ << " \n " ;
130+ } else {
131+ NewLine () << line;
132+ }
133+ }
134+
135+ NewLine () << " \"\"\" " ;
136+ DecreaseIndent ();
137+ }
62138};
63139
64140void PythonDocPrinter::PrintTypedDoc (const LiteralDoc& doc) {
@@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
260336 }
261337}
262338
339+ void PythonDocPrinter::PrintTypedDoc (const StmtBlockDoc& doc) {
340+ for (const StmtDoc& stmt : doc->stmts ) {
341+ PrintDoc (stmt);
342+ NewLine ();
343+ }
344+ }
345+
346+ void PythonDocPrinter::PrintTypedDoc (const AssignDoc& doc) {
347+ if (const auto * tuple_doc = doc->lhs .as <TupleDocNode>()) {
348+ PrintJoinedDocs (tuple_doc->elements , " , " );
349+ } else {
350+ PrintDoc (doc->lhs );
351+ }
352+
353+ if (doc->annotation ) {
354+ output_ << " : " ;
355+ PrintDoc (doc->annotation .value ());
356+ }
357+ if (doc->rhs ) {
358+ output_ << " = " ;
359+ PrintDoc (doc->rhs .value ());
360+ }
361+ MaybePrintCommentInline (doc);
362+ }
363+
364+ void PythonDocPrinter::PrintTypedDoc (const IfDoc& doc) {
365+ MaybePrintCommentWithNewLine (doc);
366+ output_ << " if " ;
367+ PrintDoc (doc->predicate );
368+ output_ << " :" ;
369+
370+ PrintIndentedBlock (doc->then_branch );
371+
372+ if (!doc->else_branch .empty ()) {
373+ NewLine ();
374+ output_ << " else:" ;
375+ PrintIndentedBlock (doc->else_branch );
376+ }
377+ }
378+
379+ void PythonDocPrinter::PrintTypedDoc (const WhileDoc& doc) {
380+ MaybePrintCommentWithNewLine (doc);
381+ output_ << " while " ;
382+ PrintDoc (doc->predicate );
383+ output_ << " :" ;
384+
385+ PrintIndentedBlock (doc->body );
386+ }
387+
388+ void PythonDocPrinter::PrintTypedDoc (const ForDoc& doc) {
389+ MaybePrintCommentWithNewLine (doc);
390+ output_ << " for " ;
391+ PrintDoc (doc->lhs );
392+ output_ << " in " ;
393+ PrintDoc (doc->rhs );
394+ output_ << " :" ;
395+
396+ PrintIndentedBlock (doc->body );
397+ }
398+
399+ void PythonDocPrinter::PrintTypedDoc (const ScopeDoc& doc) {
400+ MaybePrintCommentWithNewLine (doc);
401+ output_ << " with " ;
402+ PrintDoc (doc->rhs );
403+ if (doc->lhs != nullptr ) {
404+ output_ << " as " ;
405+ PrintDoc (doc->lhs .value ());
406+ }
407+ output_ << " :" ;
408+
409+ PrintIndentedBlock (doc->body );
410+ }
411+
412+ void PythonDocPrinter::PrintTypedDoc (const ExprStmtDoc& doc) {
413+ PrintDoc (doc->expr );
414+ MaybePrintCommentInline (doc);
415+ }
416+
417+ void PythonDocPrinter::PrintTypedDoc (const AssertDoc& doc) {
418+ output_ << " assert " ;
419+ PrintDoc (doc->test );
420+ if (doc->msg .defined ()) {
421+ output_ << " , " ;
422+ PrintDoc (doc->msg .value ());
423+ }
424+ MaybePrintCommentInline (doc);
425+ }
426+
427+ void PythonDocPrinter::PrintTypedDoc (const ReturnDoc& doc) {
428+ output_ << " return " ;
429+ PrintDoc (doc->value );
430+ MaybePrintCommentInline (doc);
431+ }
432+
433+ void PythonDocPrinter::PrintTypedDoc (const FunctionDoc& doc) {
434+ for (const AssignDoc& arg_doc : doc->args ) {
435+ ICHECK (arg_doc->comment == nullptr ) << " Function arg cannot have comment attached to them." ;
436+ }
437+
438+ PrintDecorators (doc->decorators );
439+
440+ output_ << " def " ;
441+ PrintDoc (doc->name );
442+
443+ output_ << " (" ;
444+ PrintJoinedDocs (doc->args , " , " );
445+ output_ << " )" ;
446+
447+ output_ << " -> " ;
448+ PrintDoc (doc->return_type );
449+
450+ output_ << " :" ;
451+
452+ if (doc->comment .defined ()) {
453+ PrintBlockComment (doc->comment .value ());
454+ }
455+ PrintIndentedBlock (doc->body );
456+ NewLineWithoutIndent ();
457+ }
458+
459+ void PythonDocPrinter::PrintTypedDoc (const ClassDoc& doc) {
460+ PrintDecorators (doc->decorators );
461+
462+ output_ << " class " ;
463+ PrintDoc (doc->name );
464+ output_ << " :" ;
465+
466+ if (doc->comment .defined ()) {
467+ PrintBlockComment (doc->comment .value ());
468+ }
469+ PrintIndentedBlock (doc->body );
470+ NewLineWithoutIndent ();
471+ }
472+
263473String DocToPythonScript (Doc doc, int indent_spaces) {
264474 PythonDocPrinter printer (indent_spaces);
265475 printer.Append (doc);
0 commit comments