Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up IRModule attrs and LowerTEPass #8914

Merged
merged 1 commit into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class IRModuleNode : public Object {
v->Visit("global_var_map_", &global_var_map_);
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a very important change and should be called out in the PR title or the main description.

What happens to these when they are split to Map<Target, IRModule> in lowered_funcs or per_target_module in the intepretter ? Are they copied in ?

Will it be possible to add a test to make sure attrs are passed onto TIR lowering ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After having checked :

Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";
// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules[target.value()] = target_module;

I dont think it is transferred.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @manupa-arm, attrs were actually added to IRModule in #8750 so anything we do now is an iterative improvement on that. Given this series of PRs aims to remove the Map<Target, IRModule> entirely (this one removes the per_target_module from the interpreter) I don't think we need to ensure the copy happens here but I agree we should have some test coverage when the unified IRModule is lowered to ensure it contains all the attributes we've accrued - this should be a follow up when we change the interface from Map<Target, IRModule> to IRModule - does that sound reasonable to you?

Copy link
Contributor

@manupak manupak Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be a PR to remove lowered_funcs as well ? @electriclilies

Copy link
Contributor

@manupak manupak Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not very clear to me how we can avoid per_target_module throughout the lowering process -- we could push it way down.

Unified IRModule --> per_target_modules (lowered_funcs) --> runtime.Modules

Unless Im missing something here, there will always be a stage that IRModule contains things that gets lowered to a single runtime.Module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely, although I have not dug into that part of the codebase in depth yet so I can't say for sure.
The two options that I think are most likely are

  1. build consuming the IRModule directly, traversing the functions in the IRModule and dealing with each directly (which is what you just mentioned)
  2. Right before build is invoked, separating the functions in the module by Target (although we wouldn't store them in a Map<Target, IRModule>)

So to summarize, we'll either completely remove the data structure that stores functions separated by target, or just push it all the way down to right before build is called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the attrs of IRModule are globally valid to all targets. There I feel conveying them to the codegen might still be beneficial.

For 1, we would need to expand the build API to pass the attrs

For 2, we can embed them to each IRModule.

The attrs allows a channel that pass through all of the unified lowering process up until the concept of IRModule cease to exist.

If you agree, I feel we should pass them in to the per target IRModule, then later changed in either way we decide to proceed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I can add them to the per target IRModules and the we can figure out what to do with it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to:

  • preserving module attrs through all passes on the assumption they apply to all targets
  • xfering them when we project a cross-target IRModule to a specific-target IRModule just before the transition into lowering phases

We should probably have a convention that IRModule attrs should be for describing properties of the code and not of the compilation? Since they will appear everywhere it would be tempting to start adding Target-like things there.

Does anyone have thoughts on whether we should try to line up the IRModule and runtime::Module worlds?

                   cross-target         single-target
IR                 IRModule             ?
Executable         ?                    runtime::Module       

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That convention sounds good to me, but one thing we should be cautious of is that the attributes of functions do contain properties of the compilation flow (like targets), and this is inconsistent.

@mbs-octoml By lining up the IRModule and runtime::Module worlds do you mean also moving to a world where runtime::Modules are cross-target? I'm honestly not sure if that's something we want to do. @areusch do you have any thoughts about this?

}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -277,6 +278,12 @@ class IRModuleNode : public Object {
*/
TVM_DLL void Update(const IRModule& other);

/*!
* \brief Create a shallow copy of this IRModule.
* \returns The shallow copy of the IRModule.
*/
TVM_DLL IRModule ShallowCopy();

/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
Expand Down Expand Up @@ -348,12 +355,14 @@ class IRModule : public ObjectRef {
* \brief constructor
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, parser::SourceMap map = {});
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
DictAttrs attrs = {});

/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
Expand Down Expand Up @@ -415,6 +424,13 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromText(const String& text, const String& source_path);

/*!
* \brief Create a shallow copy of an IRModule.
* \param mod The module to copy.
* \return The copied module.
*/
IRModule ShallowCopyIRModule(IRModule mod);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;

Expand Down
11 changes: 10 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ namespace tvm {

IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, parser::SourceMap source_map) {
std::unordered_set<String> import_set, parser::SourceMap source_map,
DictAttrs attrs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add attrs to SEqualReduce (return equal(attrs, other->attrs);) and SHashReduce (hash_reduce(attrs)) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll do this. FunctionNode has attrs in the hash and equal functions, so we should be consistent with that.

auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
Expand All @@ -52,6 +53,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
n->constructor_tag_map_ = {};
n->import_set_ = std::move(import_set);
n->source_map = source_map;
n->attrs = std::move(attrs);

for (const auto& kv : n->functions) {
// set global var map
Expand All @@ -72,6 +74,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,

bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
if (!equal(this->attrs, other->attrs)) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
Expand Down Expand Up @@ -112,6 +115,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
hash_reduce(this->attrs);
}

bool IRModuleNode::ContainGlobalVar(const String& name) const {
Expand Down Expand Up @@ -361,6 +365,11 @@ void IRModuleNode::Update(const IRModule& mod) {
}
}

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
Expand Down
13 changes: 5 additions & 8 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,26 +241,23 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

// Get only the Relay functions out of the lowered module so we can run type inference on them
IRModule main_module = tec::GetMainModule(lowered_mod);
main_module = relay::transform::InferType()(main_module);
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));
Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));

