Skip to content

Commit ad6fbec

Browse files
authored
[TIR] Improved error message in InjectSoftwarePipeline (#14391)
Updated the error message to state which PrimFunc has a malformed pipeline annotation, the blocks found in that primfunc, and the pipeline annotation found.
1 parent 6759702 commit ad6fbec

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/tir/transforms/inject_software_pipeline.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,8 @@ void BuildDependencyGraph(
989989
class PipelineInjector : private StmtExprMutator {
990990
public:
991991
static Stmt Inject(const PrimFunc& func) {
992-
PipelineInjector injector;
992+
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
993+
PipelineInjector injector(global_symbol);
993994
for (const auto& kv : func->buffer_map) {
994995
const Buffer& buffer = kv.second;
995996
injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
@@ -999,7 +1000,7 @@ class PipelineInjector : private StmtExprMutator {
9991000
}
10001001

10011002
private:
1002-
PipelineInjector() {}
1003+
explicit PipelineInjector(Optional<String> global_symbol) : global_symbol_(global_symbol) {}
10031004

10041005
/*!
10051006
* \brief Check the pipeline satisfies the following conditions:
@@ -1103,8 +1104,14 @@ class PipelineInjector : private StmtExprMutator {
11031104
Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_stage));
11041105
auto pipeline_orders =
11051106
Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_order));
1106-
CHECK_EQ(pipeline_stages.size(), original_order.size());
1107-
CHECK_EQ(pipeline_orders.size(), original_order.size());
1107+
CHECK_EQ(pipeline_stages.size(), original_order.size())
1108+
<< "PrimFunc " << global_symbol_ << " has original order "
1109+
<< original_order.Map([](const auto& block) { return block->name_hint; })
1110+
<< ", but pipeline annotation is " << pipeline_stages << " with different size";
1111+
CHECK_EQ(pipeline_orders.size(), original_order.size())
1112+
<< "PrimFunc " << global_symbol_ << " has original order "
1113+
<< original_order.Map([](const auto& block) { return block->name_hint; })
1114+
<< ", but pipeline annotation is " << pipeline_orders << " with different size";
11081115

11091116
std::unordered_set<int> pipeline_async_stages;
11101117
if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) {
@@ -1205,6 +1212,7 @@ class PipelineInjector : private StmtExprMutator {
12051212
Map<Var, Buffer> buffer_data_to_buffer_;
12061213
std::unordered_map<const VarNode*, FragmentInfo> fragment_info_;
12071214
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
1215+
Optional<String> global_symbol_;
12081216
};
12091217

12101218
} // namespace software_pipeline

0 commit comments

Comments
 (0)