Skip to content

Commit aea2a68

Browse files
Mousiusmanupak
andcommitted
Fix incorrect AOT Memory Planning
This change introduces a second memory planning phase in the AOT code generator once the storage rewrite pass has been completed, fixing incorrectly sized workspaces for a variety of models. It comes with accompanying tests so we can safely refactor this later. Also corrected a typo in the TE compiler regards the memory alignment argument 😸 Co-authored-by: Manupa Karunaratne <[email protected]>
1 parent 06a0d63 commit aea2a68

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -598,12 +598,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
598598
tec::UpdateFunctionMetadata(func, this->function_metadata_);
599599
})(mod);
600600

601-
Optional<backend::FunctionInfo> main_func_info =
602-
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
603-
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
604-
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
605601
auto lowered_main = lowered_mod->Lookup("main");
606-
607602
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
608603

609604
// Post-lowering storage map for writing main func - this should be the same map as previously
@@ -656,6 +651,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
656651
auto storage_rewrite = tir::transform::StorageRewrite();
657652
mod_run = storage_rewrite(mod_run);
658653

654+
// The workspace for main function should be calculated after performing storage_rewrite for
655+
// the top level TIR function.
656+
auto workspace_byte_alignment =
657+
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
658+
Integer main_workspace_size = CalculateWorkspaceBytes(
659+
Downcast<tir::PrimFunc>(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)),
660+
workspace_byte_alignment);
661+
662+
Optional<backend::FunctionInfo> main_func_info =
663+
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
664+
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
665+
main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size);
666+
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
667+
659668
// Legalize AOT if needed. This means that all the packed calls
660669
// need to be wrapped in TVMValues (unless use_unpacked_api is set)
661670
if (!use_unpacked_api_) {

src/relay/backend/te_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ void UpdateFunctionMetadata(Function relay_func,
800800
CHECK(prim_fn.defined()) << "the primitive function must be defined";
801801

802802
auto workspace_byte_alignment =
803-
relay_target.value()->GetAttr<Integer>("workspace_byte_alignment").value_or(16);
803+
relay_target.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
804804

805805
Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment);
806806

tests/micro/arduino/test_arduino_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_model_header_templating(project_dir, project):
120120
# Ensure model.h was templated with correct WORKSPACE_SIZE
121121
with (project_dir / "src" / "model.h").open() as f:
122122
model_h = f.read()
123-
assert "#define WORKSPACE_SIZE 21312" in model_h
123+
assert "#define WORKSPACE_SIZE 17392" in model_h
124124

125125

126126
def test_import_rerouting(project_dir, project):

tests/python/relay/aot/test_crt_aot.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,5 +559,35 @@ def test_name_sanitiser_name_clash():
559559
)
560560

561561

562+
@pytest.mark.parametrize(
563+
"workspace_byte_alignment,main_workspace_size,sum_workspace_size",
564+
[
565+
(8, 10368, 15328),
566+
(16, 10368, 15344),
567+
(256, 10752, 16896),
568+
],
569+
)
570+
def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_workspace_size):
571+
mod, params = tvm.relay.testing.synthetic.get_workload()
572+
573+
target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment}"
574+
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
575+
lib = tvm.relay.build(mod, target, params=params)
576+
577+
assert (
578+
sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size
579+
)
580+
assert (
581+
sum(
582+
[
583+
size
584+
for metadata in lib.function_metadata.values()
585+
for size in metadata.workspace_sizes.values()
586+
]
587+
)
588+
== sum_workspace_size
589+
)
590+
591+
562592
if __name__ == "__main__":
563593
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)