Skip to content

Commit 6e30c3c

Browse files
committed
[AOT] BugFix of workspace calculation
Following an investigation from apache#10022, it turns out, currently the workspace calculation assumes there would be a single lowered PrimFunc could be produced per primitive Relay Function. However, the exception turned out to be the CMSIS-NN codegen that produces multiple calls/PrimFuncs in the place of a single call to single relay PrimFunc. This commit adds changes to workspace calculation to be done on lowered IRModule. Additionally, changes the test utils to not to generate any stack allocator code when USMP is used to make the tests more strict. Change-Id: I5202d9cc7c6a8c00c73791b82df062a8e13dd224
1 parent 2f93780 commit 6e30c3c

File tree

8 files changed

+197
-66
lines changed

8 files changed

+197
-66
lines changed

apps/microtvm/zephyr_cmsisnn/src/main.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ extern float output_storage[12];
3434

3535
extern const size_t output_len;
3636

37-
static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE + 512];
37+
static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE];
3838
tvm_workspace_t app_workspace;
3939

4040
void TVMLogf(const char* msg, ...) {

src/relay/backend/aot_executor_codegen.cc

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,35 @@ class AOTExecutorCodegen : public MixedModeVisitor {
686686
}
687687
}
688688

689+
/*!
690+
* brief Calculate workspace sizes for PrimFuncs in the IRModule
691+
*/
692+
Map<String, FunctionInfo> CalculateWorkspaceSizes(
693+
const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
694+
Executor executor_config = lowered_mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
695+
Integer workspace_byte_alignment =
696+
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
697+
Map<String, FunctionInfo> updated_function_metadata;
698+
for (const auto& kv : lowered_mod->functions) {
699+
GlobalVar global_var = kv.first;
700+
BaseFunc base_func = kv.second;
701+
if (base_func->IsInstance<tir::PrimFuncNode>()) {
702+
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
703+
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
704+
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
705+
if (function_metadata.count(global_var->name_hint)) {
706+
updated_function_metadata.Set(global_var->name_hint,
707+
function_metadata[global_var->name_hint]);
708+
updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws);
709+
} else {
710+
FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
711+
updated_function_metadata.Set(global_var->name_hint, finfo);
712+
}
713+
}
714+
}
715+
return updated_function_metadata;
716+
}
717+
689718
/*!
690719
* brief Run USMP to plan memory for lowered IRModule
691720
*/
@@ -694,17 +723,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
694723
Integer workspace_byte_alignment =
695724
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
696725
IRModule lowered_mod = mod->ShallowCopy();
726+
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
697727
lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
698-
// Update workspace size based on the pool allocations.
699-
for (const auto& kv : function_metadata_) {
700-
if (lowered_mod->ContainGlobalVar(kv.first) &&
701-
lowered_mod->Lookup(kv.first)->IsInstance<tir::PrimFuncNode>()) {
702-
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(lowered_mod->Lookup(kv.first));
703-
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
704-
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
705-
kv.second->workspace_sizes.Set(tgt, ws);
706-
}
707-
}
708728
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
709729
lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
710730
backend::FunctionInfo main_func_info =
@@ -736,6 +756,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
736756
Integer workspace_byte_alignment =
737757
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
738758
IRModule lowered_mod = mod->ShallowCopy();
759+
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
739760
// Running StorageRewrite just on the main function
740761
tir::PrimFunc tir_main_func =
741762
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));

