Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,9 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Used in CI to communicate between Python and Jenkins
.docker-image-names/

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ TVM_DLL Pass LowerAsyncDMA();
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Add TIR-printer output as debug information to all ops in the module
* \return The pass.
*/

TVM_DLL Pass InstallDebugSpans();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,3 +1028,15 @@ def InstrumentProfileIntrinsics():
The result pass
"""
return _ffi_api.InstrumentProfileIntrinsics() # type: ignore


def InstallDebugSpans():
"""Add line information from the TIR printer as spans on each statement and
expression.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InstallDebugSpans() # type: ignore
8 changes: 8 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
Expand Down Expand Up @@ -603,6 +604,9 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Array<tvm::transform::Pass> host_pass_list;

runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
Expand All @@ -621,6 +625,10 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
host_pass_list.push_back(tir::transform::CombineContextCall());

if (enable_debug) {
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
}

Expand Down
1 change: 1 addition & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ Pass GetPass(const String& pass_name) {
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
VLOG(0) << "Running pass " << pass->Info()->name;
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!pass_ctx.PassEnabled(pass_info)) {
Expand Down
40 changes: 22 additions & 18 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
: show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}

/*! \brief Output a newline */
virtual Doc NewLine();

/*! \brief Print the node */
Doc Print(const ObjectRef& node);

Expand All @@ -290,24 +293,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
*/
bool GetVarName(::tvm::tir::Var v, std::string* s);

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

protected:
Doc VisitExpr_(const IntImmNode* op) override;
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
Expand Down Expand Up @@ -363,6 +349,24 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

Doc VisitType_(const PrimTypeNode* node) override;
Doc VisitType_(const PointerTypeNode* node) override;
Doc VisitType_(const TupleTypeNode* node) override;
Expand Down
53 changes: 27 additions & 26 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
for (const auto& it : op->attrs->dict) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
doc << Doc::Indent(2, attr_doc);
}

Expand All @@ -136,8 +136,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
const Buffer buf = op->buffer_map[v];
buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
}
buffer_doc << Doc::NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
buffer_doc << NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << NewLine()));
doc << Doc::Indent(2, buffer_doc) << "}";
}

Expand All @@ -149,26 +149,28 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
}
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}

doc << PrintBody(op->body);
return doc;
}

Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); }

Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
const auto* op = module.operator->();
Doc doc;

Doc body;
body << Doc::NewLine();
body << NewLine();
std::vector<Doc> functions;
for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
if ((*it).second.as<PrimFuncNode>()) {
functions.push_back(Print((*it).second));
}
}
body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine());
doc << Doc::Indent(0, body);
return doc;
}
Expand Down Expand Up @@ -451,7 +453,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {

Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body);
doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body);
return doc;
}

Expand All @@ -463,14 +465,14 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine()
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine()
<< Print(op->body);
return doc;
}
Expand Down Expand Up @@ -529,7 +531,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
Expand All @@ -542,19 +544,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc doc;
doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine();
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine();
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
Expand All @@ -572,9 +574,9 @@ Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
Doc seq_doc, doc;
for (Stmt stmt : op->seq) {
seq_doc << Doc::NewLine() << Print(stmt);
seq_doc << NewLine() << Print(stmt);
}
doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}";
return doc;
}

Expand Down Expand Up @@ -657,37 +659,36 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
Doc block_attr_doc;
// print predicate, binding, read/write tensor region, annotations
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")";
block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")";
}
for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
<< Print(op->iter_values[i]) << ")";
block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
if (!block_op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : block_op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", "))
<< "})";
block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
// print body
Doc body;
body << Doc::NewLine();
body << NewLine();
for (const auto& alloc_buf : block_op->alloc_buffers) {
body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype)
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
<< Print(alloc_buf->shape) << ")" << NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
<< NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(block_op->init.value());
body << init_block << Doc::NewLine();
body << init_block << NewLine();
}
body << Print(block_op->body);
doc << Doc::Indent(2, block_attr_doc << body);
Expand Down Expand Up @@ -826,7 +827,7 @@ Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
Doc doc;
if (body->IsInstance<SeqStmtNode>()) return Print(body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}";
return doc;
}

Expand Down
Loading