Skip to content

Commit d1724d5

Browse files
mbs-octomlyangulei
authored andcommitted
Prepare for switching VM to LowerTEPass. (apache#9550)
This is a grab bag of fallout changes from switching the VM to use LoweTEPass which can be easily split out of the main apache#9483 PR. - AnnotateSpans can be used from C++ (though, unfortunately, it didn't help me with debugging since spans are universally dropped in most passes). - Can get a human readable dump of the VM's PackedFunc names and indexes for debugging. - If TVM_LOG_DEBUG defined then include types and ids of GlobalVars. I had a lot of difficulty tracking down where duplicate GlobalVars for the same name_hint were getting created and propagated. - GetCallLoweredProps follows same API as GetDeviceCopy and GetOnDevice where will return 'null' properties if call/expr is not of call_lowered form. Mildly more convenient, though switching all the above to ICHECK and push 'if (op == the relevant op)' into all use sites would also be just fine. - Misc VLOG improvements made while tracking down issues in apache#9483.
1 parent 4260b8c commit d1724d5

File tree

27 files changed

+311
-112
lines changed

27 files changed

+311
-112
lines changed

include/tvm/parser/parser.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
#ifndef TVM_PARSER_PARSER_H_
2121
#define TVM_PARSER_PARSER_H_
2222
/*!
23-
* \file parser.h
23+
* \file include/tvm/parser/parser.h
2424
* \brief A parser for TVM IR.
2525
*/
2626
#include <tvm/ir/module.h>
27+
#include <tvm/ir/transform.h>
2728
#include <tvm/runtime/packed_func.h>
2829
#include <tvm/runtime/registry.h>
2930

@@ -39,6 +40,13 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
3940
const Optional<IRModule>& init_module = Optional<IRModule>(),
4041
const MetaTable& init_meta_table = MetaTable());
4142

43+
/*!
44+
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
45+
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
46+
* modules constructed programaticaly rather than textually.
47+
*/
48+
transform::Pass AnnotateSpans();
49+
4250
} // namespace parser
4351
} // namespace tvm
4452

