Skip to content

Commit 02631f6

Browse files
jroeschtqchen
authored andcommitted
[Relay] Add generic & informative Relay error reporting (#2408)
1 parent 4e57323 commit 02631f6

File tree

14 files changed

+537
-86
lines changed

14 files changed

+537
-86
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
[submodule "dlpack"]
88
path = 3rdparty/dlpack
99
url = https://github.com/dmlc/dlpack
10+
[submodule "3rdparty/rang"]
11+
path = 3rdparty/rang
12+
url = https://github.com/agauniyal/rang

3rdparty/rang

Submodule rang added at cabe04d

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
5353
include_directories("include")
5454
include_directories("3rdparty/dlpack/include")
5555
include_directories("3rdparty/dmlc-core/include")
56+
include_directories("3rdparty/rang/include")
5657
include_directories("3rdparty/compiler-rt")
5758

5859
# initial variables

include/tvm/relay/error.h

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,134 @@
77
#define TVM_RELAY_ERROR_H_
88

99
#include <string>
10+
#include <vector>
11+
#include <sstream>
1012
#include "./base.h"
13+
#include "./expr.h"
14+
#include "./module.h"
1115

1216
namespace tvm {
1317
namespace relay {
1418

15-
struct Error : public dmlc::Error {
16-
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
17-
};
19+
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
20+
21+
// Forward declaratio for RelayErrorStream.
22+
struct Error;
23+
24+
/*! \brief A wrapper around std::stringstream.
25+
*
26+
* This is designed to avoid platform specific
27+
* issues compiling and using std::stringstream
28+
* for error reporting.
29+
*/
30+
struct RelayErrorStream {
31+
std::stringstream ss;
32+
33+
template<typename T>
34+
RelayErrorStream& operator<<(const T& t) {
35+
ss << t;
36+
return *this;
37+
}
1838

19-
struct InternalError : public Error {
20-
explicit InternalError(const std::string &msg) : Error(msg) {}
39+
std::string str() const {
40+
return ss.str();
41+
}
42+
43+
void Raise() const;
2144
};
2245

23-
struct FatalTypeError : public Error {
24-
explicit FatalTypeError(const std::string &s) : Error(s) {}
46+
struct Error : public dmlc::Error {
47+
Span sp;
48+
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
49+
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
50+
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
2551
};
2652

27-
struct TypecheckerError : public Error {
28-
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
53+
/*! \brief An abstraction around how errors are stored and reported.
54+
* Designed to be opaque to users, so we can support a robust and simpler
55+
* error reporting mode, as well as a more complex mode.
56+
*
57+
* The first mode is the most accurate: we report a Relay error at a specific
58+
* Span, and then render the error message directly against a textual representation
59+
* of the program, highlighting the exact lines in which it occurs. This mode is not
60+
* implemented in this PR and will not work.
61+
*
62+
* The second mode is a general-purpose mode, which attempts to annotate the program's
63+
* textual format with errors.
64+
*
65+
* The final mode represents the old mode, if we report an error that has no span or
66+
* expression, we will default to throwing an exception with a textual representation
67+
* of the error and no indication of where it occured in the original program.
68+
*
69+
* The latter mode is not ideal, and the goal of the new error reporting machinery is
70+
* to avoid ever reporting errors in this style.
71+
*/
72+
class ErrorReporter {
73+
public:
74+
ErrorReporter() : errors_(), node_to_error_() {}
75+
76+
/*! \brief Report a tvm::relay::Error.
77+
*
78+
* This API is useful for reporting spanned errors.
79+
*
80+
* \param err The error to report.
81+
*/
82+
void Report(const Error& err) {
83+
if (!err.sp.defined()) {
84+
throw err;
85+
}
86+
87+
this->errors_.push_back(err);
88+
}
89+
90+
/*! \brief Report an error against a program, using the full program
91+
* error reporting strategy.
92+
*
93+
* This error reporting method requires the global function in which
94+
* to report an error, the expression to report the error on,
95+
* and the error object.
96+
*
97+
* \param global The global function in which the expression is contained.
98+
* \param node The expression or type to report the error at.
99+
* \param err The error message to report.
100+
*/
101+
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
102+
this->ReportAt(global, node, Error(err));
103+
}
104+
105+
/*! \brief Report an error against a program, using the full program
106+
* error reporting strategy.
107+
*
108+
* This error reporting method requires the global function in which
109+
* to report an error, the expression to report the error on,
110+
* and the error object.
111+
*
112+
* \param global The global function in which the expression is contained.
113+
* \param node The expression or type to report the error at.
114+
* \param err The error to report.
115+
*/
116+
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
117+
118+
/*! \brief Render all reported errors and exit the program.
119+
*
120+
* This function should be used after executing a pass to render reported errors.
121+
*
122+
* It will build an error message from the set of errors, depending on the error
123+
* reporting strategy.
124+
*
125+
* \param module The module to report errors on.
126+
* \param use_color Controls whether to colorize the output.
127+
*/
128+
void RenderErrors(const Module& module, bool use_color = true);
129+
130+
inline bool AnyErrors() {
131+
return errors_.size() != 0;
132+
}
133+
134+
private:
135+
std::vector<Error> errors_;
136+
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
137+
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
29138
};
30139

31140
} // namespace relay

include/tvm/relay/module.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
4343
/*! \brief A map from ids to all global functions. */
4444
tvm::Map<GlobalVar, Function> functions;
4545

46+
/*! \brief The entry function (i.e. "main"). */
47+
GlobalVar entry_func;
48+
4649
ModuleNode() {}
4750

4851
void VisitAttrs(tvm::AttrVisitor* v) final {
4952
v->Visit("functions", &functions);
5053
v->Visit("global_var_map_", &global_var_map_);
54+
v->Visit("entry_func", &entry_func);
5155
}
5256

5357
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
@@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
111115
*/
112116
void Update(const Module& other);
113117

118+
/*! \brief Construct a module from a standalone expression.
119+
*
120+
* Allows one to optionally pass a global function map as
121+
* well.
122+
*
123+
* \param expr The expression to set as the entry point to the module.
124+
* \param global_funcs The global function map.
125+
*
126+
* \returns A module with expr set as the entry point.
127+
*/
128+
static Module FromExpr(
129+
const Expr& expr,
130+
const tvm::Map<GlobalVar, Function>& global_funcs = {});
131+
114132
static constexpr const char* _type_key = "relay.Module";
115133
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
116134

@@ -132,6 +150,7 @@ struct Module : public NodeRef {
132150
using ContainerType = ModuleNode;
133151
};
134152

153+
135154
} // namespace relay
136155
} // namespace tvm
137156

include/tvm/relay/pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#ifndef TVM_RELAY_PASS_H_
77
#define TVM_RELAY_PASS_H_
88

9-
#include <tvm/relay/module.h>
109
#include <tvm/relay/expr.h>
10+
#include <tvm/relay/module.h>
1111
#include <tvm/relay/op_attr_types.h>
1212
#include <string>
1313

include/tvm/relay/type.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
295295
*/
296296
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
297297

298+
/*!
299+
* \brief Set the location at which to report unification errors.
300+
* \param ref The program node to report the error.
301+
*/
302+
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
303+
298304
// solver is not serializable.
299305
void VisitAttrs(tvm::AttrVisitor* v) final {}
300306

src/relay/ir/error.cc

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*!
2+
* Copyright (c) 2018 by Contributors
3+
* \file error_reporter.h
4+
* \brief The set of errors raised by Relay.
5+
*/
6+
7+
#include <tvm/relay/expr.h>
8+
#include <tvm/relay/module.h>
9+
#include <tvm/relay/error.h>
10+
#include <string>
11+
#include <vector>
12+
#include <rang.hpp>
13+
14+
namespace tvm {
15+
namespace relay {
16+
17+
void RelayErrorStream::Raise() const {
18+
throw Error(*this);
19+
}
20+
21+
template<typename T, typename U>
22+
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
23+
24+
void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
25+
// First we pick an error reporting strategy for each error.
26+
// TODO(@jroesch): Spanned errors are currently not supported.
27+
for (auto err : this->errors_) {
28+
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
29+
}
30+
31+
NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;
32+
33+
// Set control mode in order to produce colors;
34+
if (use_color) {
35+
rang::setControlMode(rang::control::Force);
36+
}
37+
38+
for (auto pair : this->node_to_gv_) {
39+
auto node = pair.first;
40+
auto global = Downcast<GlobalVar>(pair.second);
41+
42+
auto has_errs = this->node_to_error_.find(node);
43+
44+
CHECK(has_errs != this->node_to_error_.end());
45+
46+
const auto& error_indicies = has_errs->second;
47+
48+
std::stringstream err_msg;
49+
50+
err_msg << rang::fg::red;
51+
for (auto index : error_indicies) {
52+
err_msg << this->errors_[index].what() << "; ";
53+
}
54+
err_msg << rang::fg::reset;
55+
56+
// Setup error map.
57+
auto it = error_maps.find(global);
58+
if (it != error_maps.end()) {
59+
it->second.insert({ node, err_msg.str() });
60+
} else {
61+
error_maps.insert({ global, { { node, err_msg.str() }}});
62+
}
63+
}
64+
65+
// Now we will construct the fully-annotated program to display to
66+
// the user.
67+
std::stringstream annotated_prog;
68+
69+
// First we output a header for the errors.
70+
annotated_prog <<
71+
rang::style::bold << std::endl <<
72+
"Error(s) have occurred. We have annotated the program with them:"
73+
<< std::endl << std::endl << rang::style::reset;
74+
75+
// For each global function which contains errors, we will
76+
// construct an annotated function.
77+
for (auto pair : error_maps) {
78+
auto global = pair.first;
79+
auto err_map = pair.second;
80+
auto func = module->Lookup(global);
81+
82+
// We output the name of the function before displaying
83+
// the annotated program.
84+
annotated_prog <<
85+
rang::style::bold <<
86+
"In `" << global->name_hint << "`: " <<
87+
std::endl <<
88+
rang::style::reset;
89+
90+
// We then call into the Relay printer to generate the program.
91+
//
92+
// The annotation callback will annotate the error messages
93+
// contained in the map.
94+
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
95+
auto it = err_map.find(expr);
96+
if (it != err_map.end()) {
97+
return it->second;
98+
} else {
99+
return std::string("");
100+
}
101+
});
102+
}
103+
104+
auto msg = annotated_prog.str();
105+
106+
if (use_color) {
107+
rang::setControlMode(rang::control::Auto);
108+
}
109+
110+
// Finally we report the error, currently we do so to LOG(FATAL),
111+
// it may be good to instead report it to std::cout.
112+
LOG(FATAL) << annotated_prog.str() << std::endl;
113+
}
114+
115+
void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
116+
size_t index_to_insert = this->errors_.size();
117+
this->errors_.push_back(err);
118+
auto it = this->node_to_error_.find(node);
119+
if (it != this->node_to_error_.end()) {
120+
it->second.push_back(index_to_insert);
121+
} else {
122+
this->node_to_error_.insert({ node, { index_to_insert }});
123+
}
124+
this->node_to_gv_.insert({ node, global });
125+
}
126+
127+
} // namespace relay
128+
} // namespace tvm

src/relay/ir/module.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
2323
<< "Duplicate global function name " << kv.first->name_hint;
2424
n->global_var_map_.Set(kv.first->name_hint, kv.first);
2525
}
26+
27+
n->entry_func = GlobalVarNode::make("main");
2628
return Module(n);
2729
}
2830

@@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) {
9698
}
9799
}
98100

101+
Module ModuleNode::FromExpr(
102+
const Expr& expr,
103+
const tvm::Map<GlobalVar, Function>& global_funcs) {
104+
auto mod = ModuleNode::make(global_funcs);
105+
auto func_node = expr.as<FunctionNode>();
106+
Function func;
107+
if (func_node) {
108+
func = GetRef<Function>(func_node);
109+
} else {
110+
func = FunctionNode::make({}, expr, Type(), {}, {});
111+
}
112+
mod->Add(mod->entry_func, func);
113+
return mod;
114+
}
115+
99116
TVM_REGISTER_NODE_TYPE(ModuleNode);
100117

101118
TVM_REGISTER_API("relay._make.Module")

0 commit comments

Comments
 (0)