@@ -72,6 +72,107 @@ class AttrVisitor {
7272 // ! \endcond
7373};
7474
75+ // Attr getter.
76+ class AttrGetter : public AttrVisitor {
77+ public:
78+ const String& skey;
79+ runtime::TVMRetValue* ret;
80+
81+ AttrGetter (const String& skey, runtime::TVMRetValue* ret) : skey(skey), ret(ret) {}
82+
83+ bool found_ref_object{false };
84+
85+ void Visit (const char * key, double * value) final {
86+ if (skey == key) *ret = value[0 ];
87+ }
88+ void Visit (const char * key, int64_t * value) final {
89+ if (skey == key) *ret = value[0 ];
90+ }
91+ void Visit (const char * key, uint64_t * value) final {
92+ ICHECK_LE (value[0 ], static_cast <uint64_t >(std::numeric_limits<int64_t >::max ()))
93+ << " cannot return too big constant" ;
94+ if (skey == key) *ret = static_cast <int64_t >(value[0 ]);
95+ }
96+ void Visit (const char * key, int * value) final {
97+ if (skey == key) *ret = static_cast <int64_t >(value[0 ]);
98+ }
99+ void Visit (const char * key, bool * value) final {
100+ if (skey == key) *ret = static_cast <int64_t >(value[0 ]);
101+ }
102+ void Visit (const char * key, void ** value) final {
103+ if (skey == key) *ret = static_cast <void *>(value[0 ]);
104+ }
105+ void Visit (const char * key, DataType* value) final {
106+ if (skey == key) *ret = value[0 ];
107+ }
108+ void Visit (const char * key, std::string* value) final {
109+ if (skey == key) *ret = value[0 ];
110+ }
111+
112+ void Visit (const char * key, runtime::NDArray* value) final {
113+ if (skey == key) {
114+ *ret = value[0 ];
115+ found_ref_object = true ;
116+ }
117+ }
118+ void Visit (const char * key, runtime::ObjectRef* value) final {
119+ if (skey == key) {
120+ *ret = value[0 ];
121+ found_ref_object = true ;
122+ }
123+ }
124+ };
125+
126+ class NodeAttrSetter : public AttrVisitor {
127+ public:
128+ std::string type_key;
129+ std::unordered_map<std::string, runtime::TVMArgValue> attrs;
130+
131+ void Visit (const char * key, double * value) final { *value = GetAttr (key).operator double (); }
132+ void Visit (const char * key, int64_t * value) final { *value = GetAttr (key).operator int64_t (); }
133+ void Visit (const char * key, uint64_t * value) final { *value = GetAttr (key).operator uint64_t (); }
134+ void Visit (const char * key, int * value) final { *value = GetAttr (key).operator int (); }
135+ void Visit (const char * key, bool * value) final { *value = GetAttr (key).operator bool (); }
136+ void Visit (const char * key, std::string* value) final {
137+ *value = GetAttr (key).operator std::string ();
138+ }
139+ void Visit (const char * key, void ** value) final { *value = GetAttr (key).operator void *(); }
140+ void Visit (const char * key, DataType* value) final { *value = GetAttr (key).operator DataType (); }
141+ void Visit (const char * key, runtime::NDArray* value) final {
142+ *value = GetAttr (key).operator runtime::NDArray ();
143+ }
144+ void Visit (const char * key, ObjectRef* value) final {
145+ *value = GetAttr (key).operator ObjectRef ();
146+ }
147+
148+ runtime::TVMArgValue GetAttr (const char * key) {
149+ auto it = attrs.find (key);
150+ if (it == attrs.end ()) {
151+ LOG (FATAL) << type_key << " : require field " << key;
152+ }
153+ runtime::TVMArgValue v = it->second ;
154+ attrs.erase (it);
155+ return v;
156+ }
157+ };
158+
159+ // List names;
160+ class AttrDir : public AttrVisitor {
161+ public:
162+ std::vector<std::string>* names;
163+
164+ void Visit (const char * key, double * value) final { names->push_back (key); }
165+ void Visit (const char * key, int64_t * value) final { names->push_back (key); }
166+ void Visit (const char * key, uint64_t * value) final { names->push_back (key); }
167+ void Visit (const char * key, bool * value) final { names->push_back (key); }
168+ void Visit (const char * key, int * value) final { names->push_back (key); }
169+ void Visit (const char * key, void ** value) final { names->push_back (key); }
170+ void Visit (const char * key, DataType* value) final { names->push_back (key); }
171+ void Visit (const char * key, std::string* value) final { names->push_back (key); }
172+ void Visit (const char * key, runtime::NDArray* value) final { names->push_back (key); }
173+ void Visit (const char * key, runtime::ObjectRef* value) final { names->push_back (key); }
174+ };
175+
75176/* !
76177 * \brief Virtual function table to support IR/AST node reflection.
77178 *
0 commit comments