Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Remove memory planing from LowerTEPass #8974

Merged
merged 18 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/ci-problem.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Provide a link to the specific run that has failed.

### Flakiness

Have you seen this multiple times in this branch or in other branches?
Have you seen this multiple times in this branch or in other branches?
1 change: 0 additions & 1 deletion .github/ISSUE_TEMPLATE/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ Include the title of the document (e.g. "Getting Started with TVM"), and the typ
If an RFC/discuss post exists, link it here.

Otherwise, specify what actions should be taken to provide additional clarity/readability/reproducibility to the document. Include code snippets from the previous documentation if applicable.

2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
# Using this script we can reuse docker/install scripts to configure the reference
# Using this script we can reuse docker/install scripts to configure the reference
# virtual machine similar to CI QEMU setup.
#

Expand Down
2 changes: 1 addition & 1 deletion cmake/modules/contrib/EthosU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
if(USE_ETHOSU)
file(GLOB ETHOSU_RELAY_CONTRIB_SRC src/relay/backend/contrib/ethosu/*)
list(APPEND COMPILER_SRCS ${ETHOSU_RELAY_CONTRIB_SRC})
endif(USE_ETHOSU)
endif(USE_ETHOSU)
4 changes: 2 additions & 2 deletions docker/install/ubuntu_install_vitis_ai_packages_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
2 changes: 1 addition & 1 deletion docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Or a convolution with a specific kernel size:
x = relay.var('x')
y = relay.var('y')
assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))



Matching an Optional Op
Expand Down
16 changes: 13 additions & 3 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -583,8 +583,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just put the func_info on the mod here before passing the module into LowerTE? Then you don't need to re-extract it later, and also the logic surrounding func_info is all in one place. (LowerTEPass should preserve all attributes on modules passed into it)


IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand Down Expand Up @@ -659,8 +666,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Downcast<tir::PrimFunc>(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)),
workspace_byte_alignment);

lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I said above, can you attach the func_info right before LowerTEPass is called? And we can then remove the check about whether the attribute is on the module.

ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size);
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
Expand Down
17 changes: 14 additions & 3 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <string>
#include <vector>

#include "te_compiler.h"
#include "utils.h"
#include "./te_compiler.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -221,8 +221,16 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

backend::FunctionInfo func_info;

if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info =
relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in aot_executor_codegen -- can we put the func_info on the module here instead of after LowerTEPass is called and delete the check for main_func_info being set?

IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
tec::LowerTEPass(targets_, device_context_map, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -236,8 +244,11 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");

ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

Expand Down
25 changes: 12 additions & 13 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*
* @param prim_fn_var Global bound to lowered primitive.
* @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if neeeded.
* @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed.
* @param all_prim_shape_fn_vars All globals references by lowered shape function, plus
* prim_shape_fn_var itself.
* @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic
Expand Down Expand Up @@ -763,7 +763,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
ObjectRef val = Eval(op->tuple);
const auto* adt_obj = val.as<ADTObj>();
ICHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value";
ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value";
auto adt = GetRef<ADT>(adt_obj);
ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds";
return adt[op->index];
Expand Down Expand Up @@ -906,17 +906,16 @@ IRModule Prepare(IRModule mod, Device device, Target target) {
backend::StaticMemoryPlan memory_plan; /*=nullptr*/

// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});
transform::Sequential seq({transform::SimplifyInference(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we have inconsistent formatters? In any case I'd revert this whitespace change.

// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
133 changes: 58 additions & 75 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

#include "te_compiler.h"
#include "./te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
Expand All @@ -42,8 +42,8 @@
#include <utility>
#include <vector>

#include "te_compiler_cache.h"
#include "utils.h"
#include "./te_compiler_cache.h"
#include "./utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -596,68 +596,19 @@ class LowerTensorExprMutator : public ExprMutator {
const Op& debug_op_;
};

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we don't move these to keep the diff down.

backend::StaticMemoryPlan memory_plan, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets) {
if (targets.size() == 1) {
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
return (*it).second;
} else {
// The heterogeneous execution case, return the target associated with the
// given device type.
// If "dev_type" equals to 0, the device name only can be got from
// "targets", and it may not be "llvm", so here just set it to "unknown".
std::string dev_name = "unknown";
if (dev_type != 0) {
dev_name = runtime::DeviceName(dev_type);
}

if (targets.count(dev_type) == 0) {
std::stringstream msg;
msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
<< dev_name << " mapped to device type (" << dev_type
<< ") which was not found in the target map.\n"
<< "Availible targets: \n";
for (auto target : targets) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}
return targets[dev_type];
}
}

/*!
* \brief Update the "main" control function's metadata
*
* \param mod The module
* \param targets Map of targets
* \return function_infos Function info for each function in the module
*/

backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap targets,
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMap targets,
Map<Expr, backend::StorageInfo> storage_info_map) {
CHECK_EQ(mod->functions.size(), 1)
<< "There should only be one function in the module passed to UpdateMainWorkspaceSize";
Function func = Downcast<Function>(mod->Lookup("main"));

// This is a Map<device,Map<storage_id, size>>
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, EnumClassHash> sid_workspace;
std::unordered_map<DLDeviceType, std::unordered_map<int, int>, backend::EnumClassHash>
sid_workspace;
// This is a Map<device, size_of_inputs_and_outputs>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_io;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_io;
// This is a Map<device, size_of_constants>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_consts;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_consts;

// Initialize the mapping from all storage identifiers to workspace sizes,
// the amount of device io, and the device constants.
Expand Down Expand Up @@ -723,7 +674,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

// This is a Map<device, workspace_size>
std::unordered_map<DLDeviceType, int, EnumClassHash> device_workspace;
std::unordered_map<DLDeviceType, int, backend::EnumClassHash> device_workspace;
// Once we know the sizes of sids, we need to accumulate per device
for (const auto& dev_sid_size : sid_workspace) {
auto dev = dev_sid_size.first;
Expand All @@ -746,24 +697,65 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
}

for (const auto& dev_and_size : device_workspace) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
workspace_sizes.Set(tgt, dev_and_size.second);
relay_primfuncs.Set(tgt, func);
}
for (const auto& dev_and_size : device_io) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
io_sizes.Set(tgt, dev_and_size.second);
}

