Skip to content

Commit 07ff4b4

Browse files
authored
Merge branch 'master' into master
2 parents f88e2a7 + 02631f6 commit 07ff4b4

File tree

14 files changed

+894
-86
lines changed

14 files changed

+894
-86
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
[submodule "dlpack"]
88
path = 3rdparty/dlpack
99
url = https://github.com/dmlc/dlpack
10+
[submodule "3rdparty/rang"]
11+
path = 3rdparty/rang
12+
url = https://github.com/agauniyal/rang

3rdparty/rang

Submodule rang added at cabe04d

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
5353
include_directories("include")
5454
include_directories("3rdparty/dlpack/include")
5555
include_directories("3rdparty/dmlc-core/include")
56+
include_directories("3rdparty/rang/include")
5657
include_directories("3rdparty/compiler-rt")
5758

5859
# initial variables

include/tvm/relay/error.h

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,134 @@
77
#define TVM_RELAY_ERROR_H_
88

99
#include <string>
10+
#include <vector>
11+
#include <sstream>
1012
#include "./base.h"
13+
#include "./expr.h"
14+
#include "./module.h"
1115

1216
namespace tvm {
1317
namespace relay {
1418

15-
struct Error : public dmlc::Error {
16-
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
17-
};
19+
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
20+
21+
// Forward declaratio for RelayErrorStream.
22+
struct Error;
23+
24+
/*! \brief A wrapper around std::stringstream.
25+
*
26+
* This is designed to avoid platform specific
27+
* issues compiling and using std::stringstream
28+
* for error reporting.
29+
*/
30+
struct RelayErrorStream {
31+
std::stringstream ss;
32+
33+
template<typename T>
34+
RelayErrorStream& operator<<(const T& t) {
35+
ss << t;
36+
return *this;
37+
}
1838

19-
struct InternalError : public Error {
20-
explicit InternalError(const std::string &msg) : Error(msg) {}
39+
std::string str() const {
40+
return ss.str();
41+
}
42+
43+
void Raise() const;
2144
};
2245

23-
struct FatalTypeError : public Error {
24-
explicit FatalTypeError(const std::string &s) : Error(s) {}
46+
struct Error : public dmlc::Error {
47+
Span sp;
48+
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
49+
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
50+
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
2551
};
2652

27-
struct TypecheckerError : public Error {
28-
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
53+
/*! \brief An abstraction around how errors are stored and reported.
54+
* Designed to be opaque to users, so we can support a robust and simpler
55+
* error reporting mode, as well as a more complex mode.
56+
*
57+
* The first mode is the most accurate: we report a Relay error at a specific
58+
* Span, and then render the error message directly against a textual representation
59+
* of the program, highlighting the exact lines in which it occurs. This mode is not
60+
* implemented in this PR and will not work.
61+
*
62+
* The second mode is a general-purpose mode, which attempts to annotate the program's
63+
* textual format with errors.
64+
*
65+
* The final mode represents the old mode, if we report an error that has no span or
66+
* expression, we will default to throwing an exception with a textual representation
67+
* of the error and no indication of where it occured in the original program.
68+
*
69+
* The latter mode is not ideal, and the goal of the new error reporting machinery is
70+
* to avoid ever reporting errors in this style.
71+
*/
72+
class ErrorReporter {
73+
public:
74+
ErrorReporter() : errors_(), node_to_error_() {}
75+
76+
/*! \brief Report a tvm::relay::Error.
77+
*
78+
* This API is useful for reporting spanned errors.
79+
*
80+
* \param err The error to report.
81+
*/
82+
void Report(const Error& err) {
83+
if (!err.sp.defined()) {
84+
throw err;
85+
}
86+
87+
this->errors_.push_back(err);
88+
}
89+
90+
/*! \brief Report an error against a program, using the full program
91+
* error reporting strategy.
92+
*
93+
* This error reporting method requires the global function in which
94+
* to report an error, the expression to report the error on,
95+
* and the error object.
96+
*
97+
* \param global The global function in which the expression is contained.
98+
* \param node The expression or type to report the error at.
99+
* \param err The error message to report.
100+
*/
101+
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
102+
this->ReportAt(global, node, Error(err));
103+
}
104+
105+
/*! \brief Report an error against a program, using the full program
106+
* error reporting strategy.
107+
*
108+
* This error reporting method requires the global function in which
109+
* to report an error, the expression to report the error on,
110+
* and the error object.
111+
*
112+
* \param global The global function in which the expression is contained.
113+
* \param node The expression or type to report the error at.
114+
* \param err The error to report.
115+
*/
116+
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
117+
118+
/*! \brief Render all reported errors and exit the program.
119+
*
120+
* This function should be used after executing a pass to render reported errors.
121+
*
122+
* It will build an error message from the set of errors, depending on the error
123+
* reporting strategy.
124+
*
125+
* \param module The module to report errors on.
126+
* \param use_color Controls whether to colorize the output.
127+
*/
128+
void RenderErrors(const Module& module, bool use_color = true);
129+
130+
inline bool AnyErrors() {
131+
return errors_.size() != 0;
132+
}
133+
134+
private:
135+
std::vector<Error> errors_;
136+
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
137+
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
29138
};
30139

31140
} // namespace relay

include/tvm/relay/module.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
4343
/*! \brief A map from ids to all global functions. */
4444
tvm::Map<GlobalVar, Function> functions;
4545

46+
/*! \brief The entry function (i.e. "main"). */
47+
GlobalVar entry_func;
48+
4649
ModuleNode() {}
4750

4851
void VisitAttrs(tvm::AttrVisitor* v) final {
4952
v->Visit("functions", &functions);
5053
v->Visit("global_var_map_", &global_var_map_);
54+
v->Visit("entry_func", &entry_func);
5155
}
5256

5357
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
@@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
111115
*/
112116
void Update(const Module& other);
113117

118+
/*! \brief Construct a module from a standalone expression.
119+
*
120+
* Allows one to optionally pass a global function map as
121+
* well.
122+
*
123+
* \param expr The expression to set as the entry point to the module.
124+
* \param global_funcs The global function map.
125+
*
126+
* \returns A module with expr set as the entry point.
127+
*/
128+
static Module FromExpr(
129+
const Expr& expr,
130+
const tvm::Map<GlobalVar, Function>& global_funcs = {});
131+
114132
static constexpr const char* _type_key = "relay.Module";
115133
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
116134

@@ -132,6 +150,7 @@ struct Module : public NodeRef {
132150
using ContainerType = ModuleNode;
133151
};
134152

153+
135154
} // namespace relay
136155
} // namespace tvm
137156

include/tvm/relay/pass.h

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#ifndef TVM_RELAY_PASS_H_
77
#define TVM_RELAY_PASS_H_
88

9-
#include <tvm/relay/module.h>
109
#include <tvm/relay/expr.h>
10+
#include <tvm/relay/module.h>
1111
#include <tvm/relay/op_attr_types.h>
1212
#include <string>
1313

@@ -119,6 +119,17 @@ TVM_DLL bool WellFormed(const Expr& expr);
119119
*/
120120
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
121121

122+
/*! \brief Get all bound variables from expression expr.
123+
*
124+
* Bound variables are all variables that are declared in the expr.
125+
* They only have meaning inside that expr, and can only be used in it.
126+
*
127+
* \param expr the expression.
128+
*
129+
* \return List of bound vars, in the PostDFS order in the expression.
130+
*/
131+
tvm::Array<Var> BoundVars(const Expr& expr);
132+
122133
/*! \brief Get free type parameters from expression expr.
123134
*
124135
* Free variables are variables that are not bound by a
@@ -138,6 +149,14 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
138149
*/
139150
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
140151

152+
/*! \brief Get all variables from expression expr.
153+
*
154+
* \param expr the expression.
155+
*
156+
* \return List of all vars, in the PostDFS order in the expression.
157+
*/
158+
tvm::Array<Var> AllVars(const Expr& expr);
159+
141160
/*! \brief Get free TypeVars from expression expr.
142161
*
143162
* Free type parameters are type parameters that are not bound by a function
@@ -198,6 +217,55 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr);
198217
*/
199218
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t);
200219

