3737#include < tvm/relay/attrs/annotation.h>
3838#include < tvm/relay/expr_functor.h>
3939#include < tvm/relay/pattern_functor.h>
40+ #include < tvm/target/se_scope.h>
4041#include < tvm/tir/function.h>
4142
4243#include " ../ir/attr_functor.h"
@@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
120121 return PrintPattern (Downcast<Pattern>(node), meta);
121122 } else if (node.as <IRModuleNode>()) {
122123 return PrintMod (Downcast<IRModule>(node));
123- } else if (!show_meta_data_ && node.as <BaseAttrsNode>()) {
124- // Show attributes in readable form.
125- return PrintAttrs (Downcast<Attrs>(node));
126124 } else {
127125 // default module.
128126 std::ostringstream os;
@@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
444442 for (Var param : fn->params ) {
445443 params.push_back (AllocVar (param));
446444 }
447- for (const Doc& d : PrintFuncAttrs (fn->attrs )) {
445+ for (const Doc& d : PrintDictAttrs (fn->attrs )) {
448446 params.push_back (d);
449447 }
450448 doc << Doc::Concat (params) << " ) " ;
@@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
684682 Doc doc;
685683 doc << " Tensor[(" ;
686684 std::vector<Doc> shapes;
687- for (ObjectRef shape : node->shape ) {
688- shapes.push_back (PrintAttr (shape));
685+ for (const PrimExpr& prim_expr : node->shape ) {
686+ // Though not bound within an attribute the attribute visitor will handle the PrimExprs we
687+ // care about.
688+ shapes.push_back (PrintAttributeValue (prim_expr));
689689 }
690690 doc << Doc::Concat (shapes);
691691 return doc << " ), " << PrintDType (node->dtype ) << " ]" ;
@@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
766766// Overload of Attr printing functions
767767// ------------------------------------
768768
769- Doc RelayTextPrinter::PrintAttr (const ObjectRef& value, bool meta) {
770- if (value.defined ()) {
771- Doc printed_attr;
772- if (value.as <tvm::tir::AnyNode>()) {
773- printed_attr << " ?" ;
774- } else if (auto str_obj = value.as <tvm::StringObj>()) {
775- printed_attr << Doc::StrLiteral (GetRef<String>(str_obj));
776- } else if (meta) {
777- printed_attr = meta_->GetMetaNode (Downcast<ObjectRef>(value));
778- } else {
779- printed_attr = VisitAttr (value);
780- }
781- return printed_attr;
782- } else {
783- return Doc::Text (" None" );
784- }
785- }
786-
787769Doc RelayTextPrinter::VisitAttrDefault_ (const Object* op) {
788- return PrintAttr (GetRef<ObjectRef>(op), /* meta=*/ true );
770+ // Since we don't have any overload for a specific attribute type we'll need to force
771+ // the meta[...] representation to avoid infinite regress.
772+ return PrintAttributeValue (GetRef<ObjectRef>(op), /* force_meta=*/ true );
789773}
790774
791775Doc RelayTextPrinter::VisitAttr_ (const ArrayNode* op) {
792776 Doc doc;
793777 doc << " [" ;
794778 std::vector<Doc> arr_vals;
795- for (auto val : *op) {
796- arr_vals.push_back (PrintAttr (val));
779+ for (const auto & val : *op) {
780+ arr_vals.push_back (PrintAttributeValue (val));
797781 }
798782 doc << Doc::Concat (arr_vals);
799783 doc << " ]" ;
@@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
831815 doc << key << " =" << *value << " f" ;
832816 docs->push_back (doc);
833817 }
818+
834819 void Visit (const char * key, int64_t * value) final { PrintKV (key, *value); }
835820 void Visit (const char * key, uint64_t * value) final { PrintKV (key, *value); }
836821 void Visit (const char * key, int * value) final { PrintKV (key, *value); }
@@ -844,58 +829,134 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
844829 LOG (FATAL) << " do not allow NDarray as argument" ;
845830 }
846831 void Visit (const char * key, runtime::ObjectRef* obj) final {
847- PrintKV (key, parent_->PrintAttr (*obj));
832+ PrintKV (key, parent_->PrintAttributeValue (*obj));
848833 }
849834
850835 private:
851836 std::vector<Doc>* docs;
852837 RelayTextPrinter* parent_;
853838};
854839
855- Doc RelayTextPrinter::PrintAttrs (const Attrs& attrs) {
856- std::vector<Doc> docs;
857- AttrPrinter printer (&docs, this );
858- const_cast <BaseAttrsNode*>(attrs.operator ->())->VisitNonDefaultAttrs (&printer);
859- Doc doc;
860- doc << " {" << Doc::Concat (docs) << " }" ;
861-
862- return doc;
840+ void RelayTextPrinter::AppendGenericAttrs (std::vector<Doc>* docs, const Attrs& attrs,
841+ bool include_type_key) {
842+ if (!attrs.defined ()) {
843+ return ;
844+ }
845+ AttrPrinter printer (docs, this );
846+ // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
847+ // case we are read-only.
848+ const_cast <BaseAttrsNode*>(attrs.get ())->VisitNonDefaultAttrs (&printer);
849+ if (include_type_key) {
850+ std::string s = attrs->GetTypeKey ();
851+ printer.Visit (" attrs_type_key" , &s);
852+ }
863853}
864854
865855std::vector<Doc> RelayTextPrinter::PrintCallAttrs (const Attrs& attrs, const Expr& op) {
866856 std::vector<Doc> docs;
867- if (!attrs.defined ()) return docs;
857+ if (!attrs.defined ()) {
858+ return docs;
859+ }
868860 const auto * op_node = op.as <OpNode>();
869861 if (show_meta_data_ && op_node && (attrs->type_index () != op_node->attrs_type_index )) {
870- // fallback
862+ // The parser can only understand calls with attributes if they match the operator's
863+ // declared attribute type. If that's not the case fall back to the meta[...] representation.
864+ docs.push_back (meta_->GetMetaNode (attrs));
865+ } else {
866+ AppendGenericAttrs (&docs, attrs, /* include_type_key=*/ !op_node);
867+ }
868+ return docs;
869+ }
870+
871+ std::vector<Doc> RelayTextPrinter::PrintDictAttrs (const DictAttrs& dict_attrs) {
872+ if (!dict_attrs.defined ()) {
873+ return {};
874+ }
875+ return PrintDictAttrs (dict_attrs->dict );
876+ }
877+
878+ std::vector<Doc> RelayTextPrinter::PrintDictAttrs (const Map<String, ObjectRef>& dict_attrs) {
879+ std::vector<Doc> docs;
880+ if (!dict_attrs.defined ()) {
881+ return docs;
882+ }
883+ for (const auto & k : dict_attrs) {
871884 Doc doc;
872- doc << meta_-> GetMetaNode (attrs );
885+ doc << k. first << " = " << PrintAttributeValue (k. second );
873886 docs.push_back (doc);
874- return docs;
875- } else {
876- // Show attributes in readable form.
877- AttrPrinter printer (&docs, this );
878- const_cast <BaseAttrsNode*>(attrs.operator ->())->VisitNonDefaultAttrs (&printer);
879- if (!op_node) {
880- // print call attr type key to restore expr for relay parser
881- std::string s = std::string (attrs->GetTypeKey ());
882- printer.Visit (" attrs_type_key" , &s);
887+ }
888+ return docs;
889+ }
890+
891+ Doc RelayTextPrinter::PrintAttributeValue (const ObjectRef& value, bool force_meta) {
892+ if (value.defined ()) {
893+ Doc printed_attr;
894+ if (value.as <tvm::tir::AnyNode>()) {
895+ printed_attr << " ?" ;
896+ } else if (auto str_obj = value.as <tvm::StringObj>()) {
897+ printed_attr << Doc::StrLiteral (GetRef<String>(str_obj));
898+ } else if (force_meta) {
899+ printed_attr = meta_->GetMetaNode (Downcast<ObjectRef>(value));
900+ } else if (const auto * se_scope_node = value.as <SEScopeNode>()) {
901+ if (show_meta_data_) {
902+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(se_scope_node));
903+ } else {
904+ // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
905+ // debugging.
906+ std::ostringstream os;
907+ os << GetRef<SEScope>(se_scope_node);
908+ return Doc::Text (os.str ());
909+ }
910+ } else if (const auto * base_attr_node = value.as <BaseAttrsNode>()) {
911+ if (show_meta_data_) {
912+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(base_attr_node));
913+ } else {
914+ // Special case: The non-meta form for attributes are much easier to work with while
915+ // debugging.
916+ printed_attr = PrintAttrsAsAttributeValue (GetRef<Attrs>(base_attr_node));
917+ }
918+ } else if (const auto * base_map_node = value.as <MapNode>()) {
919+ if (show_meta_data_) {
920+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(base_map_node));
921+ } else {
922+ // Special case: Show maps fields as key=value pairs to help debugging.
923+ printed_attr << PrintMapAsAttributeValue (GetRef<Map<ObjectRef, ObjectRef>>(base_map_node));
924+ }
925+ } else if (const auto * global_var_node = value.as <GlobalVarNode>()) {
926+ if (show_meta_data_) {
927+ printed_attr = meta_->GetMetaNode (GetRef<ObjectRef>(global_var_node));
928+ } else {
929+ printed_attr << " '" << global_var_node->name_hint << " '" ;
930+ }
931+ } else {
932+ printed_attr = VisitAttr (value);
883933 }
884- return docs;
934+ return printed_attr;
935+ } else {
936+ return Doc::Text (" None" );
885937 }
886938}
887939
888- std::vector< Doc> RelayTextPrinter::PrintFuncAttrs (const Attrs& attrs) {
940+ Doc RelayTextPrinter::PrintAttrsAsAttributeValue (const Attrs& attrs) {
889941 std::vector<Doc> docs;
890- if (!attrs.defined ()) return docs;
891- const auto * dict_attrs = attrs.as <DictAttrsNode>();
892- ICHECK (dict_attrs);
893- for (const auto & k : dict_attrs->dict ) {
942+ AppendGenericAttrs (&docs, attrs, /* include_type_key=*/ false );
943+ Doc doc;
944+ doc << " {" << Doc::Concat (docs) << " }" ;
945+ return doc;
946+ }
947+
948+ Doc RelayTextPrinter::PrintMapAsAttributeValue (const Map<ObjectRef, ObjectRef>& map) {
949+ std::vector<Doc> docs;
950+ for (const auto & k : map) {
894951 Doc doc;
895- doc << k.first << " =" << Print (k.second );
952+ doc << PrintAttributeValue (k.first );
953+ doc << " =" ;
954+ doc << PrintAttributeValue (k.second );
896955 docs.push_back (doc);
897956 }
898- return docs;
957+ Doc doc;
958+ doc << " {" << Doc::Concat (docs) << " }" ;
959+ return doc;
899960}
900961
901962Doc RelayTextPrinter::PrintSpan (const Span& span) {
0 commit comments