Skip to content

Commit a9c610f

Browse files
authored
[TVMScript] Add ObjectPath class (#11977)
Motivation: Same IR node object can be referenced in several different contexts inside a larger IR object. For example, a variable could be referenced in several statements within a block. This makes it impossible to use an object pointer to uniquely identify a "location" within the larger IR object for error reporting purposes. The `ObjectPath` class addresses this problem by serving as a unique "locator". Tracking issue: #11912
1 parent 261de53 commit a9c610f

File tree

4 files changed

+865
-0
lines changed

4 files changed

+865
-0
lines changed

include/tvm/node/object_path.h

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/node/object_path.h
22+
* ObjectPath class that represents a path from a root object to one of its descendants
23+
* via attribute access, array indexing etc.
24+
*/
25+
26+
#ifndef TVM_NODE_OBJECT_PATH_H_
27+
#define TVM_NODE_OBJECT_PATH_H_
28+
29+
#include <tvm/runtime/container/optional.h>
30+
#include <tvm/runtime/container/string.h>
31+
#include <tvm/runtime/object.h>
32+
33+
#include <string>
34+
35+
namespace tvm {
36+
37+
using runtime::Object;
38+
using runtime::ObjectPtr;
39+
using runtime::ObjectRef;
40+
41+
class ObjectPath;
42+
43+
/*!
44+
* \brief Path to an object from some root object.
45+
*
46+
* Motivation:
47+
*
48+
* Same IR node object can be referenced in several different contexts inside a larger IR object.
49+
* For example, a variable could be referenced in several statements within a block.
50+
*
51+
* This makes it impossible to use an object pointer to uniquely identify a "location" within
52+
* the larger IR object for error reporting purposes. The ObjectPath class addresses this problem
53+
* by serving as a unique "locator".
54+
*/
55+
class ObjectPathNode : public Object {
56+
public:
57+
/*! \brief Get the parent path */
58+
Optional<ObjectPath> GetParent() const;
59+
/*!
60+
* \brief Get the length of the path.
61+
*
62+
* For example, the path returned by `ObjectPath::Root()` has length 1.
63+
*/
64+
int32_t Length() const;
65+
66+
/*!
67+
* \brief Get a path prefix of the given length.
68+
*
69+
* Provided `length` must not exceed the `Length()` of this path.
70+
*/
71+
ObjectPath GetPrefix(int32_t length) const;
72+
73+
/*!
74+
* \brief Check if this path is a prefix of another path.
75+
*
76+
* The prefix is not strict, i.e. a path is considered a prefix of itself.
77+
*/
78+
bool IsPrefixOf(const ObjectPath& other) const;
79+
80+
/*! \brief Check if two paths are equal. */
81+
bool PathsEqual(const ObjectPath& other) const;
82+
83+
/*! \brief Extend this path with access to an object attribute. */
84+
ObjectPath Attr(const char* attr_key) const;
85+
86+
/*! \brief Extend this path with access to an object attribute. */
87+
ObjectPath Attr(Optional<String> attr_key) const;
88+
89+
/*! \brief Extend this path with access to an array element. */
90+
ObjectPath ArrayIndex(int32_t index) const;
91+
92+
/*! \brief Extend this path with access to a missing array element. */
93+
ObjectPath MissingArrayElement(int32_t index) const;
94+
95+
/*! \brief Extend this path with access to a map value. */
96+
ObjectPath MapValue(ObjectRef key) const;
97+
98+
/*! \brief Extend this path with access to a missing map entry. */
99+
ObjectPath MissingMapEntry() const;
100+
101+
static constexpr const char* _type_key = "ObjectPath";
102+
TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object);
103+
104+
protected:
105+
explicit ObjectPathNode(const ObjectPathNode* parent);
106+
107+
friend class ObjectPath;
108+
friend std::string GetObjectPathRepr(const ObjectPathNode* node);
109+
110+
const ObjectPathNode* ParentNode() const;
111+
112+
/*! Compares just the last node of the path, without comparing the whole path. */
113+
virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0;
114+
115+
virtual std::string LastNodeString() const = 0;
116+
117+
private:
118+
Optional<ObjectRef> parent_;
119+
int32_t length_;
120+
};
121+
122+
class ObjectPath : public ObjectRef {
123+
public:
124+
/*! \brief Create a path that represents the root object itself. */
125+
static ObjectPath Root();
126+
127+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode);
128+
};
129+
130+
//-------------------------------------------------------------------------
131+
//----- Concrete object path nodes ------------------------------------
132+
//-------------------------------------------------------------------------
133+
134+
// ----- Root -----
135+
136+
class RootPathNode final : public ObjectPathNode {
137+
public:
138+
explicit RootPathNode();
139+
140+
static constexpr const char* _type_key = "RootPath";
141+
TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
142+
143+
protected:
144+
bool LastNodeEqual(const ObjectPathNode* other) const final;
145+
std::string LastNodeString() const final;
146+
};
147+
148+
class RootPath : public ObjectPath {
149+
public:
150+
TVM_DEFINE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode);
151+
};
152+
153+
// ----- Attribute access -----
154+
155+
class AttributeAccessPathNode final : public ObjectPathNode {
156+
public:
157+
/*! \brief Name of the attribute being accessed. Must be a static string. */
158+
String attr_key;
159+
160+
explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key);
161+
162+
static constexpr const char* _type_key = "AttributeAccessPath";
163+
TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode);
164+
165+
protected:
166+
bool LastNodeEqual(const ObjectPathNode* other) const final;
167+
std::string LastNodeString() const final;
168+
};
169+
170+
class AttributeAccessPath : public ObjectPath {
171+
public:
172+
TVM_DEFINE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, AttributeAccessPathNode);
173+
};
174+
175+
// ----- Unknown attribute access -----
176+
177+
class UnknownAttributeAccessPathNode final : public ObjectPathNode {
178+
public:
179+
explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent);
180+
181+
static constexpr const char* _type_key = "UnknownAttributeAccessPath";
182+
TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode);
183+
184+
protected:
185+
bool LastNodeEqual(const ObjectPathNode* other) const final;
186+
std::string LastNodeString() const final;
187+
};
188+
189+
class UnknownAttributeAccessPath : public ObjectPath {
190+
public:
191+
TVM_DEFINE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath,
192+
UnknownAttributeAccessPathNode);
193+
};
194+
195+
// ----- Array element access by index -----
196+
197+
class ArrayIndexPathNode : public ObjectPathNode {
198+
public:
199+
/*! \brief Index of the array element that is being accessed. */
200+
int32_t index;
201+
202+
explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index);
203+
204+
static constexpr const char* _type_key = "ArrayIndexPath";
205+
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode);
206+
207+
protected:
208+
bool LastNodeEqual(const ObjectPathNode* other) const final;
209+
std::string LastNodeString() const final;
210+
};
211+
212+
class ArrayIndexPath : public ObjectPath {
213+
public:
214+
TVM_DEFINE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode);
215+
};
216+
217+
// ----- Missing array element -----
218+
219+
class MissingArrayElementPathNode : public ObjectPathNode {
220+
public:
221+
/*! \brief Index of the array element that is missing. */
222+
int32_t index;
223+
224+
explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index);
225+
226+
static constexpr const char* _type_key = "MissingArrayElementPath";
227+
TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode);
228+
229+
protected:
230+
bool LastNodeEqual(const ObjectPathNode* other) const final;
231+
std::string LastNodeString() const final;
232+
};
233+
234+
class MissingArrayElementPath : public ObjectPath {
235+
public:
236+
TVM_DEFINE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, MissingArrayElementPathNode);
237+
};
238+
239+
// ----- Map value -----
240+
241+
class MapValuePathNode : public ObjectPathNode {
242+
public:
243+
/*! \brief Key of the map entry that is being accessed */
244+
ObjectRef key;
245+
246+
explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key);
247+
248+
static constexpr const char* _type_key = "MapValuePath";
249+
TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode);
250+
251+
protected:
252+
bool LastNodeEqual(const ObjectPathNode* other) const final;
253+
std::string LastNodeString() const final;
254+
};
255+
256+
class MapValuePath : public ObjectPath {
257+
public:
258+
TVM_DEFINE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode);
259+
};
260+
261+
// ----- Missing map entry -----
262+
263+
class MissingMapEntryPathNode : public ObjectPathNode {
264+
public:
265+
explicit MissingMapEntryPathNode(const ObjectPathNode* parent);
266+
267+
static constexpr const char* _type_key = "MissingMapEntryPath";
268+
TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode);
269+
270+
protected:
271+
bool LastNodeEqual(const ObjectPathNode* other) const final;
272+
std::string LastNodeString() const final;
273+
};
274+
275+
class MissingMapEntryPath : public ObjectPath {
276+
public:
277+
TVM_DEFINE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, MissingMapEntryPathNode);
278+
};
279+
280+
} // namespace tvm
281+
282+
#endif // TVM_NODE_OBJECT_PATH_H_

0 commit comments

Comments
 (0)