220+
/*! \brief Get free TypeVars from type t.
221+
*
222+
* Free type parameters are type parameters that are not bound by a function
223+
* type in the context.
224+
*
225+
* \param t the type.
226+
*
227+
* \return List of free type vars, in the PostDFS order visited by type.
228+
*/
229+
tvm::Array<TypeVar> FreeTypeVars(const Type& t);
230+
231+
/*! \brief Get all bound type variables from expression expr.
232+
*
233+
* Bound variables are all type variables that are declared in the expr.
234+
* They only have meaning inside that expr, and can only be used in it.
235+
*
236+
* \param expr the expression.
237+
*
238+
* \return List of bound type vars, in the PostDFS order in the expression.
239+
*/
240+
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
241+
242+
/*! \brief Get all bound type variables from type t.
243+
*
244+
* Bound variables are all type variables that are declared in the type.
245+
* They only have meaning inside that type, and can only be used in it.
246+
*
247+
* \param t the type
248+
*
249+
* \return List of bound type vars, in the PostDFS order visited by type.
250+
*/
251+
tvm::Array<TypeVar> BoundTypeVars(const Type& t);
252+
253+
/*! \brief Get all type variables in expression expr.
254+
*
255+
* \param expr the expression.
256+
*
257+
* \return List of type vars, in the PostDFS order in the expression.
258+
*/
259+
tvm::Array<TypeVar> AllTypeVars(const Expr& expr);
260+
261+
/*! \brief Get all type variables in type t.
262+
*
263+
* \param t the type.
264+
*
265+
* \return List of type vars, in the PostDFS order visited by type.
266+
*/
267+
tvm::Array<TypeVar> AllTypeVars(const Type& t);
268+
201269
/*! \brief Remove expressions which does not effect the program result.
202270
*
203271
* It will remove let bindings which are not referenced, and branches that will

include/tvm/relay/type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
295295
*/
296296
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
297297

298+
/*!
299+
* \brief Set the location at which to report unification errors.
300+
* \param ref The program node to report the error.
301+
*/
302+
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
303+
298304
// solver is not serializable.
299305
void VisitAttrs(tvm::AttrVisitor* v) final {}
300306

0 commit comments

Comments
 (0)