Skip to content

Commit

Permalink
Convert AOT to TECompiler (apache#8697)
Browse files Browse the repository at this point in the history
* Convert AOT to TECompiler

This removes the dependency on "compile_engine.h" from aot_executor_codegen.cc. This required a few changes to how AOT was operating:
* AOT run_model is now based on the post lowering main_module
* AOTOnDemandAllocator is ran twice to ensure SIDs are updated post-lowering
* Moved to using tec::UpdateFunctionMetadata

Tests are passing, but would appreciate other validation 😸

* Clarify reasoning behind replanning memory later

* Use main_func_info rather than bespoke logic in AOT

This moves from using the bespoke AOT UpdateMainWorkspaceSize to the
LoweredModule main_func_info property to unify with Graph executor
codegen.
  • Loading branch information
Mousius authored and ylc committed Jan 13, 2022
1 parent 56e276b commit 3d333f7
Showing 1 changed file with 64 additions and 173 deletions.
237 changes: 64 additions & 173 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
#include <string>
#include <vector>

#include "compile_engine.h"
#include "te_compiler.h"
#include "utils.h"

namespace tvm {
namespace relay {
namespace backend {

using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
using StorageMap =
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;

Expand Down Expand Up @@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor {
void CreateFuncCall(Call call, std::string func_name) {
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;

// Pack the inputs
for (Expr arg : call->args) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
Expand Down Expand Up @@ -365,155 +363,21 @@ class AOTExecutorCodegen : public ExprVisitor {
return ss.str();
}

/*!
* \brief Update the "main" control function's metadata
*
* \param func The main function that contains calls to operator tir primitive functions
*/
void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) {
auto workspace_byte_alignment = target_host_->GetAttr<Integer>("workspace-byte-alignment")
.value_or(tvm::runtime::kDefaultWorkspaceAlignment);
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
// Populate FunctionInfo
auto fi_node = make_object<FunctionInfoNode>();
// Initialize all target workspaces to zero
for (const auto& kv : targets_) {
auto tgt = kv.second;
fi_node->workspace_sizes.Set(tgt, 0);
}
fi_node->workspace_sizes.Set(target_host_, workspace_size);
fi_node->relay_primfuncs.Set(target_host_, func);

int64_t io_size = 0;
for (const auto& input : input_vars_) {
io_size += CalculateRelayExprSizeBytes(input->checked_type());
}
io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
fi_node->io_sizes.Set(target_host_, io_size);

int64_t const_size = 0;
for (const auto& kv : params_by_expr_) {
const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
}
fi_node->constant_sizes.Set(target_host_, const_size);
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
}

/*!
* \brief Update the function metadata for a given cached function and its relay
* primitive function.
*
* \param cfunc The cached function as provided the by the compile engine
* \param relay_func The source relay primitive function
* \param relay_target The target associated with relay primitive function
*/
void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
const Target& relay_target) {
auto fi_node = make_object<FunctionInfoNode>();
for (const auto& kv : cfunc->funcs->functions) {
auto primfunc = Downcast<tir::PrimFunc>(kv.second);
auto workspace_byte_alignment =
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
Target primfunc_target = relay_target;
if (primfunc->attrs->dict.count("target")) {
primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
}
fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
// Calculating size for I/O
for (auto const& param : primfunc->params) {
auto p_shape = primfunc->buffer_map[param]->shape;
int num_of_elements = 1;
for (const auto& dim_index_expr : p_shape) {
if (dim_index_expr->IsInstance<IntImmNode>()) {
num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
} else {
// If shape is dynamic, we cannot calculate workspace in compile time.
num_of_elements = 0;
}
}
int element_size = primfunc->buffer_map[param]->dtype.bytes();
fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
}
fi_node->constant_sizes.Set(primfunc_target, 0);
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
// Descend the call tree
for (auto arg : op->args) {
VisitExpr(arg);
}

Expr expr = GetRef<Expr>(op);
Function func;
if (op->op.as<OpNode>()) {
LOG(FATAL) << "Operators should be transformed away; try applying"
<< "the fuse_ops transformation to the expression.";
} else if (op->op.as<GlobalVarNode>()) {
LOG(FATAL) << "Not implemented";
} else if (op->op.as<FunctionNode>()) {
func = GetRef<Function>(op->op.as<FunctionNode>());
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
CreateFuncCall(GetRef<Call>(op), node->name_hint);
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
}
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
LOG(FATAL) << "TVM only support calls to primitive functions "
<< "(i.e functions composed of fusable operator invocations)";
}

Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
return;
}

ICHECK_GE(storage_device_map_.count(expr), 0);
StorageInfo& sinfo = storage_device_map_[expr];
auto call_dev_type = sinfo->device_types[0];
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
}
target = targets_[call_dev_type];
}

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);

if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
// Update function metadata via looking at all primfuncs
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
}