tests/python/contrib/test_ethosu/infra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,13 @@ def build_source(
242242
def verify_source(
243243
models: List[AOTCompiledTestModel],
244244
accel="ethos-u55-256",
245+
enable_usmp=True,
245246
):
246247
"""
247248
This method verifies the generated source from an NPU module by building it and running on an FVP.
248249
"""
249250
interface_api = "c"
250-
test_runner = create_test_runner(accel)
251+
test_runner = create_test_runner(accel, enable_usmp)
251252
run_and_check(
252253
models,
253254
test_runner,

tests/python/contrib/test_ethosu/test_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_forward_mobilenet_v1(accel_type, enable_usmp):
7171
compiled_models = infra.build_source(
7272
mod, input_data, output_data, accel_type, output_tolerance=10, enable_usmp=enable_usmp
7373
)
74-
infra.verify_source(compiled_models, accel_type)
74+
infra.verify_source(compiled_models, accel_type, enable_usmp=enable_usmp)
7575

7676

7777
if __name__ == "__main__":

tests/python/relay/aot/aot_test_utils.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -265,29 +265,54 @@ def emit_data_linkage(output_file, data_linkage):
265265

266266

267267
def emit_main_prologue(
268-
main_file, custom_prologue, workspace_bytes, data_linkage, compiled_models, interface_api
268+
main_file,
269+
custom_prologue,
270+
workspace_bytes,
271+
data_linkage,
272+
compiled_models,
273+
interface_api,
274+
use_stack_allocator=True,
269275
):
270-
# Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
271-
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
272-
if interface_api == "c":
273-
for compiled_model in compiled_models:
274-
model = compiled_model.model
275-
workspace_define += f" + TVMGEN_{model.name.upper()}_WORKSPACE_SIZE"
276-
workspace_define += " + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
277-
main_file.write(workspace_define)
278-
emit_data_linkage(main_file, data_linkage)
279-
main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n")
280-
main_file.write("tvm_workspace_t app_workspace;\n")
281-
main_file.write(
282-
"""
276+
if use_stack_allocator:
277+
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
278+
if interface_api == "c":
279+
for compiled_model in compiled_models:
280+
model = compiled_model.model
281+
workspace_define += f" + TVMGEN_{model.name.upper()}_WORKSPACE_SIZE"
282+
# Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
283+
workspace_define += " + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
284+
main_file.write(workspace_define)
285+
emit_data_linkage(main_file, data_linkage)
286+
main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n")
287+
main_file.write("tvm_workspace_t app_workspace;\n")
288+
main_file.write(
289+
"""
283290
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
284291
return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr);
285292
}
286293
287294
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
288295
return StackMemoryManager_Free(&app_workspace,ptr);
289296
}
297+
"""
298+
)
299+
else:
300+
# An implementation is not needed for these if the stack allocator is not used
301+
main_file.write(
302+
"""
303+
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
304+
return kTvmErrorNoError;
305+
}
290306
307+
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
308+
return kTvmErrorNoError;
309+
}
310+
311+
"""
312+
)
313+
main_file.write(
314+
"""
315+
291316
void TVMPlatformAbort(tvm_crt_error_t code) { exit(-1); }
292317
293318
void TVMLogf(const char* msg, ...) {
@@ -296,10 +321,10 @@ def emit_main_prologue(
296321
vfprintf(stdout, msg, args);
297322
va_end(args);
298323
}
299-
324+
300325
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
301326
int main(){\n
302-
"""
327+
"""
303328
)
304329
main_file.write(custom_prologue)
305330

@@ -511,6 +536,7 @@ def create_main(
511536
data_linkage,
512537
interface_api,
513538
workspace_bytes,
539+
use_stack_allocator=True,
514540
):
515541
file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
516542
# create header file
@@ -533,8 +559,10 @@ def create_main(
533559
data_linkage,
534560
compiled_models,
535561
interface_api,
562+
use_stack_allocator,
536563
)
537-
emit_main_init_memory_manager(main_file)
564+
if use_stack_allocator:
565+
emit_main_init_memory_manager(main_file)
538566

539567
if interface_api == "c":
540568
for compiled_model in compiled_models:
@@ -709,11 +737,14 @@ def run_and_check(
709737
t = tarfile.open(tar_file)
710738
t.extractall(base_path)
711739

712-
workspace_bytes = model.extra_memory_in_bytes
713-
use_usmp = runner.pass_config.get("tir.usmp.enable", False)
714-
if interface_api == "packed" and not use_usmp:
740+
# Interface C APIs does not need compiler generated
741+
# workspace to generate the test application, because
742+
# workspace size is codegen'd as a macro to
743+
# tvmgen_<model_name>.h.
744+
if interface_api != "c":
715745
workspace_bytes += mlf_extract_workspace_size_bytes(tar_file)
716746

747+
workspace_bytes += model.extra_memory_in_bytes
717748
for key in model.inputs:
718749
sanitized_tensor_name = re.sub(r"\W", "_", key)
719750
create_header_file(
@@ -738,6 +769,10 @@ def run_and_check(
738769
data_linkage,
739770
)
740771

772+
use_usmp = runner.pass_config.get("tir.usmp.enable", False)
773+
# We only need the stack allocator if USMP is not used
774+
use_stack_allocator = not use_usmp
775+
741776
create_main(
742777
"test.c",
743778
models,
@@ -748,6 +783,7 @@ def run_and_check(
748783
data_linkage,
749784
interface_api,
750785
workspace_bytes,
786+
use_stack_allocator,
751787
)
752788

753789
# Verify that compiles fine
@@ -868,3 +904,22 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
868904
output_tensor_names = main.attrs["output_tensor_names"]
869905

870906
return dict(zip(output_tensor_names, out))
907+
908+
909+
def create_relay_module_and_inputs_from_tflite_file(tflite_model_file):
910+
"""A helper function to create a Relay IRModule with inputs
911+
and params from a tflite file"""
912+
with open(tflite_model_file, "rb") as f:
913+
tflite_model_buf = f.read()
914+
mod, params = convert_to_relay(tflite_model_buf)
915+
916+
inputs = dict()
917+
for param in mod["main"].params:
918+
name = str(param.name_hint)
919+
data_shape = [int(i) for i in param.type_annotation.shape]
920+
dtype = str(param.type_annotation.dtype)
921+
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
922+
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
923+
inputs[name] = data
924+
925+
return mod, inputs, params

tests/python/relay/aot/test_crt_aot.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@
2828
from tvm.relay.testing import byoc
2929
from tvm.relay.op.annotation import compiler_begin, compiler_end
3030
from tvm.relay.backend import Executor, Runtime
31+
from tvm.micro import model_library_format as mlf
3132
from aot_test_utils import (
3233
AOTTestModel,
3334
AOT_DEFAULT_RUNNER,
35+
AOT_CORSTONE300_RUNNER,
36+
AOTDataLinkage,
3437
generate_ref_data,
3538
convert_to_relay,
3639
compile_and_run,
3740
compile_models,
3841
parametrize_aot_options,
42+
create_relay_module_and_inputs_from_tflite_file,
3943
)
4044

4145

@@ -87,11 +91,16 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5),
8791
inputs = {"data": input_data}
8892
output_list = generate_ref_data(mod, inputs, params)
8993

94+
data_linkage = None
95+
if test_runner == AOT_CORSTONE300_RUNNER:
96+
data_linkage = AOTDataLinkage(section=".data.tvm", alignment=8)
97+
9098
compile_and_run(
9199
AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params),
92100
test_runner,
93101
interface_api,
94102
use_unpacked_api,
103+
data_linkage=data_linkage,
95104
)
96105

97106

@@ -501,6 +510,10 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5),
501510
inputs2 = {"data": input_data}
502511
output_list2 = generate_ref_data(mod2, inputs2, params2)
503512

513+
data_linkage = None
514+
if test_runner == AOT_CORSTONE300_RUNNER:
515+
data_linkage = AOTDataLinkage(section=".data.tvm", alignment=8)
516+
504517
compile_and_run(
505518
[
506519
AOTTestModel(
@@ -521,6 +534,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5),
521534
test_runner,
522535
interface_api,
523536
use_unpacked_api,
537+
data_linkage=data_linkage,
524538
)
525539

526540

@@ -541,13 +555,7 @@ def test_quant_mobilenet_tfl():
541555
"models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
542556
"mobilenet_v1_1.0_224_quant.tflite",
543557
)
544-
with open(tflite_model_file, "rb") as f:
545-
tflite_model_buf = f.read()
546-
data_shape = (1, 224, 224, 3)
547-
in_min, in_max = (0, 255)
548-
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8")
549-
mod, params = convert_to_relay(tflite_model_buf)
550-
inputs = {"input": data}
558+
mod, inputs, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file)
551559
output_list = generate_ref_data(mod, inputs, params)
552560
compile_and_run(
553561
AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params),
@@ -843,5 +851,68 @@ def representative_dataset():
843851
assert output_name in source
844852