for (const auto& dev_and_size : device_consts) {
auto tgt = GetTargetFromInteger(dev_and_size.first, targets);
auto tgt = tec::GetTargetFromInteger(dev_and_size.first, targets);
constant_sizes.Set(tgt, dev_and_size.second);
}

return backend::FunctionInfo(workspace_sizes, io_sizes, constant_sizes, tir_primfuncs,
relay_primfuncs);
}

Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) {
if (targets.size() == 1) {
// The homogeneous execution case, return the only target.
const auto& it = targets.begin();
return (*it).second;
} else {
// The heterogeneous execution case, return the target associated with the
// given device type.
// If "dev_type" equals to 0, the device name only can be got from
// "targets", and it may not be "llvm", so here just set it to "unknown".
std::string dev_name = "unknown";
if (dev_type != 0) {
dev_name = runtime::DeviceName(dev_type);
}

if (targets.count(dev_type) == 0) {
std::stringstream msg;
msg << "No target is specified for provided device name: `" << dev_name << "`\n\n"
<< dev_name << " mapped to device type (" << dev_type
<< ") which was not found in the target map.\n"
<< "Availible targets: \n";
for (auto target : targets) {
msg << " " << target.first << "-> " << target.second << "\n";
}
LOG(FATAL) << msg.str();
}
return targets[dev_type];
}
}

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}

/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
Expand Down Expand Up @@ -844,20 +836,13 @@ void UpdateFunctionMetadata(Function relay_func,
}

IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
const String& module_name, std::function<void(Function)> process_fn) {
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);

TECompiler compiler;

backend::FunctionInfo func_info;
if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = UpdateMainWorkspaceSize(module, targets, memory_plan->expr_to_storage_info);
}

auto updated_module = LowerTensorExpr(targets, device_context_map, memory_plan, module_name,
compiler, process_fn)(module);
auto updated_module =
LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module);

// A temporary solution until we can rewrite the auto-scheduler task extraction code to work
// in a more reasonable way.
Expand All @@ -882,7 +867,6 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con

// Annotate the module with the external modules and function info
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
updated_module = WithAttr(updated_module, "main_func_info", func_info);

return updated_module;
}
Expand Down Expand Up @@ -919,12 +903,11 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
return LowerTE(module, targets, device_context_map, module_name, process_fn);
};
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
Expand Down
Loading