void VisitExpr_(const VarNode* op) override {
Expand Down Expand Up @@ -598,7 +462,7 @@ class AOTExecutorCodegen : public ExprVisitor {
// Create the main PrimFunc to execute the graph. Please note that
// the packed function calls don't pack their arguments. The AOT
// runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);

// Allocate the sids
Expand Down Expand Up @@ -637,7 +501,7 @@ class AOTExecutorCodegen : public ExprVisitor {
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

Expand All @@ -654,7 +518,7 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief target device */
TargetsMap targets_;
tec::TargetMap targets_;
/*! \brief target host */
Target target_host_;
/*!
Expand Down Expand Up @@ -684,35 +548,70 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
/*! \brief lowered funcs */
Map<String, FunctionInfo> function_metadata_;
/*! \brief compile engine */
CompileEngine compile_engine_;
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
std::vector<int> return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
: mod_(mod),
targets_(targets),
target_host_(target_host),
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
compile_engine_(CompileEngine::Global()) {}
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}

LoweredOutput Codegen(relay::Function func, String mod_name) {
auto aot_allocator = AOTOnDemandAllocator();
aot_allocator.Run(func);

// Retrieve the storage map
storage_device_map_ = aot_allocator.GetStorageMap();
mod_name_ = mod_name;
// Pre-lowering storage map and memory plan
StorageMap initial_storage_map = aot_allocator.GetStorageMap();
StaticMemoryPlan memory_plan(initial_storage_map);

// Build a map from each operation to device.
tec::DeviceMap device_context_map;
for (const auto& it : memory_plan->expr_to_storage_info) {
auto expr = it.first;
auto storage_info = it.second;
auto device_types = storage_info->device_types;
// CHECK_EQ(device_types.size(), 1);
tvm::Device dev;
dev.device_id = 0;
dev.device_type = device_types[0];
device_context_map.insert({expr, dev});
}

// This first phase moves from implicit use of compile engine,
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}

// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});

for (auto input : func->params) {
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
// created, just referencing the new expressions created from lowering
auto new_allocator = AOTOnDemandAllocator();
new_allocator.Run(lowered_main_func);
storage_device_map_ = new_allocator.GetStorageMap();

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var("input", DataType::Handle()));
}
Expand All @@ -732,13 +631,12 @@ class AOTExecutorCodegen : public ExprVisitor {
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}

VisitExpr(func->body);
VisitExpr(lowered_main_func->body);

// Create the runner function. Please note that the function is not legal yet
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
auto prim_func = CreateMainFunc(func->params.size());
UpdateMainWorkspaceSize(prim_func, func);
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
LoweredOutput ret;

ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
Expand All @@ -748,17 +646,7 @@ class AOTExecutorCodegen : public ExprVisitor {
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}

for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, IRModule(Map<GlobalVar, BaseFunc>({})));
}
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
ret.lowered_funcs.Set(kv.first, mod);
}
ret.external_mods = compile_engine_->LowerExternalFunctions();

// Build the TIR IRModule
// Build the TIR IRModule for the AOT function
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
IRModule mod_run(symbol_map);
Expand All @@ -774,14 +662,17 @@ class AOTExecutorCodegen : public ExprVisitor {
mod_run = pack_calls(mod_run);
}

// Update the lowered functions
ret.function_metadata = std::move(function_metadata_);

ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
}
ret.function_metadata = std::move(function_metadata_);

std::vector<String> input_var_names(input_vars_.size());
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
Expand Down Expand Up @@ -845,15 +736,15 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {

private:
void init(void* mod, Map<Integer, tvm::Target> tmp) {
TargetsMap targets;
tec::TargetMap targets;
Target target_host;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
target_host = it.second;
}
ICHECK(dev_type);
targets[dev_type->value] = it.second;
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
}
codegen_ = std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
targets, target_host);
Expand Down

0 comments on commit 3d333f7

Please sign in to comment.