Skip to content

Commit 03c8a6f

Browse files
junrushaoSiyuan FengspectrometerHBHMasterJH5574jinhongyii
authored
[TensorIR][M2a] Structural Error Reporting (#8121)
This PR is part of the TensorIR upstreaming effort (#7527), stage M2a. In this PR, we implemented ScheduleError, an error reporting mechanism for schedule primitives to report user-face error messages, with the functionality of rendering the TIR out in the TVM script syntax. This set of APIs allows future improvement of error location rendering, e.g. more colorful rendering mechanisms like synr does. Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Tristan Konolige <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Tristan Konolige <[email protected]>
1 parent dc5fc68 commit 03c8a6f

File tree

11 files changed

+296
-9
lines changed

11 files changed

+296
-9
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@
2424
namespace tvm {
2525
namespace tir {
2626

27+
/*! \brief The level of detailed error message rendering */
28+
enum class ScheduleErrorRenderLevel : int32_t {
29+
/*! \brief Render a detailed error message */
30+
kDetail = 0,
31+
/*! \brief Render the error in fast mode */
32+
kFast = 1,
33+
/*! \brief No error message at all */
34+
kNone = 2,
35+
};
36+
2737
/**************** Random variable: BlockRV ****************/
2838

2939
/*! \brief A random variable that evaluates to a TensorIR block */
@@ -209,13 +219,15 @@ class Schedule : public runtime::ObjectRef {
209219
* \param mod The IRModule to be scheduled
210220
* \param debug_mode Do extra correctness checking after the class creation
211221
* and each time after calling the Replace method.
222+
* \param error_render_level The level of error rendering
212223
* \return The concrete schedule created
213224
* \sa ScheduleDebugMask
214225
* \note The checks performed includes:
215226
* 1) VerifySRefTree
216227
* 2) VerifyCachedFlags
217228
*/
218-
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
229+
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
230+
ScheduleErrorRenderLevel error_render_level);
219231
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
220232
};
221233

python/tvm/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from .op import comm_reducer, min, max, sum
4949
from .op import q_multiply_shift
5050

51-
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule
51+
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
5252

5353
from . import schedule
5454
from . import ir_builder

python/tvm/tir/schedule/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919

2020
from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
2121
from .state import ScheduleDebugMask, ScheduleState
22-
from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule
22+
from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError

python/tvm/tir/schedule/schedule.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import List, Optional, Union
2020

2121
from tvm._ffi import register_object as _register_object
22+
from tvm.error import TVMError, register_error
2223
from tvm.ir import IRModule, PrimExpr
2324
from tvm.runtime import Object
2425
from tvm.tir import Block, For, IntImm, PrimFunc, Var
@@ -27,6 +28,11 @@
2728
from .state import ScheduleState, StmtSRef
2829

2930

31+
@register_error
32+
class ScheduleError(TVMError):
33+
"""Error that happens during TensorIR scheduling."""
34+
35+
3036
@_register_object("tir.LoopRV")
3137
class LoopRV(Object):
3238
"""A random variable that refers to a loop"""
@@ -57,10 +63,14 @@ class Schedule(Object):
5763
Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
5864
"""
5965

66+
ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2}
67+
6068
def __init__(
6169
self,
6270
func_or_mod: Union[PrimFunc, IRModule],
71+
*,
6372
debug_mode: Union[bool, int] = False,
73+
error_render_level: str = "detail",
6474
):
6575
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
6676
@@ -71,6 +81,11 @@ def __init__(
7181
debug_mode : Union[bool, int]
7282
Do extra correctness checking after the class creation and each time
7383
scheduling primitive
84+
error_render_level : str = "detail"
85+
The level of error rendering. Choices: "detail", "fast", "none".
86+
"detail": Render a detailed error message, with the TIR and error locations printed
87+
"fast: Show a simple error message without rendering or string manipulation
88+
"none": Do not show any error message.
7489
7590
Note
7691
----------
@@ -85,10 +100,17 @@ def __init__(
85100
debug_mode = 0
86101
if not isinstance(debug_mode, int):
87102
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
103+
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
104+
raise ValueError(
105+
'error_render_level can be "detail", "fast", or "none", but got: '
106+
+ f"{error_render_level}"
107+
)
108+
error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level)
88109
self.__init_handle_by_constructor__(
89110
_ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member
90111
func_or_mod,
91112
debug_mode,
113+
error_render_level,
92114
)
93115

94116
########## Utilities ##########

src/tir/schedule/concrete_schedule.cc

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
namespace tvm {
2222
namespace tir {
2323

24-
Schedule Schedule::Concrete(IRModule mod, int debug_mode) {
24+
Schedule Schedule::Concrete(IRModule mod, int debug_mode,
25+
ScheduleErrorRenderLevel error_render_level) {
2526
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
2627
n->state_ = ScheduleState(mod, debug_mode);
28+
n->error_render_level_ = error_render_level;
2729
n->symbol_table_ = {};
2830
n->analyzer_ = std::make_unique<arith::Analyzer>();
2931
return Schedule(std::move(n));
@@ -136,6 +138,7 @@ class ScheduleCopier {
136138
scope->src2deps = Copy(old_info.scope->src2deps);
137139
scope->dst2deps = Copy(old_info.scope->dst2deps);
138140
scope->buffer_writers = Copy(old_info.scope->buffer_writers);
141+
scope->stage_pipeline = old_info.scope->stage_pipeline;
139142
new_info.scope = BlockScope(std::move(scope));
140143
result[Copy(old_sref)] = std::move(new_info);
141144
}
@@ -173,21 +176,81 @@ class ScheduleCopier {
173176

174177
void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const {
175178
ScheduleCopier::Copy(this, new_state, new_symbol_table);
179+
new_state->get()->DebugVerify();
176180
}
177181

178182
Schedule ConcreteScheduleNode::Copy() const {
179183
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
180-
Copy(&n->state_, &n->symbol_table_);
184+
n->error_render_level_ = this->error_render_level_;
185+
this->Copy(&n->state_, &n->symbol_table_);
181186
n->analyzer_ = std::make_unique<arith::Analyzer>();
182187
return Schedule(std::move(n));
183188
}
184189

190+
/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */
191+
#define TVM_TIR_SCHEDULE_BEGIN() try {
192+
/*!
193+
* \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error
194+
* message rendering
195+
* \param level An ScheduleErrorRenderLevel enum, level of error rendering
196+
* \sa ScheduleErrorRenderLevel
197+
*/
198+
#define TVM_TIR_SCHEDULE_END(level) \
199+
} \
200+
catch (const ScheduleError& error) { \
201+
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
202+
throw tvm::runtime::Error(error.RenderReport()); \
203+
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
204+
throw tvm::runtime::Error(error.FastErrorString()); \
205+
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
206+
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
207+
} \
208+
}
209+
185210
/******** Block/Loop relation ********/
186211

187212
BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
213+
class NotSingleResult : public ScheduleError {
214+
public:
215+
explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks)
216+
: name_(name), mod_(mod), blocks_{} {
217+
blocks_.reserve(blocks.size());
218+
for (const StmtSRef& block_sref : blocks) {
219+
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
220+
blocks_.push_back(GetRef<Block>(block));
221+
}
222+
}
223+
224+
String primitive() const final { return "get-block"; }
225+
IRModule mod() const final { return mod_; }
226+
Array<ObjectRef> LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; }
227+
228+
String DetailRenderTemplate() const final {
229+
if (blocks_.empty()) {
230+
return "Cannot find a block with the name: " + name_;
231+
} else {
232+
return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_;
233+
}
234+
}
235+
236+
String FastErrorString() const final {
237+
if (blocks_.empty()) {
238+
return "ScheduleError: Cannot find a block with the specified name";
239+
} else {
240+
return "ScheduleError: Found multiple blocks with the specified name";
241+
}
242+
}
243+
244+
String name_;
245+
IRModule mod_;
246+
Array<Block> blocks_;
247+
};
188248
Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name);
189-
CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size()
190-
<< " blocks with the name: " << name;
249+
if (blocks.size() != 1) {
250+
TVM_TIR_SCHEDULE_BEGIN();
251+
throw NotSingleResult(name, this->state_->mod, blocks);
252+
TVM_TIR_SCHEDULE_END(this->error_render_level_);
253+
}
191254
return CreateRV<BlockRV>(blocks[0]);
192255
}
193256

src/tir/schedule/concrete_schedule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ class ConcreteScheduleNode : public ScheduleNode {
3737
protected:
3838
/*! \brief The internal state of scheduling */
3939
ScheduleState state_;
40+
/*! \brief The level of error rendering */
41+
ScheduleErrorRenderLevel error_render_level_;
4042
/*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */
4143
TSymbolTable symbol_table_;
4244
/*! \brief A persistent stateless arithmetic analyzer. */
4345
std::unique_ptr<arith::Analyzer> analyzer_;
4446

4547
public:
4648
void VisitAttrs(tvm::AttrVisitor* v) {
49+
// `error_render_level_` is not visited
4750
// `state_` is not visited
4851
// `symbol_table_` is not visited
4952
// `analyzer_` is not visitied

src/tir/schedule/error.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "./utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
String ScheduleError::RenderReport() const {
25+
IRModule mod = this->mod();
26+
std::ostringstream os;
27+
os << "ScheduleError: An error occurred in the schedule primitive '" << this->primitive()
28+
<< "'.\n\nThe IR is:\n"
29+
<< AsTVMScript(mod);
30+
Array<ObjectRef> locs = LocationsOfInterest();
31+
int n_locs = locs.size();
32+
std::vector<String> roi_names;
33+
roi_names.reserve(n_locs);
34+
if (n_locs > 0) {
35+
os << "Regions of interest:\n";
36+
for (const ObjectRef& obj : locs) {
37+
String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size());
38+
os << name << "\n" << obj;
39+
roi_names.emplace_back(std::move(name));
40+
}
41+
os << "\n";
42+
}
43+
std::string msg = DetailRenderTemplate();
44+
for (int i = 0; i < n_locs; ++i) {
45+
std::string src = "{" + std::to_string(i) + "}";
46+
for (size_t pos; (pos = msg.find(src)) != std::string::npos;) {
47+
msg.replace(pos, src.length(), roi_names[i]);
48+
}
49+
}
50+
os << "Error message: " << msg;
51+
return os.str();
52+
}
53+
54+
} // namespace tir
55+
} // namespace tvm

src/tir/schedule/error.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_TIR_SCHEDULE_ERROR_H_
20+
#define TVM_TIR_SCHEDULE_ERROR_H_
21+
22+
#include <tvm/tir/schedule/state.h>
23+
24+
namespace tvm {
25+
namespace tir {
26+
27+
/*! \brief Error that happens during TensorIR scheduling */
28+
class ScheduleError : public tvm::runtime::Error {
29+
public:
30+
/*! \brief Base constructor */
31+
ScheduleError() : tvm::runtime::Error("") {}
32+
/*! \brief The error occurred in this scheduling primitive */
33+
virtual String primitive() const = 0;
34+
/*! \brief The error occurred in this IRModule */
35+
virtual IRModule mod() const = 0;
36+
/*! \brief The locations of interest that we want to point out */
37+
virtual Array<ObjectRef> LocationsOfInterest() const = 0;
38+
/*!
39+
* \brief Returns an error string template for rendering, corresponds to the "detail" mode.
40+
* \sa ScheduleErrorRenderLevel
41+
* \note The template is a string, e.g.
42+
* "Some error occurred on block {0} and loop {1} blah blah"
43+
* And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right
44+
* now it only printed out all the locations in plain text, but in the future, we may want to mark
45+
* the IR with underscores and attach names to each location of interest, like what synr does.
46+
*/
47+
virtual String DetailRenderTemplate() const = 0;
48+
/*!
49+
* \brief Returns an error string without needing to render, corresponds to the "fast" mode
50+
* \sa ScheduleErrorRenderLevel
51+
*/
52+
virtual String FastErrorString() const = 0;
53+
/*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */
54+
String RenderReport() const;
55+
};
56+
57+
} // namespace tir
58+
} // namespace tvm
59+
60+
#endif // TVM_TIR_SCHEDULE_ERROR_H_

src/tir/schedule/schedule.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
5656
/**************** (FFI) Constructor ****************/
5757

5858
TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
59-
.set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule {
59+
.set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule {
6060
IRModule mod{nullptr};
6161
if (const auto* func = obj.as<PrimFuncNode>()) {
6262
mod = IRModule({{GlobalVar("main"), GetRef<BaseFunc>(func)}});
@@ -66,7 +66,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
6666
LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: "
6767
<< obj->GetTypeKey();
6868
}
69-
return Schedule::Concrete(mod, debug_mode);
69+
return Schedule::Concrete(mod, debug_mode,
70+
static_cast<ScheduleErrorRenderLevel>(error_render_level));
7071
});
7172

7273
/******** (FFI) Lookup random variables ********/

src/tir/schedule/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "../../printer/text_printer.h"
3636
#include "../../runtime/thread_storage_scope.h"
3737
#include "./analysis.h"
38+
#include "./error.h"
3839

3940
namespace tvm {
4041
namespace tir {

0 commit comments

Comments
 (0)