diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index af8ac2b4e023..0c094cb1fa2c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -598,12 +598,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { tec::UpdateFunctionMetadata(func, this->function_metadata_); })(mod); - Optional main_func_info = - lowered_mod->GetAttr("main_func_info"); - ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); auto lowered_main = lowered_mod->Lookup("main"); - auto lowered_main_func = GetRef(lowered_main.as()); // Post-lowering storage map for writing main func - this should be the same map as previously @@ -656,6 +651,20 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); + // The workspace for main function should be calculated after performing storage_rewrite for + // the top level TIR function. + auto workspace_byte_alignment = + target_host_->GetAttr("workspace-byte-alignment").value_or(16); + Integer main_workspace_size = CalculateWorkspaceBytes( + Downcast(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)), + workspace_byte_alignment); + + Optional main_func_info = + lowered_mod->GetAttr("main_func_info"); + ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point."; + main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size); + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); + // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) if (!use_unpacked_api_) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 9244e06c8c02..4d7f50b4f3a0 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -800,7 +800,7 @@ void UpdateFunctionMetadata(Function relay_func, CHECK(prim_fn.defined()) << "the primitive function must be defined"; auto workspace_byte_alignment = - relay_target.value()->GetAttr("workspace_byte_alignment").value_or(16); + relay_target.value()->GetAttr("workspace-byte-alignment").value_or(16); Integer workspace_size = CalculateWorkspaceBytes(prim_fn, workspace_byte_alignment); diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 64000a9d56b3..68a9b0b436e7 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -559,5 +559,35 @@ def test_name_sanitiser_name_clash(): ) +@pytest.mark.parametrize( + "workspace_byte_alignment,main_workspace_size,sum_workspace_size", + [ + (8, 10368, 15392), + (16, 10368, 15424), + (256, 10752, 17664), + ], +) +def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_workspace_size): + mod, params = tvm.relay.testing.synthetic.get_workload() + + target = f"c -runtime=c --link-params --executor=aot --workspace-byte-alignment={workspace_byte_alignment}" + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + lib = tvm.relay.build(mod, target, params=params) + + assert ( + sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size + ) + assert ( + sum( + [ + size + for metadata in lib.function_metadata.values() + for size in metadata.workspace_sizes.values() + ] + ) + == sum_workspace_size + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))