845853

854+
@pytest.mark.parametrize(
855+
"workspace_byte_alignment,main_workspace_size",
856+
[
857+
(8, 55296),
858+
(16, 55296),
859+
(256, 57344),
860+
],
861+
)
862+
def test_workspace_calculation(workspace_byte_alignment, main_workspace_size):
863+
mod, params = tvm.relay.testing.synthetic.get_workload()
864+
target = "c"
865+
runtime = Runtime("crt")
866+
executor = Executor(
867+
"aot",
868+
{
869+
"workspace-byte-alignment": workspace_byte_alignment,
870+
},
871+
)
872+
with tvm.transform.PassContext(
873+
opt_level=3,
874+
config={
875+
"tir.disable_vectorize": True,
876+
},
877+
):
878+
lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params)
879+
880+
mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
881+
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == main_workspace_size
882+
883+
884+
@tvm.testing.requires_package("tflite")
885+
@tvm.testing.requires_cmsisnn
886+
def test_workspace_calculation_cmsis_nn():
887+
"""This tests cmsis_nn codegen for workspace calculation.
888+
This is tested specially because cmsis-nn codegen creates
889+
multiple PrimFuncs per offloaded relay function in a non
890+
-hierarchical manner."""
891+
pytest.importorskip("tflite")
892+
893+
import tvm.relay.testing.tf as tf_testing
894+
from tvm.relay.op.contrib import cmsisnn
895+
896+
target = "c"
897+
runtime = Runtime("crt")
898+
executor = Executor("aot")
899+
tflite_model_file = tf_testing.get_workload_official(
900+
"https://storage.googleapis.com/download.tensorflow.org/"
901+
"models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
902+
"mobilenet_v1_1.0_224_quant.tflite",
903+
)
904+
mod, _, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file)
905+
mod = cmsisnn.partition_for_cmsisnn(mod, params)
906+
with tvm.transform.PassContext(
907+
opt_level=3,
908+
config={
909+
"tir.disable_vectorize": True,
910+
},
911+
):
912+
lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params)
913+
mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
914+
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 12907328
915+
916+
846917
if __name__ == "__main__":
847918
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)