2323import subprocess
2424import tarfile
2525import logging
26- from typing import Any , NamedTuple , Union , Tuple , Optional , List , Dict
26+ from typing import Any , NamedTuple , Union , Tuple , Optional , List , Dict , Callable
2727import numpy as np
2828
2929import tvm
@@ -200,6 +200,7 @@ def _emit_main_prologue(
200200 compiled_models ,
201201 interface_api ,
202202 use_stack_allocator = True ,
203+ debug_last_error = False ,
203204):
204205 if use_stack_allocator :
205206 workspace_define = f"#define WORKSPACE_SIZE ({ workspace_bytes } "
@@ -243,11 +244,28 @@ def _emit_main_prologue(
243244 va_start(args, msg);
244245 vfprintf(stdout, msg, args);
245246 va_end(args);
246- }\n
247- TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
248- int main(){\n
247+ }
249248 """
250249 )
250+ if debug_last_error :
251+ main_file .write (
252+ """\n
253+ tvm_crt_error_t TVMPlatformTimerStart() {
254+ return kTvmErrorFunctionCallNotImplemented;
255+ }
256+ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
257+ return kTvmErrorFunctionCallNotImplemented;
258+ }
259+ const TVMModule* TVMSystemLibEntryPoint(void) { return NULL; }
260+ """
261+ )
262+ else :
263+ main_file .write (
264+ """\n
265+ TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
266+ """
267+ )
268+ main_file .write ("\n int main(){\n " )
251269 main_file .write (custom_prologue )
252270
253271
@@ -332,10 +350,10 @@ def _emit_main_data_setup(main_file, input_map, output_map, mod_name):
332350
333351
334352def _emit_main_c_interface_call (
335- main_file , devices , workspace_pool_names , mod_name , use_workspace_io
353+ main_file , devices , workspace_pool_names , mod_name , use_workspace_io , debug_last_error
336354):
337355 sub_strings = list ()
338- sub_strings .append (f'{ _mangle_name (mod_name ,"run" )} (' )
356+ sub_strings .append (f'if ( { _mangle_name (mod_name ,"run" )} (' )
339357 if not use_workspace_io :
340358 sub_strings .append (f'&{ _mangle_name (mod_name ,"inputs" )} , ' )
341359 sub_strings .append (f'&{ _mangle_name (mod_name ,"outputs" )} , ' )
@@ -346,10 +364,14 @@ def _emit_main_c_interface_call(
346364 # Removing the last two characters that is a comma and a space
347365 sub_strings [- 1 ] = sub_strings [- 1 ][:- 2 ]
348366 # Adding brackets and newline instead
349- sub_strings [- 1 ] = sub_strings [- 1 ] + ");\n "
350-
367+ sub_strings [- 1 ] = sub_strings [- 1 ] + ") == -1) {\n "
351368 main_file_string = "" .join (sub_strings )
352369 main_file .write (main_file_string )
370+ if debug_last_error :
371+ main_file .write (f'\t printf("ERROR: %s\\ n", TVMGetLastError());\n ' )
372+ main_file .write (f'\t printf("{ AOT_FAILURE_TOKEN } \\ n");\n ' )
373+ main_file .write (f"\t return -1;\n " )
374+ main_file .write ("}\n " )
353375
354376
355377def _emit_main_fake_packed_values (main_file ):
@@ -447,13 +469,15 @@ def _emit_main_epilogue(main_file, custom_epilogue):
447469 main_file .write ("}\n " )
448470
449471
450- def _emit_main_common_includes (main_file , custom_includes ):
472+ def _emit_main_common_includes (main_file , custom_includes , debug_last_error ):
451473 main_file .write ("#include <stdio.h>\n " )
452474 main_file .write ("#include <stdarg.h>\n " )
453475 main_file .write ("#include <stdlib.h>\n " )
454476 main_file .write ("#include <math.h>\n " )
455477 main_file .write ('#include "tvm/runtime/c_runtime_api.h"\n ' )
456478 main_file .write ('#include "tvm/runtime/crt/stack_allocator.h"\n ' )
479+ if debug_last_error :
480+ main_file .write ('#include "tvm/runtime/crt/module.h"\n ' )
457481 for include in custom_includes :
458482 main_file .write (f'#include "{ include } "\n ' )
459483
@@ -474,12 +498,13 @@ def _create_main(
474498 workspace_bytes ,
475499 use_stack_allocator = True ,
476500 use_workspace_io = False ,
501+ debug_last_error = False ,
477502):
478503 file_path = pathlib .Path (f"{ output_path } /" + test_name ).resolve ()
479504 # create header file
480505 raw_path = file_path .with_suffix (".c" ).resolve ()
481506 with open (raw_path , "w" ) as main_file :
482- _emit_main_common_includes (main_file , custom_includes )
507+ _emit_main_common_includes (main_file , custom_includes , debug_last_error )
483508
484509 if interface_api == "c" :
485510 for compiled_model in compiled_models :
@@ -497,6 +522,7 @@ def _create_main(
497522 compiled_models ,
498523 interface_api ,
499524 use_stack_allocator ,
525+ debug_last_error ,
500526 )
501527 if use_stack_allocator :
502528 _emit_main_init_memory_manager (main_file )
@@ -529,6 +555,7 @@ def _create_main(
529555 list (workspace_pool_names .keys ()),
530556 model .name ,
531557 use_workspace_io ,
558+ debug_last_error ,
532559 )
533560 else :
534561 _emit_main_fake_packed_values (main_file )
@@ -701,6 +728,8 @@ def run_and_check(
701728 test_dir : str = None ,
702729 verbose : bool = False ,
703730 use_workspace_io : bool = False ,
731+ debug_last_error : bool = False ,
732+ checker : Optional [Callable [[str ], bool ]] = None ,
704733):
705734 """
706735 This method uses the original test data and compiled runtime.Modules
@@ -780,8 +809,12 @@ def run_and_check_body(base_path):
780809 workspace_bytes ,
781810 use_stack_allocator ,
782811 use_workspace_io ,
812+ debug_last_error ,
783813 )
784814
815+ if checker and (not checker (base_path )):
816+ return False
817+
785818 # Verify that compiles fine
786819 file_dir = os .path .dirname (os .path .abspath (__file__ ))
787820 makefile_dir = os .path .join (file_dir , "../../../tests/python/relay/aot" )
@@ -829,11 +862,13 @@ def run_and_check_body(base_path):
829862 with open (run_log_path ) as run_log :
830863 assert AOT_SUCCESS_TOKEN in run_log .read ()
831864
865+ return True
866+
832867 if test_dir is None :
833868 tmpdir = utils .tempdir ()
834- run_and_check_body (os .path .join (tmpdir .path , "test" ))
869+ return run_and_check_body (os .path .join (tmpdir .path , "test" ))
835870 else :
836- run_and_check_body (test_dir )
871+ return run_and_check_body (test_dir )
837872
838873
839874def compile_and_run (
@@ -852,7 +887,9 @@ def compile_and_run(
852887 test_dir : str = None ,
853888 verbose : bool = False ,
854889 schedule_name : str = None ,
855- ):
890+ debug_last_error : bool = False ,
891+ checker : Optional [Callable [[str ], bool ]] = None ,
892+ ) -> bool :
856893 """This is a wrapper API to compile and run models as test for AoT
857894
858895 Parameters
@@ -883,7 +920,7 @@ def compile_and_run(
883920 schedule_name = schedule_name ,
884921 )
885922
886- run_and_check (
923+ return run_and_check (
887924 models = compiled_test_mods ,
888925 runner = runner ,
889926 interface_api = interface_api ,
@@ -893,6 +930,8 @@ def compile_and_run(
893930 data_linkage = data_linkage ,
894931 test_dir = test_dir ,
895932 verbose = verbose ,
933+ debug_last_error = debug_last_error ,
934+ checker = checker ,
896935 )
897936
898937
0 commit comments