include/tvm/runtime/vm/executable.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ class Executable : public ModuleNode {
144144
*/
145145
std::string GetVirtualDevices() const;
146146

147+
/*!
148+
* \brief Returns a description of all the 'primitive' (ie PackedFuncs) in the executable.
149+
* These correspond to eithed PrimFuncs we've compiled locally, or functions compiled by
150+
* a BYOC external codegen.
151+
*/
152+
std::string GetPrimitives() const;
153+
147154
/*!
148155
* \brief Print the detailed statistics of the given code, i.e. number of
149156
* globls and constants, etc.
@@ -201,9 +208,9 @@ class Executable : public ModuleNode {
201208
int host_device_index = -1;
202209
/*! \brief The global constant pool. */
203210
std::vector<ObjectRef> constants;
204-
/*! \brief A map from globals (as strings) to their index in the function map. */
211+
/*! \brief A map from globals (as strings) to their index in the Relay function map. */
205212
std::unordered_map<std::string, Index> global_map;
206-
/*! \brief A mapping from the packed function (as string) to the index that
213+
/*! \brief A mapping from the packed function's global name (as string) to the index that
207214
* corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
208215
*/
209216
std::unordered_map<std::string, Index> primitive_map;

include/tvm/target/compilation_config.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace tvm {
3232

3333
/*!
3434
* \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to
35-
* compile a Relay module. All centralizes any setup and validation logic needed to transition
35+
* compile a Relay module. Centralizes any setup and validation logic needed to transition
3636
* from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly
3737
* (eg a a list of \p Targets) to the configuration.
3838
*
@@ -49,9 +49,12 @@ namespace tvm {
4949
class CompilationConfigNode : public Object {
5050
public:
5151
/*!
52-
* \brief The legacy targets map, mapping device type to \p Targets. Does not include any
53-
* entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType,
54-
* though we want to get rid of that limitation.
52+
* \brief The legacy targets map, mapping device type to the corresponding \p Target to use
53+
* when compiling primitive functions. Does not include an entry for the host target, however
54+
* each \p Target in this map will have it's \p host field set to the \p host_target.
55+
*
56+
* Currently we require at most one \p Target per \p DLDeviceType, though we want to get rid of
57+
* that limitation.
5558
*
5659
* CAUTION: Since keys are \p Integers they are compared by object equality not integer
5760
* value.
@@ -63,13 +66,18 @@ class CompilationConfigNode : public Object {
6366
/*!
6467
* \brief The host target. Used for 'scalar' data and code (such as shapes and shape
6568
* functions) and residual Relay expressions and data (such as conditionals and ADTs).
69+
*
70+
* Note that it is possible for a \p Target used for primitive operations to be structurally
71+
* equal to the host \p Target (up to the \p host field.) However the \p Target objects will
72+
* be distinct, and can be used as keys within a \p Map without collision.
6673
*/
6774
Target host_target;
6875

6976
/*!
70-
* \brief Vector of all available targets for primitive operators. May contain a \p Target
71-
* for the same device type as for the \p host_target, however the \p host_target should
72-
* be preferred for all host computations and data.
77+
* \brief Vector of all available \p Targets for compiling primitive operators. May contain
78+
* a \p Target for the same device type as for the \p host_target, however the \p host_target
79+
* should be used for all host computations and data. Each \p Target will have \p host_target
80+
* as its host.
7381
*/
7482
Array<Target> primitive_targets;
7583

include/tvm/target/se_scope.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ class SEScope : public ObjectRef {
297297
return SEScope(device.device_type, device.device_id, std::move(target));
298298
}
299299

300+
/*! \brief Returns the \p SEScope for \p target. */
301+
static SEScope ForTarget(Target target) {
302+
return SEScope(static_cast<DLDeviceType>(target->kind->device_type), /*virtual_device_id=*/0,
303+
std::move(target));
304+
}
305+
300306
/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
301307
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
302308
MemoryScope memory_scope) {

python/tvm/runtime/vm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, mod):
7373
self._get_bytecode = self.mod["get_bytecode"]
7474
self._get_constants = self.mod["get_constants"]
7575
self._get_virtual_devices = self.mod["get_virtual_devices"]
76+
self._get_primitives = self.mod["get_primitives"]
7677
self._get_stats = self.mod["get_stats"]
7778
self._get_function_arity = self.mod["get_function_arity"]
7879
self._get_function_param_name = self.mod["get_function_param_name"]
@@ -257,6 +258,12 @@ def virtual_devices(self):
257258
"""Returns a human-readable description of all the (virtual) devices in the executable."""
258259
return self._get_virtual_devices()
259260

261+
@property
262+
def primitive(self):
263+
"""Returns a human-readable dencription of all the primitives (ie PackedFuncs) in the
264+
executable"""
265+
return self._get_primitives()
266+
260267
@property
261268
def globals(self):
262269
"""Get the globals used by the Relay VM executable.

src/ir/module.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
213213
ICHECK_EQ((*it).second, var);
214214
} else {
215215
ICHECK(global_var_map_.count(var->name_hint) == 0)
216-
<< "Duplicate global function name " << var->name_hint;
216+
<< "Duplicate global function name " << PrettyPrint(var);
217217
}
218218

219219
global_var_map_.Set(var->name_hint, var);
@@ -243,7 +243,7 @@ void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData&
243243
if (!update) {
244244
// set global type var map
245245
ICHECK(global_type_var_map_.count(var->name_hint) == 0)
246-
<< "Duplicate global type definition name " << var->name_hint;
246+
<< "Duplicate global type definition name " << PrettyPrint(var);
247247
}
248248
global_type_var_map_.Set(var->name_hint, var);
249249
RegisterConstructors(var, type);
@@ -266,7 +266,7 @@ void IRModuleNode::Remove(const GlobalVar& var) {
266266

267267
BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
268268
auto it = functions.find(var);
269-
ICHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
269+
ICHECK(it != functions.end()) << "There is no definition of " << PrettyPrint(var);
270270
return (*it).second;
271271
}
272272

@@ -277,7 +277,7 @@ BaseFunc IRModuleNode::Lookup(const String& name) const {
277277

278278
TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
279279
auto it = type_definitions.find(var);
280-
ICHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
280+
ICHECK(it != type_definitions.end()) << "There is no definition of " << PrettyPrint(var);
281281
return (*it).second;
282282
}
283283

@@ -306,6 +306,10 @@ String IRModuleNode::GetUniqueName(const String& name) {
306306
}
307307
}
308308

309+
/*!
310+
* \brief Renames global type/term variables to prefer the GlobalTypeVar/GlobalVar in the lhs
311+
* ('one') side above the rhs ('two').
312+
*/
309313
struct Renamer : relay::ExprMutator, TypeMutator {
310314
Map<String, GlobalVar> defs;
311315
Map<String, GlobalTypeVar> types;
@@ -411,7 +415,6 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Map<GlobalVar, BaseFunc
411415
void IRModuleNode::Import(const String& path) {
412416
if (this->import_set_.count(path) == 0) {
413417
this->import_set_.insert(path);
414-
DLOG(INFO) << "Importing: " << path;
415418
std::fstream src_file(path, std::fstream::in);
416419
std::string file_contents{std::istreambuf_iterator<char>(src_file),
417420
std::istreambuf_iterator<char>()};

src/parser/parser.cc

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,7 +1909,8 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,
19091909

19101910
IRModule ParseModule(const std::string& file_name, const std::string& file_content,
19111911
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
1912-
VLOG(9) << "ParseModule";
1912+
VLOG_CONTEXT << "ParseModule";
1913+
VLOG(9) << "parsing and type-checking " << file_name;
19131914
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
19141915
auto mod = parser.ParseModule();
19151916
ICHECK(mod.defined()) << "The parser must return a non-null module.";
@@ -1952,15 +1953,21 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
19521953
return ParseExpr(file_name, file_content);
19531954
});
19541955

1955-
TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
1956-
return CreateModulePass(
1957-
[](const IRModule& mod, const PassContext& ctx) {
1958-
String text = AsText(mod, /*show_meta_data=*/true);
1959-
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
1960-
return ParseModule("GeneratedSource", text);
1961-
},
1962-
0, "AnnotateSpans", {});
1963-
});
1956+
/*!
1957+
* \brief This pass pretty-prints mod then parses it back so as to establish spans and sources
1958+
* for all Relay sub-expressions. This improves error and debugging diagnostics downstream for
1959+
* modules constructed programaticaly rather than textually.
1960+
*/
1961+
Pass AnnotateSpans() {
1962+
auto pass_func = [](const IRModule& mod, const PassContext& ctx) {
1963+
String text = AsText(mod, /*show_meta_data=*/true);
1964+
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
1965+
return ParseModule("GeneratedSource", text);
1966+
};
1967+
return CreateModulePass(pass_func, 0, "AnnotateSpans", {});
1968+
}
1969+
1970+
TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans);
19641971

19651972
} // namespace parser
19661973
} // namespace tvm

src/printer/relay_text_printer.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,17 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
499499
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
500500
}
501501

502-
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }
502+
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
503+
Doc doc;
504+
doc << "@" << op->name_hint;
505+
#if TVM_LOG_DEBUG
506+
if (op->checked_type_.defined()) {
507+
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
508+
}
509+
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
510+
#endif
511+
return doc;
512+
}
503513

504514
Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }
505515

src/printer/text_printer.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,29 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
5656
if (kv.second.as<relay::FunctionNode>()) {
5757
std::ostringstream os;
5858
os << "def @" << kv.first->name_hint;
59+
#if TVM_LOG_DEBUG
60+
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
61+
#endif
5962
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
6063
} else if (kv.second.as<tir::PrimFuncNode>()) {
61-
doc << "@" << kv.first->name_hint << " = ";
62-
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
64+
doc << "@" << kv.first->name_hint;
65+
#if TVM_LOG_DEBUG
66+
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
67+
#endif
68+
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
6369
}
6470
doc << Doc::NewLine();
6571
}
72+
#if TVM_LOG_DEBUG
73+
// attributes
74+
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
75+
doc << "attributes {" << Doc::NewLine();
76+
for (const auto& kv : mod->attrs->dict) {
77+
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
78+
}
79+
doc << "}" << Doc::NewLine();
80+
}
81+
#endif
6682
return doc;
6783
}
6884

src/relay/backend/aot_executor_codegen.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
8383
Expr func;
8484
Array<Expr> args;
8585

86-
if (call_node->op == CallLoweredOp()) {
87-
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
86+
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
87+
if (call_lowered_props.lowered_func.defined()) {
8888
func = call_lowered_props.lowered_func;
8989
args = call_lowered_props.arguments;
9090
} else { // Relay functions that have not been lowered and lowered extern functions
@@ -516,10 +516,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
516516
}
517517
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
518518
} else {
519-
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
520-
"applying the fuse_ops transformation to the "
521-
"expression.";
522519
call_lowered_props = GetCallLoweredProps(call_node);
520+
ICHECK(call_lowered_props.lowered_func.defined())
521+
<< "Operators should be transformed away; Try "
522+
"applying the fuse_ops transformation to the "
523+
"expression.";
523524
for (const auto& arg : call_lowered_props.arguments) {
524525
VisitExpr(arg);
525526
}
@@ -717,6 +718,14 @@ class AOTExecutorCodegen : public MixedModeVisitor {
717718
: mod_(mod), targets_(targets), target_host_(target_host), use_unpacked_api_(Bool(false)) {}
718719

719720
LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
721+
VLOG_CONTEXT << "AOT";
722+
for (const auto& kv : targets_) {
723+
VLOG(1) << "target: " << kv.second->ToDebugString();
724+
}
725+
if (target_host_.defined()) {
726+
VLOG(1) << "target host: " << target_host_->ToDebugString();
727+
}
728+
720729
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
721730
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
722731
Integer workspace_byte_alignment =
@@ -793,10 +802,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
793802
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
794803
}
795804

796-
// Build the TIR IRModule for the AOT function
805+
// Build the TIR IRModule for the main AOT function
797806
Map<GlobalVar, BaseFunc> symbol_map;
798807
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
799808
IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs);
809+
VLOG(1) << "main module:" << std::endl << PrettyPrint(mod_run);
800810

801811
// Apply storage rewrite pass to the runner function to do memory planning
802812
auto storage_rewrite = tir::transform::StorageRewrite();
@@ -827,12 +837,23 @@ class AOTExecutorCodegen : public MixedModeVisitor {
827837
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";
828838

829839
// This is the point where we separate the functions in the module by target
840+
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
830841
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
842+
VLOG(1) << "per-target modules:";
843+
for (const auto& kv : ret.lowered_funcs) {
844+
VLOG(1) << "target:" << std::endl
845+
<< kv.first->ToDebugString() << std::endl
846+
<< "maps to:" << std::endl
847+
<< PrettyPrint(kv.second);
848+
}
849+
831850
ret.external_mods = external_modules.value();
832851

833852
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
853+
VLOG(1) << "merging main into existing module for host target";
834854
ret.lowered_funcs[target_host_]->Update(mod_run);
835855
} else {
856+
VLOG(1) << "adding main into new module for host target";
836857
ret.lowered_funcs.Set(target_host_, mod_run);
837858
}
838859

0 commit comments

Comments
 (0)