Skip to content

Commit

Permalink
[Relay] Remove memory planing from LowerTEPass (apache#8974)
Browse files Browse the repository at this point in the history
* Clean up LowerTEPass

Add attrs to IRModule equal and hash

Make LowerTEPass opt_level 0

Copy IRModule attrs to per_target_modules

Add ShallowCopy to IRmodule

* Fix rebase

* Remove comment

* [TEC] Remove memory plan from LowerTEPass

* Fix linting errors

* Fix PR comments

* Remove updated module with function info from LowerTe

* Refactor UpdateMainWorkspaceSize to update func info independently from LowerTEPass

* Fix aot failed tests

* Revert whitespaces fixes

* Remove obsolete function hoisting and minor cleanups

* Address PR comments

Co-authored-by: electriclilies <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 6f9bd94 commit 798a8e2
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 83 deletions.
16 changes: 12 additions & 4 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -583,8 +583,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand Down Expand Up @@ -661,7 +669,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";

main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size);
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

Expand Down
17 changes: 13 additions & 4 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -221,8 +221,17 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

backend::FunctionInfo func_info;

if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info =
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -238,7 +247,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";

function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));
Expand Down
28 changes: 12 additions & 16 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*
* @param prim_fn_var Global bound to lowered primitive.
* @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed.
* @param all_prim_shape_fn_vars All globals references by lowered shape function, plus
* prim_shape_fn_var itself.
* @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic
Expand Down Expand Up @@ -763,7 +763,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
const auto* adt_obj = val.as<ADTObj>();
ICHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value";
ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value";
auto adt = GetRef<ADT>(adt_obj);
ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds";
return adt[op->index];
Expand Down Expand Up @@ -902,21 +902,17 @@ IRModule Prepare(IRModule mod, Device device, Target target) {
// All calls to primitives will use the unique target.
tec::DeviceMap device_map;

// No need for a memory plan.
backend::StaticMemoryPlan memory_plan; /*=nullptr*/

// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
73 changes: 28 additions & 45 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

#include "te_compiler.h"
#include "./te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
Expand All @@ -42,8 +42,8 @@
#include <utility>
#include <vector>

#include "te_compiler_cache.h"
#include "utils.h"
#include "./te_compiler_cache.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -596,19 +596,7 @@ class LowerTensorExprMutator : public ExprMutator {
const Op& debug_op_;
};

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
if (targets.size() == 1) {
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
Expand Down Expand Up @@ -638,26 +626,30 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
}
}

/*!
* \brief Update the "main" control function's metadata
*
* \param mod The module
* \param targets Map of targets
* \return function_infos Function info for each function in the module
*/
Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets,
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
CHECK_EQ(mod->functions.size(), 1)
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
Function func = Downcast<Function>(mod->Lookup("main"));

// This is a Map<device,Map<storage_id, size>>
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, EnumClassHash> sid_workspace;
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
sid_workspace;
// This is a Map<device, size_of_inputs_and_outputs>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_io;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
// This is a Map<device, size_of_constants>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_consts;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;

// Initialize the mapping from all storage identifiers to workspace sizes,
// the amount of device io, and the device constants.
Expand Down Expand Up @@ -723,7 +715,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

// This is a Map<device, workspace_size>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_workspace;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
// Once we know the sizes of sids, we need to accumulate per device
for (const auto& dev_sid_size : sid_workspace) {
auto dev = dev_sid_size.first;
Expand All @@ -746,17 +738,17 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

for (const auto& dev_and_size : device_workspace) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
workspace_sizes.Set(tgt, dev_and_size.second);
relay_primfuncs.Set(tgt, func);
}
for (const auto& dev_and_size : device_io) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
io_sizes.Set(tgt, dev_and_size.second);
}

for (const auto& dev_and_size : device_consts) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
constant_sizes.Set(tgt, dev_and_size.second);
}

Expand Down Expand Up @@ -844,20 +836,13 @@ void UpdateFunctionMetadata(Function relay_func,
}

IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
const String& module_name, std::function<void(Function)> process_fn) {
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);

TECompiler compiler;

backend::FunctionInfo func_info;
if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info);
}

auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name,
compiler, process_fn)(module);
auto updated_module =
LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module);

// A temporary solution until we can rewrite the auto-scheduler task extraction code to work
// in a more reasonable way.
Expand All @@ -882,7 +867,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con

// Annotate the module with the external modules and function info
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
updated_module = WithAttr(updated_module, "main_func_info", func_info);

return updated_module;
}
Expand Down Expand Up @@ -919,12 +903,11 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
return LowerTE(module, targets, device_context_map, module_name, process_fn);
};
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
Expand Down
27 changes: 13 additions & 14 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,15 @@
#include "../transforms/infer_layout_utils.h"
#include "../transforms/pass_utils.h"
#include "./te_compiler_cache.h"
#include "utils.h"
#include "./utils.h"

namespace tvm {
namespace relay {
namespace tec {

// This class is needed to avoid a GCC 5 bug that prevents maps containing enums
// from being compiled. If i386 GCC version is increased, we can remove it.
struct EnumClassHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};

// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
// we should a version of context which works in Map
using TargetMap = std::unordered_map<DLDeviceType, Target, EnumClassHash>;
using TargetMap = std::unordered_map<DLDeviceType, Target, backend::EnumClassHash>;
using DeviceMap =
std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
using ProcessFn = std::function<void(Function)>;
Expand Down Expand Up @@ -158,6 +149,16 @@ void UpdateFunctionMetadata(Function relay_func,
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);

/*!
* \brief Update the "main" control function's metadata
*
* \param mod The module
* \param targets Map of targets
* \return function_infos Function info for each function in the module
*/
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map);

/*! \brief Utility to separate the functions in an IRModule by Target.
*
* \param mod The IRModule to extract the per target module from
Expand Down Expand Up @@ -192,15 +193,13 @@ IRModule LowerTE(
*
* \param targets The mapping for devices to targets.
* \param device_context_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \returns The pass which lowers primative functions to TIR
*/
transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn);
const String& module_name, std::function<void(Function)> process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ struct LoweredOutput {
runtime::Metadata metadata;
};

/*!
* \brief This class is needed to avoid a GCC 5 bug that prevents maps containing enums from being
compiled. If i386 GCC version is increased, we can remove it.
*/
struct EnumClassHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};

/*!
* \brief A helper to expand the params by adding the ones used in a given expression.
*/
Expand Down

0 comments on commit 798a8e2

Please sign in to comment.