// Now that we have lowered all operators to TIR code, we can proceed with compilation.
//
// We need to unfortunately re-plan as the previous results have been invalidated by lowering
// we will fix this in future refactors.
memory_plan_ = GraphPlanMemory(main_func);
memory_plan_ = GraphPlanMemory(lowered_main_func);

// The graph planner also can not handle planning calls to global variables to we must remap

// First we convert all the parameters into input nodes.
for (auto param : main_func->params) {
for (auto param : lowered_main_func->params) {
auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
var_map_[param.get()] = AddNode(node_ptr, param);
}

heads_ = VisitExpr(main_func->body);
heads_ = VisitExpr(lowered_main_func->body);
std::ostringstream os;

dmlc::JSONWriter writer(&os);
Expand All @@ -277,7 +274,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
Expand Down
52 changes: 20 additions & 32 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
// LoweredModule.
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
target_(target),
debug_op_(Op::Get("debug")) {}
Interpreter(IRModule unified_mod, Device device, Target target)
: unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {}

template <typename T>
T WithFrame(const Frame& fr, const std::function<T()>& f) {
Expand All @@ -316,7 +310,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef<Var>(var_node)); }

ObjectRef VisitExpr_(const GlobalVarNode* op) final {
return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
return Eval(unified_mod_->Lookup(GetRef<GlobalVar>(op)));
}

ObjectRef VisitExpr_(const OpNode* id) override {
Expand Down Expand Up @@ -387,9 +381,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
Map<Target, IRModule> per_target_module = tec::GetPerTargetModules(unified_mod_);
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_module_std_map =
backend::TargetModuleMapToTargetStrModuleMap(per_target_module_);
per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module);
auto mod_itr = per_target_module_std_map.find(target);
ICHECK(mod_itr != per_target_module_std_map.end())
<< "No target module for target '" << target->str() << "'";
Expand Down Expand Up @@ -876,13 +870,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
}

private:
// Main module. All expressions are eval'ed w.r.t. the definitions in this module. This module
// may contain calls to TIR functions bound in a per_target_module_ below.
IRModule mod_;
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<Target, IRModule> per_target_module_;
// Unified module. Functions are annotated with their target.
// All expressions are eval'ed w.r.t. the definitions in this module.
// This module contains functions that used to be in main_module and the per_target_module (TIR
// functions) in one module.
IRModule unified_mod_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
Expand All @@ -902,7 +894,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
IRModule Prepare(IRModule mod, Device device, Target target) {
// Things to initialize to pass into tec::LowerTEPass
// We only have one device-specific target.
tec::TargetMap targets = {{device.device_type, target}};
Expand Down Expand Up @@ -930,8 +922,7 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Lower all primitive functions reachable from expr.
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
return mod;
}

