Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
auto lowered_main = lowered_mod->Lookup("main");

auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
Expand Down Expand Up @@ -656,6 +651,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
auto storage_rewrite = tir::transform::StorageRewrite();
mod_run = storage_rewrite(mod_run);

// The workspace for main function should be calculated after performing storage_rewrite for
// the top level TIR function.
auto workspace_byte_alignment =
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
Integer main_workspace_size = CalculateWorkspaceBytes(
Downcast<tir::PrimFunc>(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)),
workspace_byte_alignment);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size);
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

// Legalize AOT if needed. This means that all the packed calls
// need to be wrapped in TVMValues (unless use_unpacked_api is set)
if (!use_unpacked_api_) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ void UpdateFunctionMetadata(Function relay_func,
CHECK(prim_fn.defined()) << "the primitive function must be defined";

auto workspace_byte_alignment =
relay_target.value()->GetAttr<Integer>("workspace_byte_alignment").value_or(16);
relay_target.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);

Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment);

Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,5 +559,35 @@ def test_name_sanitiser_name_clash():
)


@pytest.mark.parametrize(
"workspace_byte_alignment,main_workspace_size,sum_workspace_size",
[
(8, 10368, 15392),
(16, 10368, 15424),
(256, 10752, 17664),
],
)
def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_workspace_size):
mod, params = tvm.relay.testing.synthetic.get_workload()

target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment}"
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
lib = tvm.relay.build(mod, target, params=params)

assert (
sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size
)
assert (
sum(
[
size
for metadata in lib.function_metadata.values()
for size in metadata.workspace_sizes.values()
]
)
== sum_workspace_size
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))