|
21 | 21 | namespace tvm { |
22 | 22 | namespace tir { |
23 | 23 |
|
24 | | -Schedule Schedule::Concrete(IRModule mod, int debug_mode) { |
| 24 | +Schedule Schedule::Concrete(IRModule mod, int debug_mode, |
| 25 | + ScheduleErrorRenderLevel error_render_level) { |
25 | 26 | ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>(); |
26 | 27 | n->state_ = ScheduleState(mod, debug_mode); |
| 28 | + n->error_render_level_ = error_render_level; |
27 | 29 | n->symbol_table_ = {}; |
28 | 30 | n->analyzer_ = std::make_unique<arith::Analyzer>(); |
29 | 31 | return Schedule(std::move(n)); |
@@ -136,6 +138,7 @@ class ScheduleCopier { |
136 | 138 | scope->src2deps = Copy(old_info.scope->src2deps); |
137 | 139 | scope->dst2deps = Copy(old_info.scope->dst2deps); |
138 | 140 | scope->buffer_writers = Copy(old_info.scope->buffer_writers); |
| 141 | + scope->stage_pipeline = old_info.scope->stage_pipeline; |
139 | 142 | new_info.scope = BlockScope(std::move(scope)); |
140 | 143 | result[Copy(old_sref)] = std::move(new_info); |
141 | 144 | } |
@@ -173,21 +176,81 @@ class ScheduleCopier { |
173 | 176 |
|
174 | 177 | void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { |
175 | 178 | ScheduleCopier::Copy(this, new_state, new_symbol_table); |
| 179 | + new_state->get()->DebugVerify(); |
176 | 180 | } |
177 | 181 |
|
178 | 182 | Schedule ConcreteScheduleNode::Copy() const { |
179 | 183 | ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>(); |
180 | | - Copy(&n->state_, &n->symbol_table_); |
| 184 | + n->error_render_level_ = this->error_render_level_; |
| 185 | + this->Copy(&n->state_, &n->symbol_table_); |
181 | 186 | n->analyzer_ = std::make_unique<arith::Analyzer>(); |
182 | 187 | return Schedule(std::move(n)); |
183 | 188 | } |
184 | 189 |
|
| 190 | +/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */ |
| 191 | +#define TVM_TIR_SCHEDULE_BEGIN() try { |
| 192 | +/*! |
| 193 | + * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error |
| 194 | + * message rendering |
| 195 | + * \param level An ScheduleErrorRenderLevel enum, level of error rendering |
| 196 | + * \sa ScheduleErrorRenderLevel |
| 197 | + */ |
| 198 | +#define TVM_TIR_SCHEDULE_END(level) \ |
| 199 | + } \ |
| 200 | + catch (const ScheduleError& error) { \ |
| 201 | + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ |
| 202 | + throw tvm::runtime::Error(error.RenderReport()); \ |
| 203 | + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ |
| 204 | + throw tvm::runtime::Error(error.FastErrorString()); \ |
| 205 | + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ |
| 206 | + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ |
| 207 | + } \ |
| 208 | + } |
| 209 | + |
185 | 210 | /******** Block/Loop relation ********/ |
186 | 211 |
|
187 | 212 | BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { |
| 213 | + class NotSingleResult : public ScheduleError { |
| 214 | + public: |
| 215 | + explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks) |
| 216 | + : name_(name), mod_(mod), blocks_{} { |
| 217 | + blocks_.reserve(blocks.size()); |
| 218 | + for (const StmtSRef& block_sref : blocks) { |
| 219 | + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); |
| 220 | + blocks_.push_back(GetRef<Block>(block)); |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + String primitive() const final { return "get-block"; } |
| 225 | + IRModule mod() const final { return mod_; } |
| 226 | + Array<ObjectRef> LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } |
| 227 | + |
| 228 | + String DetailRenderTemplate() const final { |
| 229 | + if (blocks_.empty()) { |
| 230 | + return "Cannot find a block with the name: " + name_; |
| 231 | + } else { |
| 232 | + return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_; |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + String FastErrorString() const final { |
| 237 | + if (blocks_.empty()) { |
| 238 | + return "ScheduleError: Cannot find a block with the specified name"; |
| 239 | + } else { |
| 240 | + return "ScheduleError: Found multiple blocks with the specified name"; |
| 241 | + } |
| 242 | + } |
| 243 | + |
| 244 | + String name_; |
| 245 | + IRModule mod_; |
| 246 | + Array<Block> blocks_; |
| 247 | + }; |
188 | 248 | Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name); |
189 | | - CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() |
190 | | - << " blocks with the name: " << name; |
| 249 | + if (blocks.size() != 1) { |
| 250 | + TVM_TIR_SCHEDULE_BEGIN(); |
| 251 | + throw NotSingleResult(name, this->state_->mod, blocks); |
| 252 | + TVM_TIR_SCHEDULE_END(this->error_render_level_); |
| 253 | + } |
191 | 254 | return CreateRV<BlockRV>(blocks[0]); |
192 | 255 | } |
193 | 256 |
|
|
0 commit comments