/*! \brief Check if an expression could be changed by \p Prepare.
Expand Down Expand Up @@ -1020,11 +1011,9 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
target);
IRModule lowered_mod = Prepare(mod_with_expr, device, target);

std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(lowered_mod, device, target);

//
// Step 2: Evaluate target function to a closure.
Expand Down Expand Up @@ -1063,12 +1052,11 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
target);
Expr expr_to_eval = main_and_lowered.first->GetGlobalVar(mod_and_global.second->name_hint);

IRModule mod = Prepare(mod_and_global.first, device, target);

Interpreter intrp(mod, device, target);
Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint);
if (expr.as<BaseFuncNode>() == nullptr) {
// TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr
// unless it is a function, so we must reverse that in the expression to eval.
Expand Down
27 changes: 5 additions & 22 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {

// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
// Initialize the IRModule for this target with the attributes from the input IRModule
IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs);
// Add the function to the IRModule
target_module->Add(var, func);
per_target_modules[target.value()] = target_module;
} else {
Expand All @@ -918,33 +919,15 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

IRModule GetMainModule(IRModule mod) {
IRModule main_module;
// Copy the type defs
for (const auto& kv : mod->type_definitions) {
main_module->AddTypeDef(kv.first, kv.second);
}
// Copy all Relay functions (we don't include PrimFuncs)
for (auto kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tvm::relay::FunctionNode>()) {
main_module->Add(var, func);
}
}
return main_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
};
// TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
// be called afterwards
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
}
} // namespace tec
} // namespace relay
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,6 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
*/
Map<Target, IRModule> GetPerTargetModules(IRModule mod);

/*!
* \brief Utility to extract all the Relay functions from an IRModule, with no PrimFuncs.
* \param mod The IRModule to extract the Relay functions from
* \return An IRModule containing only the Relay functions that are in the input mod (no PrimFuncs)
*/
IRModule GetMainModule(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
Expand Down
4 changes: 1 addition & 3 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
DLOG(INFO) << "Executing function pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;

// Execute the pass function and return a new module.
IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule updated_mod = mod->ShallowCopy();

std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
Expand Down
5 changes: 4 additions & 1 deletion src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
*/

#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -509,7 +510,9 @@ class NameMangleExtFuncs : public MixedModeMutator {

// Walk the tree and mangle the functions. Then replace compiler functions
// with mangled functions in the module
IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports());
IRModule new_module = module_->ShallowCopy();
new_module->functions = {};

for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;

// Create a new module by shallow copy.
auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule mod_ = mod->ShallowCopy();

tvm::Map<GlobalVar, Function> updates;
auto funcs = mod_->functions;
Expand Down
17 changes: 10 additions & 7 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables "
<< "without a module");
}

if (mod_->ContainGlobalVar(var->name_hint)) {
relay::Function e = Downcast<Function>(mod_->Lookup(var));
return e->checked_type();
} else {
return op->checked_type_;
BaseFunc func = mod_->Lookup(var->name_hint);

if (func->IsInstance<FunctionNode>()) {
relay::Function relay_func = Downcast<Function>(func);
return relay_func->checked_type();
}
}
// Return op->checked_type if the module doesn't contain the GlobalVar or the function is a
// PrimFunc (we don't typecheck PrimFuncs)
return op->checked_type_;
}

Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); }
Expand Down Expand Up @@ -822,8 +826,7 @@ Pass InferType() {
[=](IRModule mod, const PassContext& pass_ctx) {
DLOG(INFO) << "tvm::relay::transform::InferType";
// Execute the pass function and return a new module.
IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule updated_mod = mod->ShallowCopy();

pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod);

Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_backend_compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def get_func(shape):
engine.dump()


# Note: Once compile engine is removed, we should keep this test so that
# we make sure that opt_level=0 passes are being called correctly.
def test_compile_placeholder_bypass():
engine = relay.backend.compile_engine.get()
x = relay.var("x", shape=(2, 3))
Expand Down